aboutsummaryrefslogtreecommitdiff
path: root/src/crypto/internal/bigmod/nat.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/internal/bigmod/nat.go')
-rw-r--r--src/crypto/internal/bigmod/nat.go75
1 files changed, 52 insertions, 23 deletions
diff --git a/src/crypto/internal/bigmod/nat.go b/src/crypto/internal/bigmod/nat.go
index 5605e9f1c3..7fdd8ef177 100644
--- a/src/crypto/internal/bigmod/nat.go
+++ b/src/crypto/internal/bigmod/nat.go
@@ -318,14 +318,48 @@ type Modulus struct {
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
func rr(m *Modulus) *Nat {
rr := NewNat().ExpandFor(m)
- // R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
- // most significant limb to 1. We then get to R*R by shifting left by _W
- // n + 1 times.
- n := len(rr.limbs)
- rr.limbs[n-1] = 1
- for i := n - 1; i < 2*n; i++ {
- rr.shiftIn(0, m) // x = x * 2^_W mod m
+ n := uint(len(rr.limbs))
+ mLen := uint(m.BitLen())
+ logR := _W * n
+
+ // We start by computing R = 2^(_W * n) mod m. We can get pretty close, to
+ // 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce.
+ rr.limbs[n-1] = 1 << ((mLen - 1) % _W)
+ // Then we double until we reach 2^(_W * n).
+ for i := mLen - 1; i < logR; i++ {
+ rr.Add(rr, m)
+ }
+
+ // Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in
+ // the Montgomery domain, meaning we can use Montgomery multiplication now).
+ // We could do that by doubling _W * n times, or with a square-and-double
+ // chain log2(_W * n) long. Turns out the fastest thing is to start out with
+ // doublings, and switch to square-and-double once the exponent is large
+ // enough to justify the cost of the multiplications.
+
+ // The threshold is selected experimentally as a linear function of n.
+ threshold := n / 4
+
+ // We calculate how many of the most-significant bits of the exponent we can
+ // compute before crossing the threshold, and we do it with doublings.
+ i := bits.UintSize
+ for logR>>i <= threshold {
+ i--
+ }
+ for k := uint(0); k < logR>>i; k++ {
+ rr.Add(rr, m)
+ }
+
+ // Then we process the remaining bits of the exponent with a
+ // square-and-double chain.
+ for i > 0 {
+ rr.montgomeryMul(rr, rr, m)
+ i--
+ if logR>>i&1 != 0 {
+ rr.Add(rr, m)
+ }
}
+
return rr
}
@@ -745,26 +779,21 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
return out.montgomeryReduction(m)
}
-// ExpShort calculates out = x^e mod m.
+// ExpShortVarTime calculates out = x^e mod m.
//
// The output will be resized to the size of m and overwritten. x must already
-// be reduced modulo m. This leaks the exact bit size of the exponent.
-func (out *Nat) ExpShort(x *Nat, e uint, m *Modulus) *Nat {
- xR := NewNat().set(x).montgomeryRepresentation(m)
-
- out.resetFor(m)
- out.limbs[0] = 1
- out.montgomeryRepresentation(m)
-
+// be reduced modulo m. This leaks the exponent through timing side-channels.
+func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
// For short exponents, precomputing a table and using a window like in Exp
- // doesn't pay off. Instead, we do a simple constant-time conditional
- // square-and-multiply chain, skipping the initial run of zeroes.
- tmp := NewNat().ExpandFor(m)
- for i := bits.UintSize - bitLen(e); i < bits.UintSize; i++ {
+ // doesn't pay off. Instead, we do a simple conditional square-and-multiply
+ // chain, skipping the initial run of zeroes.
+ xR := NewNat().set(x).montgomeryRepresentation(m)
+ out.set(xR)
+ for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ {
out.montgomeryMul(out, out, m)
- k := (e >> (bits.UintSize - i - 1)) & 1
- tmp.montgomeryMul(out, xR, m)
- out.assign(ctEq(k, 1), tmp)
+ if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
+ out.montgomeryMul(out, xR, m)
+ }
}
return out.montgomeryReduction(m)
}