Add support for pow for big integers

This commit is contained in:
Eugen Wissner 2016-12-22 21:51:16 +01:00
parent f7fb89fed0
commit 38addb7a5b
2 changed files with 149 additions and 88 deletions

View File

@ -10,7 +10,6 @@
*/ */
module tanya.math.mp; module tanya.math.mp;
import core.exception;
import std.algorithm.iteration; import std.algorithm.iteration;
import std.algorithm.searching; import std.algorithm.searching;
import std.algorithm.mutation; import std.algorithm.mutation;
@ -24,7 +23,7 @@ import tanya.memory;
*/ */
struct Integer struct Integer
{ {
private ubyte[] rep; package ubyte[] rep;
private bool sign; private bool sign;
/** /**
@ -48,14 +47,8 @@ struct Integer
} }
else if (value.length > 0) else if (value.length > 0)
{ {
rep = () @trusted { allocator.resize!(ubyte, false)(rep, value.length);
return cast(ubyte[]) allocator_.allocate(value.length); rep[] = value.rep[];
}();
if (rep is null)
{
onOutOfMemoryError();
}
value.rep.copy(rep);
sign = value.sign; sign = value.sign;
} }
} }
@ -84,6 +77,8 @@ struct Integer
assert(h2.sign); assert(h2.sign);
} }
@disable this(this);
/** /**
* Destroys the internal representation. * Destroys the internal representation.
*/ */
@ -125,14 +120,8 @@ struct Integer
} }
--size; --size;
} }
rep = () @trusted { allocator.resize!(ubyte, false)(rep, size);
void[] rep = this.rep;
if (!allocator.reallocate(rep, size))
{
onOutOfMemoryError();
}
return cast(ubyte[]) rep;
}();
/* Work backward through the int, masking off each byte (up to the /* Work backward through the int, masking off each byte (up to the
first 0 byte) and copy it into the internal representation in first 0 byte) and copy it into the internal representation in
big-endian format. */ big-endian format. */
@ -162,8 +151,8 @@ struct Integer
} }
else else
{ {
allocator.resizeArray(rep, value.length); allocator.resize!(ubyte, false)(rep, value.length);
value.rep.copy(rep); rep[0 .. $] = value.rep[0 .. $];
sign = value.sign; sign = value.sign;
} }
return this; return this;
@ -209,7 +198,7 @@ struct Integer
/// Ditto. /// Ditto.
bool opEquals(in ref Integer h) const nothrow @safe @nogc bool opEquals(in ref Integer h) const nothrow @safe @nogc
{ {
return rep == h.rep; return rep == h.rep;
} }
/// ///
@ -278,7 +267,7 @@ struct Integer
assert(h1 > h2); assert(h1 > h2);
} }
private void add(in ref ubyte[] h) nothrow @trusted @nogc private void add(in ref ubyte[] h) nothrow @safe @nogc
{ {
uint sum; uint sum;
uint carry = 0; uint carry = 0;
@ -286,11 +275,7 @@ struct Integer
if (h.length > length) if (h.length > length)
{ {
tmp = cast(ubyte[]) allocator.allocate(h.length); allocator.resize!(ubyte, false)(tmp, h.length);
if (tmp is null)
{
onOutOfMemoryError();
}
tmp[0 .. h.length] = 0; tmp[0 .. h.length] = 0;
tmp[h.length - length .. $] = rep[0 .. length]; tmp[h.length - length .. $] = rep[0 .. length];
swap(rep, tmp); swap(rep, tmp);
@ -319,13 +304,12 @@ struct Integer
if (carry) if (carry)
{ {
// Still overflowed; allocate more space // Still overflowed; allocate more space
void[]* vtmp = cast(void[]*) &tmp; allocator.resize!(ubyte, false)(tmp, length + 1);
allocator.reallocate(*vtmp, length + 1);
tmp[1 .. $] = rep[0 .. length]; tmp[1 .. $] = rep[0 .. length];
tmp[0] = 0x01; tmp[0] = 0x01;
swap(rep, tmp); swap(rep, tmp);
} }
allocator.deallocate(tmp); allocator.dispose(tmp);
} }
private void subtract(in ref ubyte[] h) nothrow @trusted @nogc private void subtract(in ref ubyte[] h) nothrow @trusted @nogc
@ -363,7 +347,7 @@ struct Integer
if (offset > 0) if (offset > 0)
{ {
ubyte[] tmp = cast(ubyte[]) allocator.allocate(rep.length - offset); ubyte[] tmp = cast(ubyte[]) allocator.allocate(rep.length - offset);
rep[offset .. $].copy(tmp); tmp[0 .. $] = rep[offset .. $];
allocator.deallocate(rep); allocator.deallocate(rep);
rep = tmp; rep = tmp;
} }
@ -505,6 +489,15 @@ struct Integer
body body
{ {
auto i = h.rep.length; auto i = h.rep.length;
if (length == 0)
{
return this;
}
else if (i == 0)
{
opAssign(0);
return this;
}
auto temp = Integer(this, allocator); auto temp = Integer(this, allocator);
immutable sign = sign == h.sign ? false : true; immutable sign = sign == h.sign ? false : true;
@ -536,6 +529,11 @@ struct Integer
assert(cast(long) h1 == 56088); assert(cast(long) h1 == 56088);
} }
private unittest
{
assert((Integer(1) * Integer()).length == 0);
}
/// Ditto. /// Ditto.
ref Integer opOpAssign(string op)(in auto ref Integer h) nothrow @safe @nogc ref Integer opOpAssign(string op)(in auto ref Integer h) nothrow @safe @nogc
if ((op == "/") || (op == "%")) if ((op == "/") || (op == "%"))
@ -555,9 +553,8 @@ struct Integer
} }
static if (op == "/") static if (op == "/")
{ {
auto quotient = (() @trusted => ubyte[] quotient;
cast(ubyte[]) allocator.allocate(bitSize / 8 + 1) allocator.resize!(ubyte, false)(quotient, bitSize / 8 + 1);
)();
} }
// "bitPosition" keeps track of which bit, of the quotient, // "bitPosition" keeps track of which bit, of the quotient,
@ -587,7 +584,7 @@ struct Integer
static if (op == "/") static if (op == "/")
{ {
() @trusted { allocator.deallocate(rep); }(); allocator.dispose(rep);
rep = quotient; rep = quotient;
sign = sign == h.sign ? false : true; sign = sign == h.sign ? false : true;
} }
@ -727,7 +724,7 @@ struct Integer
immutable size = rep.retro.countUntil!((const ref a) => a != 0); immutable size = rep.retro.countUntil!((const ref a) => a != 0);
if (rep[0] == 1) if (rep[0] == 1)
{ {
allocator.resizeArray(rep, rep.length - 1); allocator.resize!(ubyte, false)(rep, rep.length - 1);
rep[0 .. $] = typeof(rep[0]).max; rep[0 .. $] = typeof(rep[0]).max;
} }
else else
@ -739,13 +736,11 @@ struct Integer
private void increment() nothrow @safe @nogc private void increment() nothrow @safe @nogc
{ {
auto size = rep auto size = rep.retro.countUntil!((const ref a) => a != typeof(rep[0]).max);
.retro
.countUntil!((const ref a) => a != typeof(rep[0]).max);
if (size == -1) if (size == -1)
{ {
size = length; size = length;
allocator.resizeArray(rep, rep.length + 1); allocator.resize!(ubyte, false)(rep, rep.length + 1);
rep[0] = 1; rep[0] = 1;
} }
else else
@ -856,24 +851,30 @@ struct Integer
return length == 0 ? false : true; return length == 0 ? false : true;
} }
/** /// Ditto.
* Casting to integer types. T opCast(T)() const pure nothrow @safe @nogc
* if (isIntegral!T && isSigned!T)
* Params:
* T = Target type.
*
* Returns: Signed integer.
*/
T opCast(T : long)() const pure nothrow @safe @nogc
{ {
ulong ret; ulong ret;
for (size_t i = length, j; i > 0 && j <= 32; --i, j += 8) for (size_t i = length, j; i > 0 && j <= T.sizeof * 4; --i, j += 8)
{ {
ret |= cast(long) (rep[i - 1]) << j; ret |= cast(T) (rep[i - 1]) << j;
} }
return sign ? -ret : ret; return sign ? -ret : ret;
} }
/// Ditto.
T opCast(T)() const pure nothrow @safe @nogc
if (isIntegral!T && isUnsigned!T)
{
ulong ret;
for (size_t i = length, j; i > 0 && j <= T.sizeof * 8; --i, j += 8)
{
ret |= cast(T) (rep[i - 1]) << j;
}
return ret;
}
/// ///
unittest unittest
{ {
@ -912,7 +913,7 @@ struct Integer
if (step >= rep.length) if (step >= rep.length)
{ {
allocator.resizeArray(rep, 0); allocator.resize!(ubyte, false)(rep, 0);
return this; return this;
} }
@ -934,7 +935,7 @@ struct Integer
rep[j] = (rep[i] >> bit | oldCarry); rep[j] = (rep[i] >> bit | oldCarry);
++j; ++j;
} }
allocator.resizeArray(rep, rep.length - n / 8 - (i == j ? 0 : 1)); allocator.resize!(ubyte, false)(rep, rep.length - n / 8 - (i == j ? 0 : 1));
return this; return this;
} }
@ -993,12 +994,12 @@ struct Integer
if (cast(ubyte) (rep[0] >> delta)) if (cast(ubyte) (rep[0] >> delta))
{ {
allocator.resizeArray(rep, i + n / 8 + 1); allocator.resize!(ubyte, false)(rep, i + n / 8 + 1);
j = i + 1; j = i + 1;
} }
else else
{ {
allocator.resizeArray(rep, i + n / 8); allocator.resize!(ubyte, false)(rep, i + n / 8);
j = i; j = i;
} }
do do

