diff options
| author | Rémy Oudompheng <remyoudompheng@gmail.com> | 2019-04-14 08:16:13 +0200 |
|---|---|---|
| committer | Robert Griesemer <gri@golang.org> | 2019-11-12 05:18:25 +0000 |
| commit | 194ae3236d81cf16dc39b955efc1b9202b59d067 (patch) | |
| tree | 246ff4032a9c72642e78782c018abc18d8a5e749 /src/math | |
| parent | b8cb75fb17511538524ac304abe68e62699c4e23 (diff) | |
| download | go-194ae3236d81cf16dc39b955efc1b9202b59d067.tar.xz | |
math/big: implement recursive algorithm for division
The current division algorithm produces one word of result at a time,
using 2-word division to compute the top word and mulAddVWW to compute
the remainder. The top word may need to be adjusted by 1 or 2 units.
The recursive version, based on Burnikel, Ziegler, "Fast Recursive Division",
uses the same principles, but in a multi-word setting, so that
multiplication benefits from the Karatsuba algorithm (and possibly later
improvements).
benchmark old ns/op new ns/op delta
BenchmarkDiv/20/10-4 38.2 38.3 +0.26%
BenchmarkDiv/40/20-4 38.7 38.5 -0.52%
BenchmarkDiv/100/50-4 62.5 62.6 +0.16%
BenchmarkDiv/200/100-4 238 259 +8.82%
BenchmarkDiv/400/200-4 311 338 +8.68%
BenchmarkDiv/1000/500-4 604 649 +7.45%
BenchmarkDiv/2000/1000-4 1214 1278 +5.27%
BenchmarkDiv/20000/10000-4 38279 36510 -4.62%
BenchmarkDiv/200000/100000-4 3022057 1359615 -55.01%
BenchmarkDiv/2000000/1000000-4 310827664 54012939 -82.62%
BenchmarkDiv/20000000/10000000-4 33272829421 1965401359 -94.09%
BenchmarkString/10/Base10-4 158 156 -1.27%
BenchmarkString/100/Base10-4 797 792 -0.63%
BenchmarkString/1000/Base10-4 3677 3814 +3.73%
BenchmarkString/10000/Base10-4 16633 17116 +2.90%
BenchmarkString/100000/Base10-4 5779029 1793808 -68.96%
BenchmarkString/1000000/Base10-4 889840820 85524031 -90.39%
BenchmarkString/10000000/Base10-4 134338236860 4935657026 -96.33%
Fixes #21960
Updates #30943
Change-Id: I134c6f81a47870c688ca95b6081eb9211def15a2
Reviewed-on: https://go-review.googlesource.com/c/go/+/172018
Reviewed-by: Robert Griesemer <gri@golang.org>
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Diffstat (limited to 'src/math')
| -rw-r--r-- | src/math/big/int_test.go | 7 | ||||
| -rw-r--r-- | src/math/big/nat.go | 196 | ||||
| -rw-r--r-- | src/math/big/nat_test.go | 24 |
3 files changed, 216 insertions, 11 deletions
diff --git a/src/math/big/int_test.go b/src/math/big/int_test.go index a4285f3239..e3a1587b3f 100644 --- a/src/math/big/int_test.go +++ b/src/math/big/int_test.go @@ -1829,8 +1829,11 @@ func benchmarkDiv(b *testing.B, aSize, bSize int) { } func BenchmarkDiv(b *testing.B) { - min, max, step := 10, 100000, 10 - for i := min; i <= max; i *= step { + sizes := []int{ + 10, 20, 50, 100, 200, 500, 1000, + 1e4, 1e5, 1e6, 1e7, + } + for _, i := range sizes { j := 2 * i b.Run(fmt.Sprintf("%d/%d", j, i), func(b *testing.B) { benchmarkDiv(b, j, i) diff --git a/src/math/big/nat.go b/src/math/big/nat.go index 3b60232075..6667319100 100644 --- a/src/math/big/nat.go +++ b/src/math/big/nat.go @@ -693,7 +693,7 @@ func putNat(x *nat) { var natPool sync.Pool -// q = (uIn-r)/vIn, with 0 <= r < y +// q = (uIn-r)/vIn, with 0 <= r < vIn // Uses z as storage for q, and u as storage for r if possible. // See Knuth, Volume 2, section 4.3.1, Algorithm D. // Preconditions: @@ -721,6 +721,30 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { } q = z.make(m + 1) + if n < divRecursiveThreshold { + q.divBasic(u, v) + } else { + q.divRecursive(u, v) + } + putNat(vp) + + q = q.norm() + shrVU(u, u, shift) + r = u.norm() + + return q, r +} + +// divBasic performs word-by-word division of u by v. +// The quotient is written in pre-allocated q. +// The remainder overwrites input u. +// +// Precondition: +// - len(q) >= len(u)-len(v) +func (q nat) divBasic(u, v nat) { + n := len(v) + m := len(u) - n + qhatvp := getNat(n + 1) qhatv := *qhatvp @@ -729,7 +753,11 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { for j := m; j >= 0; j-- { // D3. qhat := Word(_M) - if ujn := u[j+n]; ujn != vn1 { + var ujn Word + if j+n < len(u) { + ujn = u[j+n] + } + if ujn != vn1 { var rhat Word qhat, rhat = divWW(ujn, u[j+n-1], vn1) @@ -752,25 +780,175 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { // D4. qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) - - c := subVV(u[j:j+len(qhatv)], u[j:], qhatv) + qhl := len(qhatv) + if j+qhl > len(u) && qhatv[n] == 0 { + qhl-- + } + c := subVV(u[j:j+qhl], u[j:], qhatv) if c != 0 { c := addVV(u[j:j+n], u[j:], v) u[j+n] += c qhat-- } + if j == m && m == len(q) && qhat == 0 { + continue + } q[j] = qhat } - putNat(vp) putNat(qhatvp) +} - q = q.norm() - shrVU(u, u, shift) - r = u.norm() +const divRecursiveThreshold = 100 - return q, r +// divRecursive performs word-by-word division of u by v. +// The quotient is written in pre-allocated z. +// The remainder overwrites input u. +// +// Precondition: +// - len(z) >= len(u)-len(v) +// +// See Burnikel, Ziegler, "Fast Recursive Division", Algorithm 1 and 2. +func (z nat) divRecursive(u, v nat) { + // Recursion depth is less than 2 log2(len(v)) + // Allocate a slice of temporaries to be reused across recursion. + recDepth := 2 * bits.Len(uint(len(v))) + // large enough to perform Karatsuba on operands as large as v + tmp := getNat(3 * len(v)) + temps := make([]*nat, recDepth) + z.clear() + z.divRecursiveStep(u, v, 0, tmp, temps) + for _, n := range temps { + if n != nil { + putNat(n) + } + } + putNat(tmp) +} + +func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { + u = u.norm() + v = v.norm() + + if len(u) == 0 { + z.clear() + return + } + n := len(v) + if n < divRecursiveThreshold { + z.divBasic(u, v) + return + } + m := len(u) - n + if m < 0 { + return + } + + // Produce the quotient by blocks of B words. + // Division by v (length n) is done using a length n/2 division + // and a length n/2 multiplication for each block. The final + // complexity is driven by multiplication complexity. + B := n / 2 + + // Allocate a nat for qhat below. + if temps[depth] == nil { + temps[depth] = getNat(n) + } else { + *temps[depth] = temps[depth].make(B + 1) + } + + j := m + for j > B { + // Divide u[j-B:j+n] by vIn. Keep remainder in u + // for next block. + // + // The following property will be used (Lemma 2): + // if u = u1 << s + u0 + // v = v1 << s + v0 + // then floor(u1/v1) >= floor(u/v) + // + // Moreover, the difference is at most 2 if len(v1) >= len(u/v) + // We choose s = B-1 since len(v)-B >= B+1 >= len(u/v) + s := (B - 1) + // Except for the first step, the top bits are always + // a division remainder, so the quotient length is <= n. + uu := u[j-B:] + + qhat := *temps[depth] + qhat.clear() + qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps) + qhat = qhat.norm() + // Adjust the quotient: + // u = u_h << s + u_l + // v = v_h << s + v_l + // u_h = q̂ v_h + rh + // u = q̂ (v - v_l) + rh << s + u_l + // After the above step, u contains a remainder: + // u = rh << s + u_l + // and we need to substract q̂ v_l + // + // But it may be a bit too large, in which case q̂ needs to be smaller. + qhatv := tmp.make(3 * n) + qhatv.clear() + qhatv = qhatv.mul(qhat, v[:s]) + for i := 0; i < 2; i++ { + e := qhatv.cmp(uu.norm()) + if e <= 0 { + break + } + subVW(qhat, qhat, 1) + c := subVV(qhatv[:s], qhatv[:s], v[:s]) + if len(qhatv) > s { + subVW(qhatv[s:], qhatv[s:], c) + } + addAt(uu[s:], v[s:], 0) + } + if qhatv.cmp(uu.norm()) > 0 { + panic("impossible") + } + c := subVV(uu[:len(qhatv)], uu[:len(qhatv)], qhatv) + if c > 0 { + subVW(uu[len(qhatv):], uu[len(qhatv):], c) + } + addAt(z, qhat, j-B) + j -= B + } + + // Now u < (v<<B), compute lower bits in the same way. + // Choose shift = B-1 again. + s := B + qhat := *temps[depth] + qhat.clear() + qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps) + qhat = qhat.norm() + qhatv := tmp.make(3 * n) + qhatv.clear() + qhatv = qhatv.mul(qhat, v[:s]) + // Set the correct remainder as before. + for i := 0; i < 2; i++ { + if e := qhatv.cmp(u.norm()); e > 0 { + subVW(qhat, qhat, 1) + c := subVV(qhatv[:s], qhatv[:s], v[:s]) + if len(qhatv) > s { + subVW(qhatv[s:], qhatv[s:], c) + } + addAt(u[s:], v[s:], 0) + } + } + if qhatv.cmp(u.norm()) > 0 { + panic("impossible") + } + c := subVV(u[0:len(qhatv)], u[0:len(qhatv)], qhatv) + if c > 0 { + c = subVW(u[len(qhatv):B+n], u[len(qhatv):B+n], c) + } + if c > 0 { + panic("impossible") + } + + // Done! + addAt(z, qhat.norm(), 0) } // Length of x in bits. x must be normalized. diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go index bb5e14b5fa..da34e95c1f 100644 --- a/src/math/big/nat_test.go +++ b/src/math/big/nat_test.go @@ -739,3 +739,27 @@ func BenchmarkNatSetBytes(b *testing.B) { }) } } + +func TestNatDiv(t *testing.T) { + sizes := []int{ + 1, 2, 5, 8, 15, 25, 40, 65, 100, + 200, 500, 800, 1500, 2500, 4000, 6500, 10000, + } + for _, i := range sizes { + for _, j := range sizes { + a := rndNat(i) + b := rndNat(j) + x := nat(nil).mul(a, b) + addVW(x, x, 1) + + var q, r nat + q, r = q.div(r, x, b) + if q.cmp(a) != 0 { + t.Fatal("wrong quotient", i, j) + } + if len(r) != 1 || r[0] != 1 { + t.Fatal("wrong remainder") + } + } + } +} |
