diff --git a/hpke/kem_test.go b/hpke/kem_test.go new file mode 100644 index 000000000..c536133ac --- /dev/null +++ b/hpke/kem_test.go @@ -0,0 +1,63 @@ +package hpke_test + +import ( + "fmt" + "testing" + + "github.com/cloudflare/circl/hpke" + "github.com/cloudflare/circl/internal/test" +) + +func TestKemKeysMarshal(t *testing.T) { + for _, kem := range []hpke.KEM{ + hpke.KEM_P256_HKDF_SHA256, + hpke.KEM_P384_HKDF_SHA384, + hpke.KEM_P521_HKDF_SHA512, + hpke.KEM_X25519_HKDF_SHA256, + hpke.KEM_X448_HKDF_SHA512, + hpke.KEM_X25519_KYBER768_DRAFT00, + } { + checkIssue488(t, kem) + } +} + +func checkIssue488(t *testing.T, kem hpke.KEM) { + scheme := kem.Scheme() + pk, sk, err := scheme.GenerateKeyPair() + if err != nil { + t.Fatal(err) + } + skBytes, err := sk.MarshalBinary() + test.CheckNoErr(t, err, "marshal private key") + pkBytes, err := pk.MarshalBinary() + test.CheckNoErr(t, err, "marshal public key") + + t.Run(fmt.Sprintf("%v/PrivateKey", scheme.Name()), func(t *testing.T) { + N := scheme.PrivateKeySize() + buffer := make([]byte, N+1) + copy(buffer, skBytes) + + // passing a buffer larger than the private key size should error (but no panic). + _, err := scheme.UnmarshalBinaryPrivateKey(buffer[:N+1]) + test.CheckIsErr(t, err, "unmarshal private key should failed") + + // passing a buffer of the exact size must be correct. + gotSk, err := scheme.UnmarshalBinaryPrivateKey(buffer[:N]) + test.CheckNoErr(t, err, "unmarshal private key shouldn't fail") + test.CheckOk(sk.Equal(gotSk), "private keys are not equal", t) + }) + + t.Run(fmt.Sprintf("%v/PublicKey", scheme.Name()), func(t *testing.T) { + N := scheme.PublicKeySize() + buffer := make([]byte, N+1) + copy(buffer, pkBytes) + + // passing a buffer larger than the public key size should error (but no panic). + _, err := scheme.UnmarshalBinaryPublicKey(buffer[:N+1]) + test.CheckIsErr(t, err, "unmarshal public key should failed") + + gotPk, err := scheme.UnmarshalBinaryPublicKey(buffer[:N]) + test.CheckNoErr(t, err, "unmarshal public key shouldn't fail") + test.CheckOk(pk.Equal(gotPk), "public keys are not equal", t) + }) +} diff --git a/hpke/shortkem.go b/hpke/shortkem.go index e5c55e991..cea17a976 100644 --- a/hpke/shortkem.go +++ b/hpke/shortkem.go @@ -53,6 +53,7 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { bitmask = 0x01 } + Nsk := s.PrivateKeySize() dkpPrk := s.labeledExtract([]byte(""), []byte("dkp_prk"), seed) var bytes []byte ctr := 0 @@ -64,14 +65,12 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { dkpPrk, []byte("candidate"), []byte{byte(ctr)}, - uint16(s.byteSize()), + uint16(Nsk), ) bytes[0] &= bitmask skBig.SetBytes(bytes) } - l := s.PrivateKeySize() - sk := &shortKEMPrivKey{s, make([]byte, l), nil} - copy(sk.priv[l-len(bytes):], bytes) + sk := &shortKEMPrivKey{s, bytes, nil} return sk.Public(), sk } @@ -83,11 +82,11 @@ func (s shortKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) { l := s.PrivateKeySize() - if len(data) < l { - return nil, ErrInvalidKEMPrivateKey + if len(data) != l { + return nil, kem.ErrPrivKeySize } sk := &shortKEMPrivKey{s, make([]byte, l), nil} - copy(sk.priv[l-len(data):l], data[:l]) + copy(sk.priv, data[:l]) if !sk.validate() { return nil, ErrInvalidKEMPrivateKey } @@ -96,7 +95,11 @@ func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) } func (s shortKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) { - x, y := elliptic.Unmarshal(s, data) + l := s.PublicKeySize() + if len(data) != l { + return nil, kem.ErrPubKeySize + } + x, y := elliptic.Unmarshal(s, data[:l]) if x == nil { return nil, ErrInvalidKEMPublicKey } diff --git a/hpke/xkem.go b/hpke/xkem.go index f11ab6b37..19d896145 100644 --- a/hpke/xkem.go +++ b/hpke/xkem.go @@ -58,13 +58,14 @@ func (x xKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { if len(seed) != x.SeedSize() { panic(kem.ErrSeedSize) } - sk := &xKEMPrivKey{scheme: x, priv: make([]byte, x.size)} + Nsk := x.PrivateKeySize() + sk := &xKEMPrivKey{scheme: x, priv: make([]byte, Nsk)} dkpPrk := x.labeledExtract([]byte(""), []byte("dkp_prk"), seed) bytes := x.labeledExpand( dkpPrk, []byte("sk"), nil, - uint16(x.PrivateKeySize()), + uint16(Nsk), ) copy(sk.priv, bytes) return sk.Public(), sk @@ -81,8 +82,8 @@ func (x xKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) { l := x.PrivateKeySize() - if len(data) < l { - return nil, ErrInvalidKEMPrivateKey + if len(data) != l { + return nil, kem.ErrPrivKeySize } sk := &xKEMPrivKey{x, make([]byte, l), nil} copy(sk.priv, data[:l]) @@ -94,8 +95,8 @@ func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) { func (x xKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) { l := x.PublicKeySize() - if len(data) < l { - return nil, ErrInvalidKEMPublicKey + if len(data) != l { + return nil, kem.ErrPubKeySize } pk := &xKEMPubKey{x, make([]byte, l)} copy(pk.pub, data[:l])