View File

@ -10,7 +10,9 @@
*/ */
module tanya.math; module tanya.math;
import std.traits;
public import tanya.math.mp; public import tanya.math.mp;
public import tanya.math.random;
version (unittest) version (unittest)
{ {
@ -20,26 +22,32 @@ version (unittest)
/** /**
* Computes $(D_PARAM x) to the power $(D_PARAM y) modulo $(D_PARAM z). * Computes $(D_PARAM x) to the power $(D_PARAM y) modulo $(D_PARAM z).
* *
* If $(D_PARAM I) is an $(D_PSYMBOL Integer), the allocator of $(D_PARAM x)
* is used to allocate the result.
*
* Params: * Params:
* I = Base type.
* G = Exponent type.
* H = Divisor type:
* x = Base. * x = Base.
* y = Exponent. * y = Exponent.
* z = Divisor. * z = Divisor.
* *
* Returns: Reminder of the division of $(D_PARAM x) to the power $(D_PARAM y) * Returns: Reminder of the division of $(D_PARAM x) to the power $(D_PARAM y)
* by $(D_PARAM z). * by $(D_PARAM z).
*
* Precondition: $(D_INLINECODE z > 0)
*/ */
ulong pow(ulong x, ulong y, ulong z) nothrow pure @safe @nogc H pow(I, G, H)(in auto ref I x, in auto ref G y, in auto ref H z)
if (isIntegral!I && isIntegral!G && isIntegral!H)
in in
{ {
assert(z > 0); assert(z > 0, "Division by zero.");
}
out (result)
{
assert(result >= 0);
} }
body body
{ {
ulong mask = ulong.max / 2 + 1, result; G mask = G.max / 2 + 1;
H result;
if (y == 0) if (y == 0)
{ {
@ -49,40 +57,92 @@ body
{ {
return x % z; return x % z;
} }
do do
{ {
auto bit = y & mask; immutable bit = y & mask;
if (!result && bit) if (!result && bit)
{ {
result = x; result = x;
continue; continue;
} }
result *= result; result *= result;
if (bit) if (bit)
{ {
result *= x; result *= x;
} }
result %= z; result %= z;
}
} while (mask >>= 1);
while (mask >>= 1);
return result; return result;
} }
/// Ditto.
I pow(I)(in auto ref I x, in auto ref I y, in auto ref I z)
if (is(I == Integer))
in
{
assert(z.length > 0, "Division by zero.");
}
body
{
size_t i = y.length;
auto tmp2 = Integer(x.allocator), tmp1 = Integer(x, x.allocator);
Integer result = Integer(x.allocator);
if (x.length == 0 && i != 0)
{
i = 0;
}
else
{
result = 1;
}
while (i)
{
--i;
for (ubyte mask = 0x01; mask; mask <<= 1)
{
if (y.rep[i] & mask)
{
result *= tmp1;
result %= z;
}
tmp2 = tmp1;
tmp1 *= tmp2;
tmp1 %= z;
}
}
return result;
}
///
pure nothrow @safe @nogc unittest
{
assert(pow(3, 5, 7) == 5);
assert(pow(2, 2, 1) == 0);
assert(pow(3, 3, 3) == 0);
assert(pow(7, 4, 2) == 1);
assert(pow(53, 0, 2) == 1);
assert(pow(53, 1, 3) == 2);
assert(pow(53, 2, 5) == 4);
assert(pow(0, 0, 5) == 1);
assert(pow(0, 5, 5) == 0);
}
/// ///
unittest unittest
{ {
assert(pow(3, 5, 7) == 5); assert(cast(long) pow(Integer(3), Integer(5), Integer(7)) == 5);
assert(pow(2, 2, 1) == 0); assert(cast(long) pow(Integer(2), Integer(2), Integer(1)) == 0);
assert(pow(3, 3, 3) == 0); assert(cast(long) pow(Integer(3), Integer(3), Integer(3)) == 0);
assert(pow(7, 4, 2) == 1); assert(cast(long) pow(Integer(7), Integer(4), Integer(2)) == 1);
assert(pow(53, 0, 2) == 1); assert(cast(long) pow(Integer(53), Integer(0), Integer(2)) == 1);
assert(pow(53, 1, 3) == 2); assert(cast(long) pow(Integer(53), Integer(1), Integer(3)) == 2);
assert(pow(53, 2, 5) == 4); assert(cast(long) pow(Integer(53), Integer(2), Integer(5)) == 4);
assert(pow(0, 0, 5) == 1); assert(cast(long) pow(Integer(0), Integer(0), Integer(5)) == 1);
assert(pow(0, 5, 5) == 0); assert(cast(long) pow(Integer(0), Integer(5), Integer(5)) == 0);
} }
/** /**