aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFilippo Valsorda <filippo@golang.org>2024-09-30 13:39:09 +0200
committerGopher Robot <gobot@golang.org>2024-10-22 19:51:02 +0000
commit750a45fe5e473d5afa193e9088f3d135e64eca26 (patch)
tree07f3db62d69e9976a25db118beb56055e710ee1c
parent36b172546bd03a74c79e109ec84c599b672ea9e4 (diff)
downloadgo-x-crypto-750a45fe5e473d5afa193e9088f3d135e64eca26.tar.xz
sha3: add MarshalBinary, AppendBinary, and UnmarshalBinary
Fixes golang/go#24617 Change-Id: I1d9d529950aa8a5953435e8d3412cda44b075d55 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/616635 Reviewed-by: Roland Shoemaker <roland@golang.org> Auto-Submit: Filippo Valsorda <filippo@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Daniel McCarney <daniel@binaryparadox.net> Reviewed-by: Michael Pratt <mpratt@google.com>
-rw-r--r--sha3/doc.go4
-rw-r--r--sha3/hashes.go31
-rw-r--r--sha3/sha3.go72
-rw-r--r--sha3/sha3_test.go42
-rw-r--r--sha3/shake.go42
5 files changed, 171 insertions, 20 deletions
diff --git a/sha3/doc.go b/sha3/doc.go
index 7e02309..bbf391f 100644
--- a/sha3/doc.go
+++ b/sha3/doc.go
@@ -5,6 +5,10 @@
// Package sha3 implements the SHA-3 fixed-output-length hash functions and
// the SHAKE variable-output-length hash functions defined by FIPS-202.
//
+// All types in this package also implement [encoding.BinaryMarshaler],
+// [encoding.BinaryAppender] and [encoding.BinaryUnmarshaler] to marshal and
+// unmarshal the internal state of the hash.
+//
// Both types of hash function use the "sponge" construction and the Keccak
// permutation. For a detailed specification see http://keccak.noekeon.org/
//
diff --git a/sha3/hashes.go b/sha3/hashes.go
index c544b29..31fffbe 100644
--- a/sha3/hashes.go
+++ b/sha3/hashes.go
@@ -48,33 +48,52 @@ func init() {
crypto.RegisterHash(crypto.SHA3_512, New512)
}
+const (
+ dsbyteSHA3 = 0b00000110
+ dsbyteKeccak = 0b00000001
+ dsbyteShake = 0b00011111
+ dsbyteCShake = 0b00000100
+
+ // rateK[c] is the rate in bytes for Keccak[c] where c is the capacity in
+ // bits. Given the sponge size is 1600 bits, the rate is 1600 - c bits.
+ rateK256 = (1600 - 256) / 8
+ rateK448 = (1600 - 448) / 8
+ rateK512 = (1600 - 512) / 8
+ rateK768 = (1600 - 768) / 8
+ rateK1024 = (1600 - 1024) / 8
+)
+
func new224Generic() *state {
- return &state{rate: 144, outputLen: 28, dsbyte: 0x06}
+ return &state{rate: rateK448, outputLen: 28, dsbyte: dsbyteSHA3}
}
func new256Generic() *state {
- return &state{rate: 136, outputLen: 32, dsbyte: 0x06}
+ return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteSHA3}
}
func new384Generic() *state {
- return &state{rate: 104, outputLen: 48, dsbyte: 0x06}
+ return &state{rate: rateK768, outputLen: 48, dsbyte: dsbyteSHA3}
}
func new512Generic() *state {
- return &state{rate: 72, outputLen: 64, dsbyte: 0x06}
+ return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteSHA3}
}
// NewLegacyKeccak256 creates a new Keccak-256 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New256 instead.
-func NewLegacyKeccak256() hash.Hash { return &state{rate: 136, outputLen: 32, dsbyte: 0x01} }
+func NewLegacyKeccak256() hash.Hash {
+ return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteKeccak}
+}
// NewLegacyKeccak512 creates a new Keccak-512 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New512 instead.
-func NewLegacyKeccak512() hash.Hash { return &state{rate: 72, outputLen: 64, dsbyte: 0x01} }
+func NewLegacyKeccak512() hash.Hash {
+ return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteKeccak}
+}
// Sum224 returns the SHA3-224 digest of the data.
func Sum224(data []byte) (digest [28]byte) {
diff --git a/sha3/sha3.go b/sha3/sha3.go
index 4f5cadd..6658c44 100644
--- a/sha3/sha3.go
+++ b/sha3/sha3.go
@@ -7,6 +7,7 @@ package sha3
import (
"crypto/subtle"
"encoding/binary"
+ "errors"
"unsafe"
"golang.org/x/sys/cpu"
@@ -170,3 +171,74 @@ func (d *state) Sum(in []byte) []byte {
dup.Read(hash)
return append(in, hash...)
}
+
+const (
+ magicSHA3 = "sha\x08"
+ magicShake = "sha\x09"
+ magicCShake = "sha\x0a"
+ magicKeccak = "sha\x0b"
+ // magic || rate || main state || n || sponge direction
+ marshaledSize = len(magicSHA3) + 1 + 200 + 1 + 1
+)
+
+func (d *state) MarshalBinary() ([]byte, error) {
+ return d.AppendBinary(make([]byte, 0, marshaledSize))
+}
+
+func (d *state) AppendBinary(b []byte) ([]byte, error) {
+ switch d.dsbyte {
+ case dsbyteSHA3:
+ b = append(b, magicSHA3...)
+ case dsbyteShake:
+ b = append(b, magicShake...)
+ case dsbyteCShake:
+ b = append(b, magicCShake...)
+ case dsbyteKeccak:
+ b = append(b, magicKeccak...)
+ default:
+ panic("unknown dsbyte")
+ }
+ // rate is at most 168, and n is at most rate.
+ b = append(b, byte(d.rate))
+ b = append(b, d.a[:]...)
+ b = append(b, byte(d.n), byte(d.state))
+ return b, nil
+}
+
+func (d *state) UnmarshalBinary(b []byte) error {
+ if len(b) != marshaledSize {
+ return errors.New("sha3: invalid hash state")
+ }
+
+ magic := string(b[:len(magicSHA3)])
+ b = b[len(magicSHA3):]
+ switch {
+ case magic == magicSHA3 && d.dsbyte == dsbyteSHA3:
+ case magic == magicShake && d.dsbyte == dsbyteShake:
+ case magic == magicCShake && d.dsbyte == dsbyteCShake:
+ case magic == magicKeccak && d.dsbyte == dsbyteKeccak:
+ default:
+ return errors.New("sha3: invalid hash state identifier")
+ }
+
+ rate := int(b[0])
+ b = b[1:]
+ if rate != d.rate {
+ return errors.New("sha3: invalid hash state function")
+ }
+
+ copy(d.a[:], b)
+ b = b[len(d.a):]
+
+ n, state := int(b[0]), spongeDirection(b[1])
+ if n > d.rate {
+ return errors.New("sha3: invalid hash state")
+ }
+ d.n = n
+ if state != spongeAbsorbing && state != spongeSqueezing {
+ return errors.New("sha3: invalid hash state")
+ }
+ d.state = state
+
+ return nil
+}
diff --git a/sha3/sha3_test.go b/sha3/sha3_test.go
index 8347b60..d97a970 100644
--- a/sha3/sha3_test.go
+++ b/sha3/sha3_test.go
@@ -13,6 +13,7 @@ package sha3
import (
"bytes"
"compress/flate"
+ "encoding"
"encoding/hex"
"encoding/json"
"fmt"
@@ -421,11 +422,11 @@ func TestCSHAKEAccumulated(t *testing.T) {
// console.log(bytesToHex(acc.xof(32)));
//
t.Run("cSHAKE128", func(t *testing.T) {
- testCSHAKEAccumulated(t, NewCShake128, rate128,
+ testCSHAKEAccumulated(t, NewCShake128, rateK256,
"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
})
t.Run("cSHAKE256", func(t *testing.T) {
- testCSHAKEAccumulated(t, NewCShake256, rate256,
+ testCSHAKEAccumulated(t, NewCShake256, rateK512,
"0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef")
})
}
@@ -486,6 +487,43 @@ func TestCSHAKELargeS(t *testing.T) {
}
}
+func TestMarshalUnmarshal(t *testing.T) {
+ t.Run("SHA3-224", func(t *testing.T) { testMarshalUnmarshal(t, New224()) })
+ t.Run("SHA3-256", func(t *testing.T) { testMarshalUnmarshal(t, New256()) })
+ t.Run("SHA3-384", func(t *testing.T) { testMarshalUnmarshal(t, New384()) })
+ t.Run("SHA3-512", func(t *testing.T) { testMarshalUnmarshal(t, New512()) })
+ t.Run("SHAKE128", func(t *testing.T) { testMarshalUnmarshal(t, NewShake128()) })
+ t.Run("SHAKE256", func(t *testing.T) { testMarshalUnmarshal(t, NewShake256()) })
+ t.Run("cSHAKE128", func(t *testing.T) { testMarshalUnmarshal(t, NewCShake128([]byte("N"), []byte("S"))) })
+ t.Run("cSHAKE256", func(t *testing.T) { testMarshalUnmarshal(t, NewCShake256([]byte("N"), []byte("S"))) })
+ t.Run("Keccak-256", func(t *testing.T) { testMarshalUnmarshal(t, NewLegacyKeccak256()) })
+ t.Run("Keccak-512", func(t *testing.T) { testMarshalUnmarshal(t, NewLegacyKeccak512()) })
+}
+
+// TODO(filippo): move this to crypto/internal/cryptotest.
+func testMarshalUnmarshal(t *testing.T, h hash.Hash) {
+ buf := make([]byte, 200)
+ rand.Read(buf)
+ n := rand.Intn(200)
+ h.Write(buf)
+ want := h.Sum(nil)
+ h.Reset()
+ h.Write(buf[:n])
+ b, err := h.(encoding.BinaryMarshaler).MarshalBinary()
+ if err != nil {
+ t.Errorf("MarshalBinary: %v", err)
+ }
+ h.Write(bytes.Repeat([]byte{0}, 200))
+ if err := h.(encoding.BinaryUnmarshaler).UnmarshalBinary(b); err != nil {
+ t.Errorf("UnmarshalBinary: %v", err)
+ }
+ h.Write(buf[n:])
+ got := h.Sum(nil)
+ if !bytes.Equal(got, want) {
+ t.Errorf("got %x, want %x", got, want)
+ }
+}
+
// BenchmarkPermutationFunction measures the speed of the permutation function
// with no input data.
func BenchmarkPermutationFunction(b *testing.B) {
diff --git a/sha3/shake.go b/sha3/shake.go
index 6d75811..a6b3a42 100644
--- a/sha3/shake.go
+++ b/sha3/shake.go
@@ -16,7 +16,9 @@ package sha3
// [2] https://doi.org/10.6028/NIST.SP.800-185
import (
+ "bytes"
"encoding/binary"
+ "errors"
"hash"
"io"
"math/bits"
@@ -51,14 +53,6 @@ type cshakeState struct {
initBlock []byte
}
-// Consts for configuring initial SHA-3 state
-const (
- dsbyteShake = 0x1f
- dsbyteCShake = 0x04
- rate128 = 168
- rate256 = 136
-)
-
func bytepad(data []byte, rate int) []byte {
out := make([]byte, 0, 9+len(data)+rate-1)
out = append(out, leftEncode(uint64(rate))...)
@@ -112,6 +106,30 @@ func (c *state) Clone() ShakeHash {
return c.clone()
}
+func (c *cshakeState) MarshalBinary() ([]byte, error) {
+ return c.AppendBinary(make([]byte, 0, marshaledSize+len(c.initBlock)))
+}
+
+func (c *cshakeState) AppendBinary(b []byte) ([]byte, error) {
+ b, err := c.state.AppendBinary(b)
+ if err != nil {
+ return nil, err
+ }
+ b = append(b, c.initBlock...)
+ return b, nil
+}
+
+func (c *cshakeState) UnmarshalBinary(b []byte) error {
+ if len(b) <= marshaledSize {
+ return errors.New("sha3: invalid hash state")
+ }
+ if err := c.state.UnmarshalBinary(b[:marshaledSize]); err != nil {
+ return err
+ }
+ c.initBlock = bytes.Clone(b[marshaledSize:])
+ return nil
+}
+
// NewShake128 creates a new SHAKE128 variable-output-length ShakeHash.
// Its generic security strength is 128 bits against all attacks if at
// least 32 bytes of its output are used.
@@ -127,11 +145,11 @@ func NewShake256() ShakeHash {
}
func newShake128Generic() *state {
- return &state{rate: rate128, outputLen: 32, dsbyte: dsbyteShake}
+ return &state{rate: rateK256, outputLen: 32, dsbyte: dsbyteShake}
}
func newShake256Generic() *state {
- return &state{rate: rate256, outputLen: 64, dsbyte: dsbyteShake}
+ return &state{rate: rateK512, outputLen: 64, dsbyte: dsbyteShake}
}
// NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash,
@@ -144,7 +162,7 @@ func NewCShake128(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 {
return NewShake128()
}
- return newCShake(N, S, rate128, 32, dsbyteCShake)
+ return newCShake(N, S, rateK256, 32, dsbyteCShake)
}
// NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash,
@@ -157,7 +175,7 @@ func NewCShake256(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 {
return NewShake256()
}
- return newCShake(N, S, rate256, 64, dsbyteCShake)
+ return newCShake(N, S, rateK512, 64, dsbyteCShake)
}
// ShakeSum128 writes an arbitrary-length digest of data into hash.