#include <cctype> #include <cstdlib> #include <string> #include <sstream> #include "posint.hpp" using namespace std; /******************** BASE ********************/ int PosInt::B = 0x8000; int PosInt::Bbase = 2; int PosInt::Bpow = 15; void PosInt::setBase(int base, int pow) { Bbase = base; Bpow = pow; B = base; while (pow > 1) { B *= Bbase; --pow; } } /******************** I/O ********************/ void PosInt::read (const char* s) { string str(s); istringstream sin (str); read(sin); } void PosInt::set(int x) { digits.clear(); if (x < 0) throw "Can't set PosInt to negative value"; while (x > 0) { digits.push_back(x % B); x /= B; } } void PosInt::set (const PosInt& rhs) { if (this != &rhs) digits.assign (rhs.digits.begin(), rhs.digits.end()); } void PosInt::print_array(ostream& out) const { out << "[ls"; for (int i=0; i<digits.size(); ++i) out << ' ' << digits[i]; out << " ms]"; } void PosInt::print(ostream& out) const { if (digits.empty()) out << 0; else { int i = digits.size()-1; int pow = B/Bbase; int digit = digits[i]; for (; digit < pow; pow /= Bbase); while (true) { for (; pow>0; pow /= Bbase) { int subdigit = digit / pow; if (subdigit < 10) out << subdigit; else out << (char)('A' + (subdigit - 10)); digit -= subdigit*pow; } if (--i < 0) break; digit = digits[i]; pow = B/Bbase; } } } void PosInt::read (istream& in) { vector<int> digstack; while (isspace(in.peek())) in.get(); int pow = B/Bbase; int digit = 0; int subdigit; while (true) { int next = in.peek(); if (isdigit(next)) subdigit = next-'0'; else if (islower(next)) subdigit = next - 'a' + 10; else if (isupper(next)) subdigit = next - 'A' + 10; else subdigit = Bbase; if (subdigit >= Bbase) break; digit += pow*subdigit; in.get(); if (pow == 1) { digstack.push_back (digit); pow = B/Bbase; digit = 0; } else pow /= Bbase; } pow *= Bbase; if (pow == B && !digstack.empty()) { pow = 1; digit = digstack.back(); digstack.pop_back(); } int pmul = B/pow; digits.assign (1, digit/pow); for (int i=digstack.size()-1; i >= 0; --i) { digits.back() += (digstack[i] % pow) * pmul; digits.push_back (digstack[i] / pow); } normalize(); } int PosInt::convert () const { int val = 0; int pow = 1; for (int i = 0; i < digits.size(); ++i) { val += pow * digits[i]; pow *= B; } return val; } ostream& operator<< (ostream& out, const PosInt& x) { x.print(out); return out; } istream& operator>> (istream& in, PosInt& x) { x.read(in); return in; } /******************** RANDOM NUMBERS ********************/ // Produces a random number between 0 and n-1 static int randomInt (int n) { int max = RAND_MAX - ((RAND_MAX-n+1) % n); int r; do { r = rand(); } while (r > max); return r % n; } // Sets this PosInt to a random number between 0 and x-1 void PosInt::rand (const PosInt& x) { if (this == &x) { PosInt xcopy(x); rand(xcopy); } else { PosInt max; max.digits.assign (x.digits.size(), 0); max.digits.push_back(1); mod (max, x); max.sub(*this); do { digits.resize(x.digits.size()); for (int i=0; i<digits.size(); ++i) digits[i] = randomInt(B); normalize(); } while (compare(max) >= 0); mod(x); } } /******************** UTILITY ********************/ // Removes leading 0 digits void PosInt::normalize () { int i; for (i = digits.size()-1; i >= 0 && digits[i] == 0; --i); if (i+1 < digits.size()) digits.resize(i+1); } bool PosInt::isEven() const { if (B % 2 == 0) return digits.empty() || (digits[0] % 2 == 0); int sum = 0; for (int i = 0; i < digits.size(); ++i) sum += digits[i] % 2; return sum % 2 == 0; } // Result is -1, 0, or 1 if a is <, =, or > than b, // up to the specified length. int PosInt::compareDigits (const int* a, int alen, const int* b, int blen) { int i = max(alen, blen)-1; for (; i >= blen; --i) { if (a[i] > 0) return 1; } for (; i >= alen; --i) { if (b[i] > 0) return -1; } for (; i >= 0; --i) { if (a[i] < b[i]) return -1; else if (a[i] > b[i]) return 1; } return 0; } // Result is -1, 0, or 1 if this is <, =, or > than rhs. int PosInt::compare (const PosInt& x) const { if (digits.size() < x.digits.size()) return -1; else if (digits.size() > x.digits.size()) return 1; else return compareDigits (&digits[0], digits.size(), &x.digits[0], x.digits.size()); } /******************** ADDITION ********************/ // Computes dest += x, digit-wise // REQUIREMENT: dest has enough space to hold the complete sum. void PosInt::addArray (int* dest, const int* x, int len) { int i; for (i=0; i < len; ++i) dest[i] += x[i]; for (i=0; i+1 < len; ++i) { if (dest[i] >= B) { dest[i] -= B; ++dest[i+1]; } } for ( ; dest[i] >= B; ++i) { dest[i] -= B; ++dest[i+1]; } } // this = x + y PosInt& PosInt::add (const PosInt& x, const PosInt& y) { if (&x == this) return add(y); else if (&y == this) return add(x); else { set(x); return add(y); } } // this = this + x PosInt& PosInt::add (const PosInt& x) { digits.resize(max(digits.size(), x.digits.size())+1, 0); addArray (&digits[0], &x.digits[0], x.digits.size()); normalize(); return *this; } /******************** SUBTRACTION ********************/ // Computes dest -= x, digit-wise // REQUIREMENT: dest >= x, so the difference is non-negative void PosInt::subArray (int* dest, const int* x, int len) { int i = 0; for ( ; i < len; ++i) dest[i] -= x[i]; for (i=0; i+1 < len; ++i) { if (dest[i] < 0) { dest[i] += B; --dest[i+1]; } } for ( ; dest[i] < 0; ++i) { dest[i] += B; --dest[i+1]; } } // this = x - y PosInt& PosInt::sub (const PosInt& x, const PosInt& y) { if (&x == this) return sub(y); else if (&y == this) { PosInt temp (y); set(x); return sub(temp); } else { set(x); return sub(y); } } // this = this - x PosInt& PosInt::sub (const PosInt& x) { if (compare(x) < 0) throw "Subtraction would result in negative number"; subArray (&digits[0], &x.digits[0], x.digits.size()); normalize(); return *this; } /******************** MULTIPLICATION ********************/ // Computes dest = dest * x, digit-wise // REQUIREMENT: dest has enough space to hold any overflow. void PosInt::mulArray (int* dest, int x, int len) { int i; for (i=0; i<len; ++i) dest[i] *= x; for (i=0; i+1<len; ++i) { dest[i+1] += dest[i] / B; dest[i] %= B; } for (; dest[i] >= B; ++i) { dest[i+1] += dest[i] / B; dest[i] %= B; } } // this = x * y PosInt& PosInt::mul(const PosInt& x, const PosInt& y) { int xlen = x.digits.size(); int ylen = y.digits.size(); if (xlen == 0 || ylen == 0) set(0); else if (this == &x) { PosInt xcopy(x); mul(xcopy, y); } else if (this == &y) { PosInt ycopy(y); mul(x, ycopy); } else { int* temp = new int[xlen+1]; digits.assign (xlen+ylen, 0); for (int i = 0; i < ylen; ++i) { // set temp = x for (int j = 0; j < xlen; ++j) temp[j] = x.digits[j]; temp[xlen] = 0; // temp = x * i'th digit of y mulArray (temp, y.digits[i], xlen); // this = this + temp*B^i addArray (&digits[i], temp, xlen+1); } delete [] temp; normalize(); } return *this; } /******************** DIVISION ********************/ // Computes dest = dest / x, digit-wise, and returns dest % x int PosInt::divArray (int* dest, int x, int len) { int r = 0; for (int i = len-1; i >= 0; --i) { dest[i] += B*r; r = dest[i] % x; dest[i] /= x; } return r; } // Computes division with remainder, digit-wise. // REQUIREMENTS: // - length of q is at least xlen-ylen+1 // - length of r is at least xlen // - q and r are distinct from all other arrays // - most significant digit of divisor (y) is at least B/2 void PosInt::divremArray (int* q, int* r, const int* x, int xlen, const int* y, int ylen) { // Copy x into r for (int i=0; i<xlen; ++i) r[i] = x[i]; // Create temporary array to hold a digit-multiple of y int* temp = new int[ylen+1]; int qind = xlen - ylen; int rind = xlen - 1; q[qind] = 0; while (true) { // Do "correction" if the quotient digit is off by a few while (compareDigits (y, ylen, r + qind, xlen-qind) <= 0) { subArray (r + qind, y, ylen); ++q[qind]; } // Test if we're done, otherwise move to the next digit if (qind == 0) break; --qind; --rind; // (Under)-estimate the next digit, and subtract out the multiple. int quoest = (r[rind] + B*r[rind+1]) / y[ylen-1] - 2; if (quoest <= 0) q[qind] = 0; else { q[qind] = quoest; for (int i=0; i<ylen; ++i) temp[i] = y[i]; temp[ylen] = 0; mulArray (temp, quoest, ylen+1); subArray (r+qind, temp, ylen+1); } } delete [] temp; } // Computes division with remainder. After the call, we have // x = q*y + r, and 0 <= r < y. void PosInt::divrem (PosInt& q, PosInt& r, const PosInt& x, const PosInt& y) { if (y.digits.empty()) throw "Divide by zero"; else if (&q == &r) throw "Quotient and remainder can't be the same"; else if (x.compare(y) < 0) { r.set(x); q.set(0); return; } else if (y.digits.size() == 1) { int divdig = y.digits[0]; q.set(x); r.set (divArray (&q.digits[0], divdig, q.digits.size())); } else if (2*y.digits.back() < B) { int ylen = y.digits.size(); int fac = 1; int* scaley = new int[ylen]; for (int i=0; i<ylen; ++i) scaley[i] = y.digits[i]; do { mulArray (scaley, 2, ylen); fac *= 2; } while (2*scaley[ylen-1] < B); int xlen = x.digits.size()+1; int* scalex = new int[xlen]; for (int i=0; i<xlen-1; ++i) scalex[i] = x.digits[i]; scalex[xlen-1] = 0; mulArray (scalex, fac, xlen); q.digits.resize(xlen - ylen + 1); r.digits.resize(xlen); divremArray (&q.digits[0], &r.digits[0], scalex, xlen, scaley, ylen); divArray (&r.digits[0], fac, xlen); } else { int xlen = x.digits.size(); int ylen = y.digits.size(); int* xarr = NULL; int* yarr = NULL; if (&x == &q || &x == &r) { xarr = new int[xlen]; for (int i=0; i<xlen; ++i) xarr[i] = x.digits[i]; } if (&y == &q || &y == &r) { yarr = new int[ylen]; for (int i=0; i<ylen; ++i) yarr[i] = y.digits[i]; } q.digits.resize(xlen - ylen + 1); r.digits.resize(xlen); divremArray (&q.digits[0], &r.digits[0], (xarr == NULL ? (&x.digits[0]) : xarr), xlen, (yarr == NULL ? (&y.digits[0]) : yarr), ylen); if (xarr != NULL) delete [] xarr; if (yarr != NULL) delete [] yarr; } q.normalize(); r.normalize(); } /******************** EXPONENTIATION ********************/ // this = x ^ y PosInt& PosInt::pow (const PosInt& x, const PosInt& y) { static const PosInt one(1); if (this == &x) { PosInt xcopy(x); pow(xcopy, y); } else { PosInt expon(y); set(1); while (!expon.isZero()) { mul(x); expon.sub(one); } } return *this; } /******************** GCDs ********************/ // this = gcd(x,y) PosInt& PosInt::gcd (const PosInt& x, const PosInt& y) { PosInt b(y); set(x); PosInt r; while (!b.isZero()) { r.mod(*this,b); set(b); b.set(r); } return *this; } // this = gcd(x,y) = s*x - t*y // NOTE THE MINUS SIGN! This is required so that both s and t are // always non-negative. PosInt& PosInt::xgcd (PosInt& s, PosInt& t, const PosInt& x, const PosInt& y) { if (this == &s || this == &t || &s == &t) throw "Arguments to xgcd must be distinct"; else if (x.isZero()) throw "First argument to xgcd must be nonzero"; else if (x.compare(y) < 0) { // Flip the arguments around if x < y xgcd (t, s, y, x); t.sub (x, t); s.sub (y, s); } else if (y.isZero()) { set(x); s.set(1); t.set(0); } else { PosInt quo; PosInt rem; divrem (quo, rem, x, y); xgcd (t, s, y, rem); quo.mul(s); t.sub (x, t); t.sub (quo); s.sub (y, s); } return *this; } /******************** Primality Testing ********************/ // returns true if this is PROBABLY prime bool PosInt::MillerRabin () const { static PosInt one (1); static PosInt two (2); if (compare(one) <= 0) return false; if (compare(two) == 0) return true; PosInt a; PosInt nminus1 (*this); nminus1.sub(one); // Choose a = random(2..n-1) PosInt nminus3(nminus1); nminus3.sub(two); a.rand(nminus3); a.add(two); PosInt d(nminus1); PosInt k(0); while (d.isEven()) { d.div(two); k.add(one); } // Compute x = a^d mod n PosInt powof2(one); while (powof2.compare(d) < 0) powof2.mul(two); PosInt x(1); while (!powof2.isZero()) { x.mul(x); if (powof2.compare(d) <= 0) { x.mul(a); d.sub(powof2); } x.mod(*this); powof2.div(two); } x.mul(x); x.mod(*this); // x = x^2 mod n if (x.compare(one) == 0) return true; if (x.compare(nminus1) == 0) return true; while (k.compare(one) > 0) { // while (k > 1) x.mul(x); x.mod(*this); // x = x^2 mod n if (x.compare(one) == 0) return false; if (x.compare(nminus1) == 0) return true; k.sub(one); } return false; }