aboutsummaryrefslogtreecommitdiff
path: root/src/math
diff options
context:
space:
mode:
Diffstat (limited to 'src/math')
-rw-r--r--src/math/big/arith_test.go5
-rw-r--r--src/math/big/float.go8
-rw-r--r--src/math/big/int.go20
-rw-r--r--src/math/big/nat.go310
-rw-r--r--src/math/big/nat_test.go74
-rw-r--r--src/math/big/natconv.go25
-rw-r--r--src/math/big/natconv_test.go21
-rw-r--r--src/math/big/natdiv.go98
-rw-r--r--src/math/big/prime.go42
-rw-r--r--src/math/big/prime_test.go17
-rw-r--r--src/math/big/rat.go73
-rw-r--r--src/math/big/ratconv.go28
12 files changed, 429 insertions, 292 deletions
diff --git a/src/math/big/arith_test.go b/src/math/big/arith_test.go
index 64225bbd53..feffa1bc95 100644
--- a/src/math/big/arith_test.go
+++ b/src/math/big/arith_test.go
@@ -368,9 +368,12 @@ func TestShiftOverlap(t *testing.T) {
}
func TestIssue31084(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
// compute 10^n via 5^n << n.
const n = 165
- p := nat(nil).expNN(nat{5}, nat{n}, nil, false)
+ p := nat(nil).expNN(stk, nat{5}, nat{n}, nil, false)
p = p.shl(p, n)
got := string(p.utoa(10))
want := "1" + strings.Repeat("0", n)
diff --git a/src/math/big/float.go b/src/math/big/float.go
index e1d20d8bb4..2c5234a4ce 100644
--- a/src/math/big/float.go
+++ b/src/math/big/float.go
@@ -1327,9 +1327,9 @@ func (z *Float) umul(x, y *Float) {
e := int64(x.exp) + int64(y.exp)
if x == y {
- z.mant = z.mant.sqr(x.mant)
+ z.mant = z.mant.sqr(nil, x.mant)
} else {
- z.mant = z.mant.mul(x.mant, y.mant)
+ z.mant = z.mant.mul(nil, x.mant, y.mant)
}
z.setExpAndRound(e-fnorm(z.mant), 0)
}
@@ -1363,8 +1363,10 @@ func (z *Float) uquo(x, y *Float) {
d := len(xadj) - len(y.mant)
// divide
+ stk := getStack()
+ defer stk.free()
var r nat
- z.mant, r = z.mant.div(nil, xadj, y.mant)
+ z.mant, r = z.mant.div(stk, nil, xadj, y.mant)
e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W
// The result is long enough to include (at least) the rounding bit.
diff --git a/src/math/big/int.go b/src/math/big/int.go
index 0b710c6968..cb7221250d 100644
--- a/src/math/big/int.go
+++ b/src/math/big/int.go
@@ -181,16 +181,20 @@ func (z *Int) Sub(x, y *Int) *Int {
// Mul sets z to the product x*y and returns z.
func (z *Int) Mul(x, y *Int) *Int {
+ return z.mul(nil, x, y)
+}
+
+func (z *Int) mul(stk *stack, x, y *Int) *Int {
// x * y == x * y
// x * (-y) == -(x * y)
// (-x) * y == -(x * y)
// (-x) * (-y) == x * y
if x == y {
- z.abs = z.abs.sqr(x.abs)
+ z.abs = z.abs.sqr(stk, x.abs)
z.neg = false
return z
}
- z.abs = z.abs.mul(x.abs, y.abs)
+ z.abs = z.abs.mul(stk, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
return z
}
@@ -213,7 +217,7 @@ func (z *Int) MulRange(a, b int64) *Int {
a, b = -b, -a
}
- z.abs = z.abs.mulRange(uint64(a), uint64(b))
+ z.abs = z.abs.mulRange(nil, uint64(a), uint64(b))
z.neg = neg
return z
}
@@ -264,7 +268,7 @@ func (z *Int) Binomial(n, k int64) *Int {
// If y == 0, a division-by-zero run-time panic occurs.
// Quo implements truncated division (like Go); see [Int.QuoRem] for more details.
func (z *Int) Quo(x, y *Int) *Int {
- z.abs, _ = z.abs.div(nil, x.abs, y.abs)
+ z.abs, _ = z.abs.div(nil, nil, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
return z
}
@@ -273,7 +277,7 @@ func (z *Int) Quo(x, y *Int) *Int {
// If y == 0, a division-by-zero run-time panic occurs.
// Rem implements truncated modulus (like Go); see [Int.QuoRem] for more details.
func (z *Int) Rem(x, y *Int) *Int {
- _, z.abs = nat(nil).div(z.abs, x.abs, y.abs)
+ _, z.abs = nat(nil).div(nil, z.abs, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
return z
}
@@ -290,7 +294,7 @@ func (z *Int) Rem(x, y *Int) *Int {
// (See Daan Leijen, “Division and Modulus for Computer Scientists”.)
// See [Int.DivMod] for Euclidean division and modulus (unlike Go).
func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) {
- z.abs, r.abs = z.abs.div(r.abs, x.abs, y.abs)
+ z.abs, r.abs = z.abs.div(nil, r.abs, x.abs, y.abs)
z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg // 0 has no sign
return z, r
}
@@ -589,7 +593,7 @@ func (z *Int) exp(x, y, m *Int, slow bool) *Int {
mWords = m.abs // m.abs may be nil for m == 0
}
- z.abs = z.abs.expNN(xWords, yWords, mWords, slow)
+ z.abs = z.abs.expNN(nil, xWords, yWords, mWords, slow)
z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign
if z.neg && len(mWords) > 0 {
// make modulus result positive
@@ -1298,6 +1302,6 @@ func (z *Int) Sqrt(x *Int) *Int {
panic("square root of negative number")
}
z.neg = false
- z.abs = z.abs.sqrt(x.abs)
+ z.abs = z.abs.sqrt(nil, x.abs)
return z
}
diff --git a/src/math/big/nat.go b/src/math/big/nat.go
index 541da229d6..ec75c8f6fd 100644
--- a/src/math/big/nat.go
+++ b/src/math/big/nat.go
@@ -17,6 +17,7 @@ import (
"internal/byteorder"
"math/bits"
"math/rand"
+ "slices"
"sync"
)
@@ -262,9 +263,9 @@ var karatsubaThreshold = 40 // computed by calibrate_test.go
// karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a
-// power of 2. The result vector z must have len(z) >= 6*n.
-// The (non-normalized) result is placed in z[0 : 2*n].
-func karatsuba(z, x, y nat) {
+// power of 2. The result vector z must have len(z) == len(x)+len(y).
+// The (non-normalized) result is placed in z.
+func karatsuba(stk *stack, z, x, y nat) {
n := len(y)
// Switch to basic multiplication if numbers are odd or small.
@@ -304,29 +305,19 @@ func karatsuba(z, x, y nat) {
x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
- // z is used for the result and temporary storage:
- //
- // 6*n 5*n 4*n 3*n 2*n 1*n 0*n
- // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
- //
- // For each recursive call of karatsuba, an unused slice of
- // z is passed in that has (at least) half the length of the
- // caller's z.
-
// compute z0 and z2 with the result "in place" in z
- karatsuba(z, x0, y0) // z0 = x0*y0
- karatsuba(z[n:], x1, y1) // z2 = x1*y1
+ karatsuba(stk, z, x0, y0) // z0 = x0*y0
+ karatsuba(stk, z[n:], x1, y1) // z2 = x1*y1
- // compute xd (or the negative value if underflow occurs)
+ // compute xd, yd (or the negative value if underflow occurs)
s := 1 // sign of product xd*yd
- xd := z[2*n : 2*n+n2]
+ defer stk.restore(stk.save())
+ xd := stk.nat(n2)
+ yd := stk.nat(n2)
if subVV(xd, x1, x0) != 0 { // x1-x0
s = -s
subVV(xd, x0, x1) // x0-x1
}
-
- // compute yd (or the negative value if underflow occurs)
- yd := z[2*n+n2 : 3*n]
if subVV(yd, y0, y1) != 0 { // y0-y1
s = -s
subVV(yd, y1, y0) // y1-y0
@@ -334,12 +325,12 @@ func karatsuba(z, x, y nat) {
// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
- p := z[n*3:]
- karatsuba(p, xd, yd)
+ p := stk.nat(2 * n2)
+ karatsuba(stk, p, xd, yd)
// save original z2:z0
// (ok to use upper half of z since we're done recurring)
- r := z[n*4:]
+ r := stk.nat(n * 2)
copy(r, z[:n*2])
// add up all partial products
@@ -396,13 +387,15 @@ func karatsubaLen(n, threshold int) int {
return n << i
}
-func (z nat) mul(x, y nat) nat {
+// mul sets z = x*y, using stk for temporary storage.
+// The caller may pass stk == nil to request that mul obtain and release one itself.
+func (z nat) mul(stk *stack, x, y nat) nat {
m := len(x)
n := len(y)
switch {
case m < n:
- return z.mul(y, x)
+ return z.mul(stk, y, x)
case m == 0 || n == 0:
return z[:0]
case n == 1:
@@ -432,12 +425,16 @@ func (z nat) mul(x, y nat) nat {
k := karatsubaLen(n, karatsubaThreshold)
// k <= n
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+
// multiply x0 and y0 via Karatsuba
- x0 := x[0:k] // x0 is not normalized
- y0 := y[0:k] // y0 is not normalized
- z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
- karatsuba(z, x0, y0)
- z = z[0 : m+n] // z has final length but may be incomplete
+ x0 := x[0:k] // x0 is not normalized
+ y0 := y[0:k] // y0 is not normalized
+ z = z.make(m + n) // enough space for full result of x*y
+ karatsuba(stk, z, x0, y0)
clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
// If xh != 0 or yh != 0, add the missing terms to z. For
@@ -454,13 +451,13 @@ func (z nat) mul(x, y nat) nat {
// be a larger valid threshold contradicting the assumption about k.
//
if k < n || m != n {
- tp := getNat(3 * k)
- t := *tp
+ defer stk.restore(stk.save())
+ t := stk.nat(3 * k)
// add x0*y1*b
x0 := x0.norm()
- y1 := y[k:] // y1 is normalized because y is
- t = t.mul(x0, y1) // update t so we don't lose t's underlying array
+ y1 := y[k:] // y1 is normalized because y is
+ t = t.mul(stk, x0, y1) // update t so we don't lose t's underlying array
addAt(z, t, k)
// add xi*y0<<i, xi*y1*b<<(i+k)
@@ -471,13 +468,11 @@ func (z nat) mul(x, y nat) nat {
xi = xi[:k]
}
xi = xi.norm()
- t = t.mul(xi, y0)
+ t = t.mul(stk, xi, y0)
addAt(z, t, i)
- t = t.mul(xi, y1)
+ t = t.mul(stk, xi, y1)
addAt(z, t, i+k)
}
-
- putNat(tp)
}
return z.norm()
@@ -487,10 +482,10 @@ func (z nat) mul(x, y nat) nat {
// by about a factor of 2, but slower for small arguments due to overhead.
// Requirements: len(x) > 0, len(z) == 2*len(x)
// The (non-normalized) result is placed in z.
-func basicSqr(z, x nat) {
+func basicSqr(stk *stack, z, x nat) {
n := len(x)
- tp := getNat(2 * n)
- t := *tp // temporary variable to hold the products
+ defer stk.restore(stk.save())
+ t := stk.nat(2 * n)
clear(t)
z[1], z[0] = mulWW(x[0], x[0]) // the initial square
for i := 1; i < n; i++ {
@@ -502,38 +497,37 @@ func basicSqr(z, x nat) {
}
t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
addVV(z, z, t) // combine the result
- putNat(tp)
}
// karatsubaSqr squares x and leaves the result in z.
-// len(x) must be a power of 2 and len(z) >= 6*len(x).
-// The (non-normalized) result is placed in z[0 : 2*len(x)].
+// len(x) must be a power of 2 and len(z) == 2*len(x).
+// The (non-normalized) result is placed in z.
//
// The algorithm and the layout of z are the same as for karatsuba.
-func karatsubaSqr(z, x nat) {
+func karatsubaSqr(stk *stack, z, x nat) {
n := len(x)
if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
- basicSqr(z[:2*n], x)
+ basicSqr(stk, z[:2*n], x)
return
}
n2 := n >> 1
x1, x0 := x[n2:], x[0:n2]
- karatsubaSqr(z, x0)
- karatsubaSqr(z[n:], x1)
+ karatsubaSqr(stk, z, x0)
+ karatsubaSqr(stk, z[n:], x1)
// s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
- xd := z[2*n : 2*n+n2]
+ defer stk.restore(stk.save())
+ p := stk.nat(2 * n2)
+ r := stk.nat(n * 2)
+ xd := r[:n2]
if subVV(xd, x1, x0) != 0 {
subVV(xd, x0, x1)
}
- p := z[n*3:]
- karatsubaSqr(p, xd)
-
- r := z[n*4:]
+ karatsubaSqr(stk, p, xd)
copy(r, z[:n*2])
karatsubaAdd(z[n2:], r, n)
@@ -547,8 +541,9 @@ func karatsubaSqr(z, x nat) {
var basicSqrThreshold = 20 // computed by calibrate_test.go
var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
-// z = x*x
-func (z nat) sqr(x nat) nat {
+// sqr sets z = x*x, using stk for temporary storage.
+// The caller may pass stk == nil to request that sqr obtain and release one itself.
+func (z nat) sqr(stk *stack, x nat) nat {
n := len(x)
switch {
case n == 0:
@@ -563,15 +558,20 @@ func (z nat) sqr(x nat) nat {
if alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}
+ z = z.make(2 * n)
if n < basicSqrThreshold {
- z = z.make(2 * n)
basicMul(z, x, x)
return z.norm()
}
+
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+
if n < karatsubaSqrThreshold {
- z = z.make(2 * n)
- basicSqr(z, x)
+ basicSqr(stk, z, x)
return z.norm()
}
@@ -583,22 +583,18 @@ func (z nat) sqr(x nat) nat {
k := karatsubaLen(n, karatsubaSqrThreshold)
x0 := x[0:k]
- z = z.make(max(6*k, 2*n))
- karatsubaSqr(z, x0) // z = x0^2
- z = z[0 : 2*n]
+ karatsubaSqr(stk, z, x0) // z = x0^2
clear(z[2*k:])
if k < n {
- tp := getNat(2 * k)
- t := *tp
+ t := stk.nat(2 * k)
x0 := x0.norm()
x1 := x[k:]
- t = t.mul(x0, x1)
+ t = t.mul(stk, x0, x1)
addAt(z, t, k)
addAt(z, t, k) // z = 2*x1*x0*b + x0^2
- t = t.sqr(x1)
+ t = t.sqr(stk, x1)
addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
- putNat(tp)
}
return z.norm()
@@ -606,7 +602,8 @@ func (z nat) sqr(x nat) nat {
// mulRange computes the product of all the unsigned integers in the
// range [a, b] inclusively. If a > b (empty range), the result is 1.
-func (z nat) mulRange(a, b uint64) nat {
+// The caller may pass stk == nil to request that mulRange obtain and release one itself.
+func (z nat) mulRange(stk *stack, a, b uint64) nat {
switch {
case a == 0:
// cut long ranges short (optimization)
@@ -616,34 +613,79 @@ func (z nat) mulRange(a, b uint64) nat {
case a == b:
return z.setUint64(a)
case a+1 == b:
- return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
+ return z.mul(stk, nat(nil).setUint64(a), nat(nil).setUint64(b))
+ }
+
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
}
+
m := a + (b-a)/2 // avoid overflow
- return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
+ return z.mul(stk, nat(nil).mulRange(stk, a, m), nat(nil).mulRange(stk, m+1, b))
}
-// getNat returns a *nat of len n. The contents may not be zero.
-// The pool holds *nat to avoid allocation when converting to interface{}.
-func getNat(n int) *nat {
- var z *nat
- if v := natPool.Get(); v != nil {
- z = v.(*nat)
- }
- if z == nil {
- z = new(nat)
- }
- *z = z.make(n)
- if n > 0 {
- (*z)[0] = 0xfedcb // break code expecting zero
+// A stack provides temporary storage for complex calculations
+// such as multiplication and division.
+// The stack is a simple slice of words, extended as needed
+// to hold all the temporary storage for a calculation.
+// In general, if a function takes a *stack, it expects a non-nil *stack.
+// However, certain functions may allow passing a nil *stack instead,
+// so that they can handle trivial stack-free cases without forcing the
+// caller to obtain and free a stack that will be unused. These functions
+// document that they accept a nil *stack in their doc comments.
+type stack struct {
+ w []Word
+}
+
+var stackPool sync.Pool
+
+// getStack returns a temporary stack.
+// The caller must call [stack.free] to give up use of the stack when finished.
+func getStack() *stack {
+ s, _ := stackPool.Get().(*stack)
+ if s == nil {
+ s = new(stack)
}
- return z
+ return s
+}
+
+// free returns the stack for use by another calculation.
+func (s *stack) free() {
+ s.w = s.w[:0]
+ stackPool.Put(s)
}
-func putNat(x *nat) {
- natPool.Put(x)
+// save returns the current stack pointer.
+// A future call to restore with the same value
+// frees any temporaries allocated on the stack after the call to save.
+func (s *stack) save() int {
+ return len(s.w)
}
-var natPool sync.Pool
+// restore restores the stack pointer to n.
+// It is almost always invoked as
+//
+// defer stk.restore(stk.save())
+//
+// which makes sure to pop any temporaries allocated in the current function
+// from the stack before returning.
+func (s *stack) restore(n int) {
+ s.w = s.w[:n]
+}
+
+// nat returns a nat of n words, allocated on the stack.
+func (s *stack) nat(n int) nat {
+ nr := (n + 3) &^ 3 // round up to multiple of 4
+ off := len(s.w)
+ s.w = slices.Grow(s.w, nr)
+ s.w = s.w[:off+nr]
+ x := s.w[off : off+n : off+n]
+ if n > 0 {
+ x[0] = 0xfedcb
+ }
+ return x
+}
// bitLen returns the length of x in bits.
// Unlike most methods, it works even if x is not normalized.
@@ -930,7 +972,8 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
// otherwise it sets z to x**y. The result is the value of z.
-func (z nat) expNN(x, y, m nat, slow bool) nat {
+// The caller may pass stk == nil to request that expNN obtain and release one itself.
+func (z nat) expNN(stk *stack, x, y, m nat, slow bool) nat {
if alias(z, x) || alias(z, y) {
// We cannot allow in-place modification of x or y.
z = nil
@@ -961,12 +1004,17 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
// x > 1
// x**1 == x
- if len(y) == 1 && y[0] == 1 {
- if len(m) != 0 {
- return z.rem(x, m)
- }
+ if len(y) == 1 && y[0] == 1 && len(m) == 0 {
return z.set(x)
}
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+ if len(y) == 1 && y[0] == 1 { // len(m) > 0
+ return z.rem(stk, x, m)
+ }
+
// y > 1
if len(m) != 0 {
@@ -980,12 +1028,12 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
// instance of each of the first two cases).
if len(y) > 1 && !slow {
if m[0]&1 == 1 {
- return z.expNNMontgomery(x, y, m)
+ return z.expNNMontgomery(stk, x, y, m)
}
if logM, ok := m.isPow2(); ok {
- return z.expNNWindowed(x, y, logM)
+ return z.expNNWindowed(stk, x, y, logM)
}
- return z.expNNMontgomeryEven(x, y, m)
+ return z.expNNMontgomeryEven(stk, x, y, m)
}
}
@@ -1006,16 +1054,16 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
// otherwise the arguments would alias.
var zz, r nat
for j := 0; j < w; j++ {
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
if v&mask != 0 {
- zz = zz.mul(z, x)
+ zz = zz.mul(stk, z, x)
zz, z = z, zz
}
if len(m) != 0 {
- zz, r = zz.div(r, z, m)
+ zz, r = zz.div(stk, r, z, m)
zz, r, q, z = q, z, zz, r
}
@@ -1026,16 +1074,16 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
v = y[i]
for j := 0; j < _W; j++ {
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
if v&mask != 0 {
- zz = zz.mul(z, x)
+ zz = zz.mul(stk, z, x)
zz, z = z, zz
}
if len(m) != 0 {
- zz, r = zz.div(r, z, m)
+ zz, r = zz.div(stk, r, z, m)
zz, r, q, z = q, z, zz, r
}
@@ -1054,7 +1102,7 @@ func (z nat) expNN(x, y, m nat, slow bool) nat {
// For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”,
// IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994.
// http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf
-func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
+func (z nat) expNNMontgomeryEven(stk *stack, x, y, m nat) nat {
// Split m = m₁ × m₂ where m₁ = 2ⁿ
n := m.trailingZeroBits()
m1 := nat(nil).shl(natOne, n)
@@ -1066,8 +1114,8 @@ func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
// (We are using the math/big convention for names here,
// where the computation is z = x**y mod m, so its parts are z1 and z2.
// The paper is computing x = a**e mod n; it refers to these as x2 and z1.)
- z1 := nat(nil).expNN(x, y, m1, false)
- z2 := nat(nil).expNN(x, y, m2, false)
+ z1 := nat(nil).expNN(stk, x, y, m1, false)
+ z2 := nat(nil).expNN(stk, x, y, m2, false)
// Reconstruct z from z₁, z₂ using CRT, using algorithm from paper,
// which uses only a single modInverse (and an easy one at that).
@@ -1086,18 +1134,18 @@ func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
// Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]).
m2inv := nat(nil).modInverse(m2, m1)
- z2 = z2.mul(z1, m2inv)
+ z2 = z2.mul(stk, z1, m2inv)
z2 = z2.trunc(z2, n)
// Reuse z1 for p * m2.
- z = z.add(z, z1.mul(z2, m2))
+ z = z.add(z, z1.mul(stk, z2, m2))
return z
}
// expNNWindowed calculates x**y mod m using a fixed, 4-bit window,
// where m = 2**logM.
-func (z nat) expNNWindowed(x, y nat, logM uint) nat {
+func (z nat) expNNWindowed(stk *stack, x, y nat, logM uint) nat {
if len(y) <= 1 {
panic("big: misuse of expNNWindowed")
}
@@ -1112,23 +1160,23 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat {
// zz is used to avoid allocating in mul as otherwise
// the arguments would alias.
+ defer stk.restore(stk.save())
w := int((logM + _W - 1) / _W)
- zzp := getNat(w)
- zz := *zzp
+ zz := stk.nat(w)
const n = 4
// powers[i] contains x^i.
- var powers [1 << n]*nat
+ var powers [1 << n]nat
for i := range powers {
- powers[i] = getNat(w)
+ powers[i] = stk.nat(w)
}
- *powers[0] = powers[0].set(natOne)
- *powers[1] = powers[1].trunc(x, logM)
+ powers[0] = powers[0].set(natOne)
+ powers[1] = powers[1].trunc(x, logM)
for i := 2; i < 1<<n; i += 2 {
- p2, p, p1 := powers[i/2], powers[i], powers[i+1]
- *p = p.sqr(*p2)
+ p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
+ *p = p.sqr(stk, *p2)
*p = p.trunc(*p, logM)
- *p1 = p1.mul(*p, x)
+ *p1 = p1.mul(stk, *p, x)
*p1 = p1.trunc(*p1, logM)
}
@@ -1159,24 +1207,24 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat {
// Unrolled loop for significant performance
// gain. Use go test -bench=".*" in crypto/rsa
// to check performance before making changes.
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
z = z.trunc(z, logM)
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
z = z.trunc(z, logM)
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
z = z.trunc(z, logM)
- zz = zz.sqr(z)
+ zz = zz.sqr(stk, z)
zz, z = z, zz
z = z.trunc(z, logM)
}
- zz = zz.mul(z, *powers[yi>>(_W-n)])
+ zz = zz.mul(stk, z, powers[yi>>(_W-n)])
zz, z = z, zz
z = z.trunc(z, logM)
@@ -1185,24 +1233,18 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat {
}
}
- *zzp = zz
- putNat(zzp)
- for i := range powers {
- putNat(powers[i])
- }
-
return z.norm()
}
// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
// Uses Montgomery representation.
-func (z nat) expNNMontgomery(x, y, m nat) nat {
+func (z nat) expNNMontgomery(stk *stack, x, y, m nat) nat {
numWords := len(m)
// We want the lengths of x and m to be equal.
// It is OK if x >= m as long as len(x) == len(m).
if len(x) > numWords {
- _, x = nat(nil).div(nil, x, m)
+ _, x = nat(nil).div(stk, nil, x, m)
// Note: now len(x) <= numWords, not guaranteed ==.
}
if len(x) < numWords {
@@ -1225,7 +1267,7 @@ func (z nat) expNNMontgomery(x, y, m nat) nat {
// RR = 2**(2*_W*len(m)) mod m
RR := nat(nil).setWord(1)
zz := nat(nil).shl(RR, uint(2*numWords*_W))
- _, RR = nat(nil).div(RR, zz, m)
+ _, RR = nat(nil).div(stk, RR, zz, m)
if len(RR) < numWords {
zz = zz.make(numWords)
copy(zz, RR)
@@ -1280,7 +1322,7 @@ func (z nat) expNNMontgomery(x, y, m nat) nat {
// The div is not expected to be reached.
zz = zz.sub(zz, m)
if zz.cmp(m) >= 0 {
- _, zz = nat(nil).div(nil, zz, m)
+ _, zz = nat(nil).div(stk, nil, zz, m)
}
}
@@ -1349,7 +1391,8 @@ func (z nat) setBytes(buf []byte) nat {
}
// sqrt sets z = ⌊√x⌋
-func (z nat) sqrt(x nat) nat {
+// The caller may pass stk == nil to request that sqrt obtain and release one itself.
+func (z nat) sqrt(stk *stack, x nat) nat {
if x.cmp(natOne) <= 0 {
return z.set(x)
}
@@ -1357,6 +1400,11 @@ func (z nat) sqrt(x nat) nat {
z = nil
}
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+
// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
// See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
// https://members.loria.fr/PZimmermann/mca/pub226.html
@@ -1367,7 +1415,7 @@ func (z nat) sqrt(x nat) nat {
z1 = z1.setUint64(1)
z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x
for n := 0; ; n++ {
- z2, _ = z2.div(nil, x, z1)
+ z2, _ = z2.div(stk, nil, x, z1)
z2 = z2.add(z2, z1)
z2 = z2.shr(z2, 1)
if z2.cmp(z1) >= 0 {
diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go
index 46231f7976..1811dccfe3 100644
--- a/src/math/big/nat_test.go
+++ b/src/math/big/nat_test.go
@@ -42,6 +42,7 @@ func TestCmp(t *testing.T) {
}
type funNN func(z, x, y nat) nat
+type funSNN func(z nat, stk *stack, x, y nat) nat
type argNN struct {
z, x, y nat
}
@@ -112,6 +113,15 @@ func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
}
}
+func testFunSNN(t *testing.T, msg string, f funSNN, a argNN) {
+ stk := getStack()
+ defer stk.free()
+ z := f(nil, stk, a.x, a.y)
+ if z.cmp(a.z) != 0 {
+ t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z)
+ }
+}
+
func TestFunNN(t *testing.T) {
for _, a := range sumNN {
arg := a
@@ -129,10 +139,10 @@ func TestFunNN(t *testing.T) {
for _, a := range prodNN {
arg := a
- testFunNN(t, "mul", nat.mul, arg)
+ testFunSNN(t, "mul", nat.mul, arg)
arg = argNN{a.z, a.y, a.x}
- testFunNN(t, "mul symmetric", nat.mul, arg)
+ testFunSNN(t, "mul symmetric", nat.mul, arg)
}
}
@@ -163,8 +173,11 @@ var mulRangesN = []struct {
}
func TestMulRangeN(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
for i, r := range mulRangesN {
- prod := string(nat(nil).mulRange(r.a, r.b).utoa(10))
+ prod := string(nat(nil).mulRange(stk, r.a, r.b).utoa(10))
if prod != r.prod {
t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
}
@@ -185,11 +198,14 @@ func allocBytes(f func()) uint64 {
// does not cause deep recursion and in turn allocate too much memory.
// Test case for issue 3807.
func TestMulUnbalanced(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
x := rndNat(50000)
y := rndNat(40)
allocSize := allocBytes(func() {
- nat(nil).mul(x, y)
+ nat(nil).mul(stk, x, y)
})
inputSize := uint64(len(x)+len(y)) * _S
if ratio := allocSize / uint64(inputSize); ratio > 10 {
@@ -214,12 +230,15 @@ func rndNat1(n int) nat {
}
func BenchmarkMul(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
mulx := rndNat(1e4)
muly := rndNat(1e4)
b.ResetTimer()
for i := 0; i < b.N; i++ {
var z nat
- z.mul(mulx, muly)
+ z.mul(stk, mulx, muly)
}
}
@@ -230,7 +249,7 @@ func benchmarkNatMul(b *testing.B, nwords int) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
- z.mul(x, y)
+ z.mul(nil, x, y)
}
}
@@ -444,6 +463,9 @@ var montgomeryTests = []struct {
}
func TestMontgomery(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
one := NewInt(1)
_B := new(Int).Lsh(one, _W)
for i, test := range montgomeryTests {
@@ -458,11 +480,11 @@ func TestMontgomery(t *testing.T) {
}
if x.cmp(m) > 0 {
- _, r := nat(nil).div(nil, x, m)
+ _, r := nat(nil).div(stk, nil, x, m)
t.Errorf("#%d: x > m (0x%s > 0x%s; use 0x%s)", i, x.utoa(16), m.utoa(16), r.utoa(16))
}
if y.cmp(m) > 0 {
- _, r := nat(nil).div(nil, x, m)
+ _, r := nat(nil).div(stk, nil, x, m)
t.Errorf("#%d: y > m (0x%s > 0x%s; use 0x%s)", i, y.utoa(16), m.utoa(16), r.utoa(16))
}
@@ -538,6 +560,9 @@ var expNNTests = []struct {
}
func TestExpNN(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
for i, test := range expNNTests {
x := natFromString(test.x)
y := natFromString(test.y)
@@ -548,7 +573,7 @@ func TestExpNN(t *testing.T) {
m = natFromString(test.m)
}
- z := nat(nil).expNN(x, y, m, false)
+ z := nat(nil).expNN(stk, x, y, m, false)
if z.cmp(out) != 0 {
t.Errorf("#%d got %s want %s", i, z.utoa(10), out.utoa(10))
}
@@ -572,6 +597,9 @@ func FuzzExpMont(f *testing.F) {
}
func BenchmarkExp3Power(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
const x = 3
for _, y := range []Word{
0x10, 0x40, 0x100, 0x400, 0x1000, 0x4000, 0x10000, 0x40000, 0x100000, 0x400000,
@@ -579,7 +607,7 @@ func BenchmarkExp3Power(b *testing.B) {
b.Run(fmt.Sprintf("%#x", y), func(b *testing.B) {
var z nat
for i := 0; i < b.N; i++ {
- z.expWW(x, y)
+ z.expWW(stk, x, y)
}
})
}
@@ -712,10 +740,13 @@ func TestSticky(t *testing.T) {
}
func testSqr(t *testing.T, x nat) {
+ stk := getStack()
+ defer stk.free()
+
got := make(nat, 2*len(x))
want := make(nat, 2*len(x))
- got = got.sqr(x)
- want = want.mul(x, x)
+ got = got.sqr(stk, x)
+ want = want.mul(stk, x, x)
if got.cmp(want) != 0 {
t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
}
@@ -741,7 +772,7 @@ func benchmarkNatSqr(b *testing.B, nwords int) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
- z.sqr(x)
+ z.sqr(nil, x)
}
}
@@ -830,6 +861,9 @@ func BenchmarkNatSetBytes(b *testing.B) {
}
func TestNatDiv(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
sizes := []int{
1, 2, 5, 8, 15, 25, 40, 65, 100,
200, 500, 800, 1500, 2500, 4000, 6500, 10000,
@@ -849,11 +883,11 @@ func TestNatDiv(t *testing.T) {
c = c.norm()
}
// compute x = a*b+c
- x := nat(nil).mul(a, b)
+ x := nat(nil).mul(stk, a, b)
x = x.add(x, c)
var q, r nat
- q, r = q.div(r, x, b)
+ q, r = q.div(stk, r, x, b)
if q.cmp(a) != 0 {
t.Fatalf("wrong quotient: got %s; want %s for %s/%s", q.utoa(10), a.utoa(10), x.utoa(10), b.utoa(10))
}
@@ -868,6 +902,9 @@ func TestNatDiv(t *testing.T) {
// the inaccurate estimate of the first word's quotient
// happens at the very beginning of the loop.
func TestIssue37499(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
// Choose u and v such that v is slightly larger than u >> N.
// This tricks divBasic into choosing 1 as the first word
// of the quotient. This works in both 32-bit and 64-bit settings.
@@ -875,7 +912,7 @@ func TestIssue37499(t *testing.T) {
v := natFromString("0x2b6c385a05be027f5c22005b63c42a1165b79ff510e1706c")
q := nat(nil).make(8)
- q.divBasic(u, v)
+ q.divBasic(stk, u, v)
q = q.norm()
if s := string(q.utoa(16)); s != "fffffffffffffffffffffffffffffffffffffffffffffffb" {
t.Fatalf("incorrect quotient: %s", s)
@@ -886,8 +923,11 @@ func TestIssue37499(t *testing.T) {
// where the first division loop is never entered, and correcting
// the remainder takes exactly two iterations in the final loop.
func TestIssue42552(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
u := natFromString("0xc23b166884c3869092a520eceedeced2b00847bd256c9cf3b2c5e2227c15bd5e6ee7ef8a2f49236ad0eedf2c8a3b453cf6e0706f64285c526b372c4b1321245519d430540804a50b7ca8b6f1b34a2ec05cdbc24de7599af112d3e3c8db347e8799fe70f16e43c6566ba3aeb169463a3ecc486172deb2d9b80a3699c776e44fef20036bd946f1b4d054dd88a2c1aeb986199b0b2b7e58c42288824b74934d112fe1fc06e06b4d99fe1c5e725946b23210521e209cd507cce90b5f39a523f27e861f9e232aee50c3f585208b4573dcc0b897b6177f2ba20254fd5c50a033e849dee1b3a93bd2dc44ba8ca836cab2c2ae50e50b126284524fa0187af28628ff0face68d87709200329db1392852c8b8963fbe3d05fb1efe19f0ed5ca9fadc2f96f82187c24bb2512b2e85a66333a7e176605695211e1c8e0b9b9e82813e50654964945b1e1e66a90840396c7d10e23e47f364d2d3f660fa54598e18d1ca2ea4fe4f35a40a11f69f201c80b48eaee3e2e9b0eda63decf92bec08a70f731587d4ed0f218d5929285c8b2ccbc497e20db42de73885191fa453350335990184d8df805072f958d5354debda38f5421effaaafd6cb9b721ace74be0892d77679f62a4a126697cd35797f6858193da4ba1770c06aea2e5c59ec04b8ea26749e61b72ecdde403f3bc7e5e546cd799578cc939fa676dfd5e648576d4a06cbadb028adc2c0b461f145b2321f42e5e0f3b4fb898ecd461df07a6f5154067787bf74b5cc5c03704a1ce47494961931f0263b0aac32505102595957531a2de69dd71aac51f8a49902f81f21283dbe8e21e01e5d82517868826f86acf338d935aa6b4d5a25c8d540389b277dd9d64569d68baf0f71bd03dba45b92a7fc052601d1bd011a2fc6790a23f97c6fa5caeea040ab86841f268d39ce4f7caf01069df78bba098e04366492f0c2ac24f1bf16828752765fa523c9a4d42b71109d123e6be8c7b1ab3ccf8ea03404075fe1a9596f1bba1d267f9a7879ceece514818316c9c0583469d2367831fc42b517ea028a28df7c18d783d16ea2436cee2b15d52db68b5dfdee6b4d26f0905f9b030c911a04d078923a4136afea96eed6874462a482917353264cc9bee298f167ac65a6db4e4eda88044b39cc0b33183843eaa946564a00c3a0ab661f2c915e70bf0bb65bfbb6fa2eea20aed16bf2c1a1d00ec55fb4ff2f76b8e462ea70c19efa579c9ee78194b86708fdae66a9ce6e2cf3d366037798cfb50277ba6d2fd4866361022fd788ab7735b40b8b61d55e32243e06719e53992e9ac16c9c4b6e6933635c3c47c8f7e73e17dd54d0dd8aeba5d76de46894e7b3f9d3ec25ad78ee82297ba69905ea0fa094b8667faa2b8885e2187b3da80268aa1164761d7b0d6de206b676777348152b8ae1d4afed753bc63c739a5ca8ce7afb2b241a226bd9e502baba391b5b13f5054f070b65a9cf3a67063bfaa803ba390732cd03888f664023f888741d04d564e0b5674b0a183ace81452001b3fbb4214c77d42ca75376742c471e58f67307726d56a1032bd236610cbcbcd03d0d7a452900136897dc55bb3ce959d10d4e6a10fb635006bd8c41cd9ded2d3dfdd8f2e229590324a7370cb2124210b2330f4c56155caa09a2564932ceded8d92c79664dcdeb87faad7d3da006cc2ea267ee3df41e9677789cc5a8cc3b83add6491561b3047919e0648b1b2e97d7ad6f6c2aa80cab8e9ae10e1f75b1fdd0246151af709d259a6a0ed0b26bd711024965ecad7c41387de45443defce53f66612948694a6032279131c257119ed876a8e805dfb49576ef5c563574115ee87050d92d191bc761ef51d966918e2ef925639400069e3959d8fe19f36136e947ff430bf74e71da0aa5923b00000000")
v := natFromString("0x838332321d443a3d30373d47301d47073847473a383d3030f25b3d3d3e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002e00000000000000000041603038331c3d32f5303441e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e01c0a5459bfc7b9be9fcbb9d2383840464319434707303030f43a32f53034411c0a5459413820878787878787878787878787878787878787878787878787878787878787878787870630303a3a30334036605b923a6101f83638413943413960204337602043323801526040523241846038414143015238604060328452413841413638523c0240384141364036605b923a6101f83638413943413960204334602043323801526040523241846038414143015238604060328452413841413638523c02403841413638433030f25a8b83838383838383838383838383838383837d838383ffffffffffffffff838383838383838383000000000000000000030000007d26e27c7c8b83838383838383838383838383838383837d838383ffffffffffffffff83838383838383838383838383838383838383838383435960f535073030f3343200000000000000011881301938343030fa398383300000002300000000000000000000f11af4600c845252904141364138383c60406032414443095238010241414303364443434132305b595a15434160b042385341ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff47476043410536613603593a6005411c437405fcfcfcfcfcfcfc0000000000005a3b075815054359000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
q := nat(nil).make(16)
- q.div(q, u, v)
+ q.div(stk, q, u, v)
}
diff --git a/src/math/big/natconv.go b/src/math/big/natconv.go
index ce94f2cf72..8a47ec9f9c 100644
--- a/src/math/big/natconv.go
+++ b/src/math/big/natconv.go
@@ -321,17 +321,20 @@ func (x nat) itoa(neg bool, base int) []byte {
}
} else {
+ stk := getStack()
+ defer stk.free()
+
bb, ndigits := maxPow(b)
// construct table of successive squares of bb*leafSize to use in subdivisions
// result (table != nil) <=> (len(x) > leafSize > 0)
- table := divisors(len(x), b, ndigits, bb)
+ table := divisors(stk, len(x), b, ndigits, bb)
// preserve x, create local copy for use by convertWords
q := nat(nil).set(x)
// convert q to string s in base b
- q.convertWords(s, b, ndigits, bb, table)
+ q.convertWords(stk, s, b, ndigits, bb, table)
// strip leading zeros
// (x != 0; thus s must contain at least one non-zero digit
@@ -365,7 +368,7 @@ func (x nat) itoa(neg bool, base int) []byte {
// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and
// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for
// specific hardware.
-func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []divisor) {
+func (q nat) convertWords(stk *stack, s []byte, b Word, ndigits int, bb Word, table []divisor) {
// split larger blocks recursively
if table != nil {
// len(q) > leafSize > 0
@@ -386,12 +389,12 @@ func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []diviso
}
// split q into the two digit number (q'*bbb + r) to form independent subblocks
- q, r = q.div(r, q, table[index].bbb)
+ q, r = q.div(stk, r, q, table[index].bbb)
// convert subblocks and collect results in s[:h] and s[h:]
h := len(s) - table[index].ndigits
- r.convertWords(s[h:], b, ndigits, bb, table[0:index])
- s = s[:h] // == q.convertWords(s, b, ndigits, bb, table[0:index+1])
+ r.convertWords(stk, s[h:], b, ndigits, bb, table[0:index])
+ s = s[:h] // == q.convertWords(stk, s, b, ndigits, bb, table[0:index+1])
}
}
@@ -451,12 +454,12 @@ var cacheBase10 struct {
}
// expWW computes x**y
-func (z nat) expWW(x, y Word) nat {
- return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil, false)
+func (z nat) expWW(stk *stack, x, y Word) nat {
+ return z.expNN(stk, nat(nil).setWord(x), nat(nil).setWord(y), nil, false)
}
// construct table of powers of bb*leafSize to use in subdivisions.
-func divisors(m int, b Word, ndigits int, bb Word) []divisor {
+func divisors(stk *stack, m int, b Word, ndigits int, bb Word) []divisor {
// only compute table when recursive conversion is enabled and x is large
if leafSize == 0 || m <= leafSize {
return nil
@@ -484,10 +487,10 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor {
for i := 0; i < k; i++ {
if table[i].ndigits == 0 {
if i == 0 {
- table[0].bbb = nat(nil).expWW(bb, Word(leafSize))
+ table[0].bbb = nat(nil).expWW(stk, bb, Word(leafSize))
table[0].ndigits = ndigits * leafSize
} else {
- table[i].bbb = nat(nil).sqr(table[i-1].bbb)
+ table[i].bbb = nat(nil).sqr(stk, table[i-1].bbb)
table[i].ndigits = 2 * table[i-1].ndigits
}
diff --git a/src/math/big/natconv_test.go b/src/math/big/natconv_test.go
index d390272108..66300e412b 100644
--- a/src/math/big/natconv_test.go
+++ b/src/math/big/natconv_test.go
@@ -350,6 +350,9 @@ func BenchmarkStringPiParallel(b *testing.B) {
}
func BenchmarkScan(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
const x = 10
for _, base := range []int{2, 8, 10, 16} {
for _, y := range []Word{10, 100, 1000, 10000, 100000} {
@@ -359,7 +362,7 @@ func BenchmarkScan(b *testing.B) {
b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) {
b.StopTimer()
var z nat
- z = z.expWW(x, y)
+ z = z.expWW(stk, x, y)
s := z.utoa(base)
if t := itoa(z, base); !bytes.Equal(s, t) {
@@ -376,6 +379,9 @@ func BenchmarkScan(b *testing.B) {
}
func BenchmarkString(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
const x = 10
for _, base := range []int{2, 8, 10, 16} {
for _, y := range []Word{10, 100, 1000, 10000, 100000} {
@@ -385,7 +391,7 @@ func BenchmarkString(b *testing.B) {
b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) {
b.StopTimer()
var z nat
- z = z.expWW(x, y)
+ z = z.expWW(stk, x, y)
z.utoa(base) // warm divisor cache
b.StartTimer()
@@ -416,9 +422,11 @@ func LeafSizeHelper(b *testing.B, base, size int) {
for d := 1; d <= 10000; d *= 10 {
b.StopTimer()
+ stk := getStack()
var z nat
- z = z.expWW(Word(base), Word(d)) // build target number
- _ = z.utoa(base) // warm divisor cache
+ z = z.expWW(stk, Word(base), Word(d)) // build target number
+ _ = z.utoa(base) // warm divisor cache
+ stk.free()
b.StartTimer()
for i := 0; i < b.N; i++ {
@@ -443,13 +451,16 @@ func resetTable(table []divisor) {
}
func TestStringPowers(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
var p Word
for b := 2; b <= 16; b++ {
for p = 0; p <= 512; p++ {
if testing.Short() && p > 10 {
break
}
- x := nat(nil).expWW(Word(b), p)
+ x := nat(nil).expWW(stk, Word(b), p)
xs := x.utoa(b)
xs2 := itoa(x, b)
if !bytes.Equal(xs, xs2) {
diff --git a/src/math/big/natdiv.go b/src/math/big/natdiv.go
index 2e66e3425c..b514e2ce21 100644
--- a/src/math/big/natdiv.go
+++ b/src/math/big/natdiv.go
@@ -502,30 +502,24 @@ import "math/bits"
// rem returns r such that r = u%v.
// It uses z as the storage for r.
-func (z nat) rem(u, v nat) (r nat) {
+func (z nat) rem(stk *stack, u, v nat) (r nat) {
if alias(z, u) {
z = nil
}
- qp := getNat(0)
- q, r := qp.div(z, u, v)
- *qp = q
- putNat(qp)
+ defer stk.restore(stk.save())
+ q := stk.nat(len(u) - (len(v) - 1))
+ _, r = q.div(stk, z, u, v)
return r
}
// div returns q, r such that q = ⌊u/v⌋ and r = u%v = u - q·v.
// It uses z and z2 as the storage for q and r.
-func (z nat) div(z2, u, v nat) (q, r nat) {
+// The caller may pass stk == nil to request that div obtain and release one itself.
+func (z nat) div(stk *stack, z2, u, v nat) (q, r nat) {
if len(v) == 0 {
panic("division by zero")
}
- if u.cmp(v) < 0 {
- q = z[:0]
- r = z2.set(u)
- return
- }
-
if len(v) == 1 {
// Short division: long optimized for a single-word divisor.
// In that case, the 2-by-1 guess is all we need at each step.
@@ -535,7 +529,18 @@ func (z nat) div(z2, u, v nat) (q, r nat) {
return
}
- q, r = z.divLarge(z2, u, v)
+ if u.cmp(v) < 0 {
+ q = z[:0]
+ r = z2.set(u)
+ return
+ }
+
+ if stk == nil {
+ stk = getStack()
+ defer stk.free()
+ }
+
+ q, r = z.divLarge(stk, z2, u, v)
return
}
@@ -589,7 +594,7 @@ func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) {
// It uses z and u as the storage for q and r.
// The caller must ensure that len(vIn) ≥ 2 (use divW otherwise)
// and that len(uIn) ≥ len(vIn) (the answer is 0, uIn otherwise).
-func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
+func (z nat) divLarge(stk *stack, u, uIn, vIn nat) (q, r nat) {
n := len(vIn)
m := len(uIn) - n
@@ -597,9 +602,9 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
// vIn is treated as a read-only input (it may be in use by another
// goroutine), so we must make a copy.
// uIn is copied to u.
+ defer stk.restore(stk.save())
shift := nlz(vIn[n-1])
- vp := getNat(n)
- v := *vp
+ v := stk.nat(n)
shlVU(v, vIn, shift)
u = u.make(len(uIn) + 1)
u[len(uIn)] = shlVU(u[:len(uIn)], uIn, shift)
@@ -613,11 +618,10 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
// Use basic or recursive long division depending on size.
if n < divRecursiveThreshold {
- q.divBasic(u, v)
+ q.divBasic(stk, u, v)
} else {
- q.divRecursive(u, v)
+ q.divRecursive(stk, u, v)
}
- putNat(vp)
q = q.norm()
@@ -631,12 +635,12 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
// divBasic implements long division as described above.
// It overwrites q with ⌊u/v⌋ and overwrites u with the remainder r.
// q must be large enough to hold ⌊u/v⌋.
-func (q nat) divBasic(u, v nat) {
+func (q nat) divBasic(stk *stack, u, v nat) {
n := len(v)
m := len(u) - n
- qhatvp := getNat(n + 1)
- qhatv := *qhatvp
+ defer stk.restore(stk.save())
+ qhatv := stk.nat(n + 1)
// Set up for divWW below, precomputing reciprocal argument.
vn1 := v[n-1]
@@ -707,8 +711,6 @@ func (q nat) divBasic(u, v nat) {
}
q[j] = qhat
}
-
- putNat(qhatvp)
}
// greaterThan reports whether the two digit numbers x1 x2 > y1 y2.
@@ -727,24 +729,9 @@ const divRecursiveThreshold = 100
// z must be large enough to hold ⌊u/v⌋.
// This function is just for allocating and freeing temporaries
// around divRecursiveStep, the real implementation.
-func (z nat) divRecursive(u, v nat) {
- // Recursion depth is (much) less than 2 log₂(len(v)).
- // Allocate a slice of temporaries to be reused across recursion,
- // plus one extra temporary not live across the recursion.
- recDepth := 2 * bits.Len(uint(len(v)))
- tmp := getNat(3 * len(v))
- temps := make([]*nat, recDepth)
-
+func (z nat) divRecursive(stk *stack, u, v nat) {
clear(z)
- z.divRecursiveStep(u, v, 0, tmp, temps)
-
- // Free temporaries.
- for _, n := range temps {
- if n != nil {
- putNat(n)
- }
- }
- putNat(tmp)
+ z.divRecursiveStep(stk, u, v, 0)
}
// divRecursiveStep is the actual implementation of recursive division.
@@ -752,7 +739,7 @@ func (z nat) divRecursive(u, v nat) {
// z must be large enough to hold ⌊u/v⌋.
// It uses temps[depth] (allocating if needed) as a temporary live across
// the recursive call. It also uses tmp, but not live across the recursion.
-func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
+func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) {
// u is a subsection of the original and may have leading zeros.
// TODO(rsc): The v = v.norm() is useless and should be removed.
// We know (and require) that v's top digit is ≥ B/2.
@@ -766,7 +753,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
// Fall back to basic division if the problem is now small enough.
n := len(v)
if n < divRecursiveThreshold {
- z.divBasic(u, v)
+ z.divBasic(stk, u, v)
return
}
@@ -785,11 +772,8 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
B := n / 2
// Allocate a nat for qhat below.
- if temps[depth] == nil {
- temps[depth] = getNat(n) // TODO(rsc): Can be just B+1.
- } else {
- *temps[depth] = temps[depth].make(B + 1)
- }
+ defer stk.restore(stk.save())
+ qhat0 := stk.nat(B + 1)
// Compute each wide digit of the quotient.
//
@@ -816,9 +800,9 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
uu := u[j-B:]
// Compute the 2-by-1 guess q̂, leaving r̂ in uu[s:B+n].
- qhat := *temps[depth]
+ qhat := qhat0
clear(qhat)
- qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps)
+ qhat.divRecursiveStep(stk, uu[s:B+n], v[s:], depth+1)
qhat = qhat.norm()
// Extend to a 3-by-2 quotient and remainder.
@@ -833,9 +817,10 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
// q̂·vₙ₋₂ and decrementing q̂ until that product is ≤ u.
// But we can do the subtraction directly, as in the comment above
// and in long division, because we know that q̂ is wrong by at most one.
- qhatv := tmp.make(3 * n)
+ mark := stk.save()
+ qhatv := stk.nat(3 * n)
clear(qhatv)
- qhatv = qhatv.mul(qhat, v[:s])
+ qhatv = qhatv.mul(stk, qhat, v[:s])
for i := 0; i < 2; i++ {
e := qhatv.cmp(uu.norm())
if e <= 0 {
@@ -857,6 +842,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
}
addAt(z, qhat, j-B)
j -= B
+ stk.restore(mark)
}
// TODO(rsc): Rewrite loop as described above and delete all this code.
@@ -864,13 +850,13 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
// Now u < (v<<B), compute lower bits in the same way.
// Choose shift = B-1 again.
s := B - 1
- qhat := *temps[depth]
+ qhat := qhat0
clear(qhat)
- qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps)
+ qhat.divRecursiveStep(stk, u[s:].norm(), v[s:], depth+1)
qhat = qhat.norm()
- qhatv := tmp.make(3 * n)
+ qhatv := stk.nat(3 * n)
clear(qhatv)
- qhatv = qhatv.mul(qhat, v[:s])
+ qhatv = qhatv.mul(stk, qhat, v[:s])
// Set the correct remainder as before.
for i := 0; i < 2; i++ {
if e := qhatv.cmp(u.norm()); e > 0 {
diff --git a/src/math/big/prime.go b/src/math/big/prime.go
index 26688bbd64..bba5a07685 100644
--- a/src/math/big/prime.go
+++ b/src/math/big/prime.go
@@ -75,7 +75,9 @@ func (x *Int) ProbablyPrime(n int) bool {
return false
}
- return x.abs.probablyPrimeMillerRabin(n+1, true) && x.abs.probablyPrimeLucas()
+ stk := getStack()
+ defer stk.free()
+ return x.abs.probablyPrimeMillerRabin(stk, n+1, true) && x.abs.probablyPrimeLucas(stk)
}
// probablyPrimeMillerRabin reports whether n passes reps rounds of the
@@ -83,7 +85,7 @@ func (x *Int) ProbablyPrime(n int) bool {
// If force2 is true, one of the rounds is forced to use base 2.
// See Handbook of Applied Cryptography, p. 139, Algorithm 4.24.
// The number n is known to be non-zero.
-func (n nat) probablyPrimeMillerRabin(reps int, force2 bool) bool {
+func (n nat) probablyPrimeMillerRabin(stk *stack, reps int, force2 bool) bool {
nm1 := nat(nil).sub(n, natOne)
// determine q, k such that nm1 = q << k
k := nm1.trailingZeroBits()
@@ -103,13 +105,13 @@ NextRandom:
x = x.random(rand, nm3, nm3Len)
x = x.add(x, natTwo)
}
- y = y.expNN(x, q, n, false)
+ y = y.expNN(stk, x, q, n, false)
if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
continue
}
for j := uint(1); j < k; j++ {
- y = y.sqr(y)
- quotient, y = quotient.div(y, y, n)
+ y = y.sqr(stk, y)
+ quotient, y = quotient.div(stk, y, y, n)
if y.cmp(nm1) == 0 {
continue NextRandom
}
@@ -147,7 +149,7 @@ NextRandom:
//
// Crandall and Pomerance, Prime Numbers: A Computational Perspective, 2nd ed.
// Springer, 2005.
-func (n nat) probablyPrimeLucas() bool {
+func (n nat) probablyPrimeLucas(stk *stack) bool {
// Discard 0, 1.
if len(n) == 0 || n.cmp(natOne) == 0 {
return false
@@ -193,8 +195,8 @@ func (n nat) probablyPrimeLucas() bool {
// We'll never find (d/n) = -1 if n is a square.
// If n is a non-square we expect to find a d in just a few attempts on average.
// After 40 attempts, take a moment to check if n is indeed a square.
- t1 = t1.sqrt(n)
- t1 = t1.sqr(t1)
+ t1 = t1.sqrt(stk, n)
+ t1 = t1.sqr(stk, t1)
if t1.cmp(n) == 0 {
return false
}
@@ -254,25 +256,25 @@ func (n nat) probablyPrimeLucas() bool {
if s.bit(uint(i)) != 0 {
// k' = 2k+1
// V(k') = V(2k+1) = V(k) V(k+1) - P.
- t1 = t1.mul(vk, vk1)
+ t1 = t1.mul(stk, vk, vk1)
t1 = t1.add(t1, n)
t1 = t1.sub(t1, natP)
- t2, vk = t2.div(vk, t1, n)
+ t2, vk = t2.div(stk, vk, t1, n)
// V(k'+1) = V(2k+2) = V(k+1)² - 2.
- t1 = t1.sqr(vk1)
+ t1 = t1.sqr(stk, vk1)
t1 = t1.add(t1, nm2)
- t2, vk1 = t2.div(vk1, t1, n)
+ t2, vk1 = t2.div(stk, vk1, t1, n)
} else {
// k' = 2k
// V(k'+1) = V(2k+1) = V(k) V(k+1) - P.
- t1 = t1.mul(vk, vk1)
+ t1 = t1.mul(stk, vk, vk1)
t1 = t1.add(t1, n)
t1 = t1.sub(t1, natP)
- t2, vk1 = t2.div(vk1, t1, n)
+ t2, vk1 = t2.div(stk, vk1, t1, n)
// V(k') = V(2k) = V(k)² - 2
- t1 = t1.sqr(vk)
+ t1 = t1.sqr(stk, vk)
t1 = t1.add(t1, nm2)
- t2, vk = t2.div(vk, t1, n)
+ t2, vk = t2.div(stk, vk, t1, n)
}
}
@@ -285,7 +287,7 @@ func (n nat) probablyPrimeLucas() bool {
//
// Since we are checking for U(k) == 0 it suffices to check 2 V(k+1) == P V(k) mod n,
// or P V(k) - 2 V(k+1) == 0 mod n.
- t1 := t1.mul(vk, natP)
+ t1 := t1.mul(stk, vk, natP)
t2 := t2.shl(vk1, 1)
if t1.cmp(t2) < 0 {
t1, t2 = t2, t1
@@ -294,7 +296,7 @@ func (n nat) probablyPrimeLucas() bool {
t3 := vk1 // steal vk1, no longer needed below
vk1 = nil
_ = vk1
- t2, t3 = t2.div(t3, t1, n)
+ t2, t3 = t2.div(stk, t3, t1, n)
if len(t3) == 0 {
return true
}
@@ -312,9 +314,9 @@ func (n nat) probablyPrimeLucas() bool {
}
// k' = 2k
// V(k') = V(2k) = V(k)² - 2
- t1 = t1.sqr(vk)
+ t1 = t1.sqr(stk, vk)
t1 = t1.sub(t1, natTwo)
- t2, vk = t2.div(vk, t1, n)
+ t2, vk = t2.div(stk, vk, t1, n)
}
return false
}
diff --git a/src/math/big/prime_test.go b/src/math/big/prime_test.go
index 8596e33a13..2b1995bcb2 100644
--- a/src/math/big/prime_test.go
+++ b/src/math/big/prime_test.go
@@ -159,6 +159,9 @@ func TestProbablyPrime(t *testing.T) {
}
func BenchmarkProbablyPrime(b *testing.B) {
+ stk := getStack()
+ defer stk.free()
+
p, _ := new(Int).SetString("203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123", 10)
for _, n := range []int{0, 1, 5, 10, 20} {
b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
@@ -170,26 +173,32 @@ func BenchmarkProbablyPrime(b *testing.B) {
b.Run("Lucas", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- p.abs.probablyPrimeLucas()
+ p.abs.probablyPrimeLucas(stk)
}
})
b.Run("MillerRabinBase2", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- p.abs.probablyPrimeMillerRabin(1, true)
+ p.abs.probablyPrimeMillerRabin(stk, 1, true)
}
})
}
func TestMillerRabinPseudoprimes(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
testPseudoprimes(t, "probablyPrimeMillerRabin",
- func(n nat) bool { return n.probablyPrimeMillerRabin(1, true) && !n.probablyPrimeLucas() },
+ func(n nat) bool { return n.probablyPrimeMillerRabin(stk, 1, true) && !n.probablyPrimeLucas(stk) },
// https://oeis.org/A001262
[]int{2047, 3277, 4033, 4681, 8321, 15841, 29341, 42799, 49141, 52633, 65281, 74665, 80581, 85489, 88357, 90751})
}
func TestLucasPseudoprimes(t *testing.T) {
+ stk := getStack()
+ defer stk.free()
+
testPseudoprimes(t, "probablyPrimeLucas",
- func(n nat) bool { return n.probablyPrimeLucas() && !n.probablyPrimeMillerRabin(1, true) },
+ func(n nat) bool { return n.probablyPrimeLucas(stk) && !n.probablyPrimeMillerRabin(stk, 1, true) },
// https://oeis.org/A217719
[]int{989, 3239, 5777, 10877, 27971, 29681, 30739, 31631, 39059, 72389, 73919, 75077})
}
diff --git a/src/math/big/rat.go b/src/math/big/rat.go
index e58433ecea..ac94056a83 100644
--- a/src/math/big/rat.go
+++ b/src/math/big/rat.go
@@ -74,7 +74,7 @@ func (z *Rat) SetFloat64(f float64) *Rat {
// nearest to the quotient a/b, using round-to-even in
// halfway cases. It does not mutate its arguments.
// Preconditions: b is non-zero; a and b have no common factors.
-func quotToFloat32(a, b nat) (f float32, exact bool) {
+func quotToFloat32(stk *stack, a, b nat) (f float32, exact bool) {
const (
// float size in bits
Fsize = 32
@@ -121,7 +121,7 @@ func quotToFloat32(a, b nat) (f float32, exact bool) {
// extra shift, the low-order bit of q is logically the
// high-order bit of r.
var q nat
- q, r := q.div(a2, a2, b2) // (recycle a2)
+ q, r := q.div(stk, a2, a2, b2) // (recycle a2)
mantissa := low32(q)
haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
@@ -172,7 +172,7 @@ func quotToFloat32(a, b nat) (f float32, exact bool) {
// nearest to the quotient a/b, using round-to-even in
// halfway cases. It does not mutate its arguments.
// Preconditions: b is non-zero; a and b have no common factors.
-func quotToFloat64(a, b nat) (f float64, exact bool) {
+func quotToFloat64(stk *stack, a, b nat) (f float64, exact bool) {
const (
// float size in bits
Fsize = 64
@@ -219,7 +219,7 @@ func quotToFloat64(a, b nat) (f float64, exact bool) {
// extra shift, the low-order bit of q is logically the
// high-order bit of r.
var q nat
- q, r := q.div(a2, a2, b2) // (recycle a2)
+ q, r := q.div(stk, a2, a2, b2) // (recycle a2)
mantissa := low64(q)
haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
@@ -275,7 +275,9 @@ func (x *Rat) Float32() (f float32, exact bool) {
if len(b) == 0 {
b = natOne
}
- f, exact = quotToFloat32(x.a.abs, b)
+ stk := getStack()
+ defer stk.free()
+ f, exact = quotToFloat32(stk, x.a.abs, b)
if x.a.neg {
f = -f
}
@@ -291,7 +293,9 @@ func (x *Rat) Float64() (f float64, exact bool) {
if len(b) == 0 {
b = natOne
}
- f, exact = quotToFloat64(x.a.abs, b)
+ stk := getStack()
+ defer stk.free()
+ f, exact = quotToFloat64(stk, x.a.abs, b)
if x.a.neg {
f = -f
}
@@ -437,12 +441,14 @@ func (z *Rat) norm() *Rat {
z.b.abs = z.b.abs.setWord(1)
default:
// z is fraction; normalize numerator and denominator
+ stk := getStack()
+ defer stk.free()
neg := z.a.neg
z.a.neg = false
z.b.neg = false
if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 {
- z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs)
- z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs)
+ z.a.abs, _ = z.a.abs.div(stk, nil, z.a.abs, f.abs)
+ z.b.abs, _ = z.b.abs.div(stk, nil, z.b.abs, f.abs)
}
z.a.neg = neg
}
@@ -452,7 +458,7 @@ func (z *Rat) norm() *Rat {
// mulDenom sets z to the denominator product x*y (by taking into
// account that 0 values for x or y must be interpreted as 1) and
// returns z.
-func mulDenom(z, x, y nat) nat {
+func mulDenom(stk *stack, z, x, y nat) nat {
switch {
case len(x) == 0 && len(y) == 0:
return z.setWord(1)
@@ -461,17 +467,17 @@ func mulDenom(z, x, y nat) nat {
case len(y) == 0:
return z.set(x)
}
- return z.mul(x, y)
+ return z.mul(stk, x, y)
}
// scaleDenom sets z to the product x*f.
// If f == 0 (zero value of denominator), z is set to (a copy of) x.
-func (z *Int) scaleDenom(x *Int, f nat) {
+func (z *Int) scaleDenom(stk *stack, x *Int, f nat) {
if len(f) == 0 {
z.Set(x)
return
}
- z.abs = z.abs.mul(x.abs, f)
+ z.abs = z.abs.mul(stk, x.abs, f)
z.neg = x.neg
}
@@ -481,58 +487,73 @@ func (z *Int) scaleDenom(x *Int, f nat) {
// - +1 if x > y.
func (x *Rat) Cmp(y *Rat) int {
var a, b Int
- a.scaleDenom(&x.a, y.b.abs)
- b.scaleDenom(&y.a, x.b.abs)
+ stk := getStack()
+ defer stk.free()
+ a.scaleDenom(stk, &x.a, y.b.abs)
+ b.scaleDenom(stk, &y.a, x.b.abs)
return a.Cmp(&b)
}
// Add sets z to the sum x+y and returns z.
func (z *Rat) Add(x, y *Rat) *Rat {
+ stk := getStack()
+ defer stk.free()
+
var a1, a2 Int
- a1.scaleDenom(&x.a, y.b.abs)
- a2.scaleDenom(&y.a, x.b.abs)
+ a1.scaleDenom(stk, &x.a, y.b.abs)
+ a2.scaleDenom(stk, &y.a, x.b.abs)
z.a.Add(&a1, &a2)
- z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
+ z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Sub sets z to the difference x-y and returns z.
func (z *Rat) Sub(x, y *Rat) *Rat {
+ stk := getStack()
+ defer stk.free()
+
var a1, a2 Int
- a1.scaleDenom(&x.a, y.b.abs)
- a2.scaleDenom(&y.a, x.b.abs)
+ a1.scaleDenom(stk, &x.a, y.b.abs)
+ a2.scaleDenom(stk, &y.a, x.b.abs)
z.a.Sub(&a1, &a2)
- z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
+ z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Mul sets z to the product x*y and returns z.
func (z *Rat) Mul(x, y *Rat) *Rat {
+ stk := getStack()
+ defer stk.free()
+
if x == y {
// a squared Rat is positive and can't be reduced (no need to call norm())
z.a.neg = false
- z.a.abs = z.a.abs.sqr(x.a.abs)
+ z.a.abs = z.a.abs.sqr(stk, x.a.abs)
if len(x.b.abs) == 0 {
z.b.abs = z.b.abs.setWord(1)
} else {
- z.b.abs = z.b.abs.sqr(x.b.abs)
+ z.b.abs = z.b.abs.sqr(stk, x.b.abs)
}
return z
}
- z.a.Mul(&x.a, &y.a)
- z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
+
+ z.a.mul(stk, &x.a, &y.a)
+ z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Quo sets z to the quotient x/y and returns z.
// If y == 0, Quo panics.
func (z *Rat) Quo(x, y *Rat) *Rat {
+ stk := getStack()
+ defer stk.free()
+
if len(y.a.abs) == 0 {
panic("division by zero")
}
var a, b Int
- a.scaleDenom(&x.a, y.b.abs)
- b.scaleDenom(&y.a, x.b.abs)
+ a.scaleDenom(stk, &x.a, y.b.abs)
+ b.scaleDenom(stk, &y.a, x.b.abs)
z.a.abs = a.abs
z.b.abs = b.abs
z.a.neg = a.neg != b.neg
diff --git a/src/math/big/ratconv.go b/src/math/big/ratconv.go
index 12f9888c37..84602ff455 100644
--- a/src/math/big/ratconv.go
+++ b/src/math/big/ratconv.go
@@ -163,6 +163,9 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
}
// exp consumed - not needed anymore
+ stk := getStack()
+ defer stk.free()
+
// apply exp5 contributions
// (start with exp5 so the numbers to multiply are smaller)
if exp5 != 0 {
@@ -178,9 +181,9 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
if n > 1e6 {
return nil, false // avoid excessively large exponents
}
- pow5 := z.b.abs.expNN(natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs
+ pow5 := z.b.abs.expNN(stk, natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs
if exp5 > 0 {
- z.a.abs = z.a.abs.mul(z.a.abs, pow5)
+ z.a.abs = z.a.abs.mul(stk, z.a.abs, pow5)
z.b.abs = z.b.abs.setWord(1)
} else {
z.b.abs = pow5
@@ -343,15 +346,17 @@ func (x *Rat) FloatString(prec int) string {
}
// x.b.abs != 0
- q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs)
+ stk := getStack()
+ defer stk.free()
+ q, r := nat(nil).div(stk, nat(nil), x.a.abs, x.b.abs)
p := natOne
if prec > 0 {
- p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil, false)
+ p = nat(nil).expNN(stk, natTen, nat(nil).setUint64(uint64(prec)), nil, false)
}
- r = r.mul(r, p)
- r, r2 := r.div(nat(nil), r, x.b.abs)
+ r = r.mul(stk, r, p)
+ r, r2 := r.div(stk, nat(nil), r, x.b.abs)
// see if we need to round up
r2 = r2.add(r2, r2)
@@ -398,6 +403,9 @@ func (x *Rat) FloatString(prec int) string {
// 1/4 2 true 0.25
// 1/6 1 false 0.2 (0.166... rounded)
func (x *Rat) FloatPrec() (n int, exact bool) {
+ stk := getStack()
+ defer stk.free()
+
// Determine q and largest p2, p5 such that d = q·2^p2·5^p5.
// The results n, exact are:
//
@@ -425,11 +433,11 @@ func (x *Rat) FloatPrec() (n int, exact bool) {
f := nat{1220703125} // == 5^fp (must fit into a uint32 Word)
var t, r nat // temporaries
for {
- if _, r = t.div(r, q, f); len(r) != 0 {
+ if _, r = t.div(stk, r, q, f); len(r) != 0 {
break // f doesn't divide q evenly
}
tab = append(tab, f)
- f = nat(nil).sqr(f) // nat(nil) to ensure a new f for each table entry
+ f = nat(nil).sqr(stk, f) // nat(nil) to ensure a new f for each table entry
}
// Factor q using the table entries, if any.
@@ -441,7 +449,7 @@ func (x *Rat) FloatPrec() (n int, exact bool) {
// The same reasoning applies to the subsequent factors.
var p5 uint
for i := len(tab) - 1; i >= 0; i-- {
- if t, r = t.div(r, q, tab[i]); len(r) == 0 {
+ if t, r = t.div(stk, r, q, tab[i]); len(r) == 0 {
p5 += fp * (1 << i) // tab[i] == 5^(fp·2^i)
q = q.set(t)
}
@@ -449,7 +457,7 @@ func (x *Rat) FloatPrec() (n int, exact bool) {
// If fp != 1, we may still have multiples of 5 left.
for {
- if t, r = t.div(r, q, natFive); len(r) != 0 {
+ if t, r = t.div(stk, r, q, natFive); len(r) != 0 {
break
}
p5++