diff --git a/CHANGELOG.md b/CHANGELOG.md index ae8d9030..1488390a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ This document outlines major changes between releases. New features: Behaviour changes: + * simplify PublicKey interface (#114) + * remove WithKeyPair callback from dBFT (#114) Improvements: diff --git a/config.go b/config.go index 5de4e58f..c34ac35b 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,6 @@ package dbft import ( - "bytes" "errors" "time" @@ -137,30 +136,6 @@ func checkConfig[H Hash](cfg *Config[H]) error { return nil } -// WithKeyPair sets GetKeyPair to a function returning default key pair -// if it is present in a list of validators. -func WithKeyPair[H Hash](priv PrivateKey, pub PublicKey) func(config *Config[H]) { - myPub, err := pub.MarshalBinary() - if err != nil { - return nil - } - - return func(cfg *Config[H]) { - cfg.GetKeyPair = func(ps []PublicKey) (int, PrivateKey, PublicKey) { - for i := range ps { - pi, err := ps[i].MarshalBinary() - if err != nil { - continue - } else if bytes.Equal(myPub, pi) { - return i, priv, pub - } - } - - return -1, nil, nil - } - } -} - // WithGetKeyPair sets GetKeyPair. func WithGetKeyPair[H Hash](f func(pubs []PublicKey) (int, PrivateKey, PublicKey)) func(config *Config[H]) { return func(cfg *Config[H]) { diff --git a/dbft_test.go b/dbft_test.go index 2adb1d52..86fe3c89 100644 --- a/dbft_test.go +++ b/dbft_test.go @@ -418,7 +418,9 @@ func TestDBFT_Invalid(t *testing.T) { require.NotNil(t, priv) require.NotNil(t, pub) - opts := []func(*dbft.Config[crypto.Uint256]){dbft.WithKeyPair[crypto.Uint256](priv, pub)} + opts := []func(*dbft.Config[crypto.Uint256]){dbft.WithGetKeyPair[crypto.Uint256](func(_ []dbft.PublicKey) (int, dbft.PrivateKey, dbft.PublicKey) { + return -1, nil, nil + })} t.Run("without Timer", func(t *testing.T) { _, err := dbft.New(opts...) require.Error(t, err) @@ -829,7 +831,9 @@ func (s *testState) getOptions() []func(*dbft.Config[crypto.Uint256]) { dbft.WithCurrentHeight[crypto.Uint256](func() uint32 { return s.currHeight }), dbft.WithCurrentBlockHash[crypto.Uint256](func() crypto.Uint256 { return s.currHash }), dbft.WithGetValidators[crypto.Uint256](func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey { return s.pubs }), - dbft.WithKeyPair[crypto.Uint256](s.privs[s.myIndex], s.pubs[s.myIndex]), + dbft.WithGetKeyPair[crypto.Uint256](func(_ []dbft.PublicKey) (int, dbft.PrivateKey, dbft.PublicKey) { + return s.myIndex, s.privs[s.myIndex], s.pubs[s.myIndex] + }), dbft.WithBroadcast[crypto.Uint256](func(p Payload) { s.ch = append(s.ch, p) }), dbft.WithGetTx[crypto.Uint256](s.pool.Get), dbft.WithProcessBlock[crypto.Uint256](func(b dbft.Block[crypto.Uint256]) { s.blocks = append(s.blocks, b) }), diff --git a/identity.go b/identity.go index 8ab67b81..9b41e952 100644 --- a/identity.go +++ b/identity.go @@ -1,19 +1,12 @@ package dbft import ( - "encoding" "fmt" ) type ( // PublicKey is a generic public key interface used by dbft. - PublicKey interface { - encoding.BinaryMarshaler - encoding.BinaryUnmarshaler - - // Verify verifies if sig is indeed msg's signature. - Verify(msg, sig []byte) error - } + PublicKey any // PrivateKey is a generic private key interface used by dbft. PrivateKey interface { diff --git a/internal/consensus/block.go b/internal/consensus/block.go index b24d7ccb..81a7fbe3 100644 --- a/internal/consensus/block.go +++ b/internal/consensus/block.go @@ -114,7 +114,7 @@ func (b *neoBlock) Sign(key dbft.PrivateKey) error { // Verify implements Block interface. func (b *neoBlock) Verify(pub dbft.PublicKey, sign []byte) error { data := b.GetHashData() - return pub.Verify(data, sign) + return pub.(*crypto.ECDSAPub).Verify(data, sign) } // Hash implements Block interface. diff --git a/internal/consensus/consensus.go b/internal/consensus/consensus.go index e264d884..c4cd7b69 100644 --- a/internal/consensus/consensus.go +++ b/internal/consensus/consensus.go @@ -22,7 +22,15 @@ func New(logger *zap.Logger, key dbft.PrivateKey, pub dbft.PublicKey, dbft.WithTimer[crypto.Uint256](timer.New()), dbft.WithLogger[crypto.Uint256](logger), dbft.WithSecondsPerBlock[crypto.Uint256](time.Second*5), - dbft.WithKeyPair[crypto.Uint256](key, pub), + dbft.WithGetKeyPair[crypto.Uint256](func(pubs []dbft.PublicKey) (int, dbft.PrivateKey, dbft.PublicKey) { + for i := range pubs { + if pub.(*crypto.ECDSAPub).Equals(pubs[i]) { + return i, key, pub + } + } + + return -1, nil, nil + }), dbft.WithGetTx[crypto.Uint256](getTx), dbft.WithGetVerified[crypto.Uint256](getVerified), dbft.WithBroadcast[crypto.Uint256](broadcast), diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go index 52ca81dc..9655e001 100644 --- a/internal/crypto/crypto_test.go +++ b/internal/crypto/crypto_test.go @@ -19,7 +19,7 @@ func TestVerifySignature(t *testing.T) { require.NoError(t, err) require.Equal(t, 64, len(sign)) - err = pub.Verify(data, sign) + err = pub.(*ECDSAPub).Verify(data, sign) require.NoError(t, err) } diff --git a/internal/crypto/ecdsa.go b/internal/crypto/ecdsa.go index 275221c0..85c4cf04 100644 --- a/internal/crypto/ecdsa.go +++ b/internal/crypto/ecdsa.go @@ -62,6 +62,11 @@ func (e ECDSAPriv) Sign(msg []byte) ([]byte, error) { return sig, nil } +// Equals implements dbft.PublicKey interface. +func (e *ECDSAPub) Equals(other dbft.PublicKey) bool { + return e.Equal(other.(*ECDSAPub).PublicKey) +} + // MarshalBinary implements encoding.BinaryMarshaler interface. func (e ECDSAPub) MarshalBinary() ([]byte, error) { return elliptic.MarshalCompressed(e.PublicKey.Curve, e.PublicKey.X, e.PublicKey.Y), nil diff --git a/internal/crypto/ecdsa_test.go b/internal/crypto/ecdsa_test.go index 5623716f..d792cdb6 100644 --- a/internal/crypto/ecdsa_test.go +++ b/internal/crypto/ecdsa_test.go @@ -13,7 +13,7 @@ func TestECDSA_MarshalUnmarshal(t *testing.T) { require.NotNil(t, priv) require.NotNil(t, pub) - data, err := pub.MarshalBinary() + data, err := pub.(*ECDSAPub).MarshalBinary() require.NoError(t, err) pub1 := new(ECDSAPub) diff --git a/internal/simulation/main.go b/internal/simulation/main.go index 4587b358..5c42aa09 100644 --- a/internal/simulation/main.go +++ b/internal/simulation/main.go @@ -155,8 +155,8 @@ func updatePublicKeys(nodes []*simNode, n int) { func sortValidators(pubs []dbft.PublicKey) { sort.Slice(pubs, func(i, j int) bool { - p1, _ := pubs[i].MarshalBinary() - p2, _ := pubs[j].MarshalBinary() + p1, _ := pubs[i].(*crypto.ECDSAPub).MarshalBinary() + p2, _ := pubs[j].(*crypto.ECDSAPub).MarshalBinary() return murmur3.Sum64(p1) < murmur3.Sum64(p2) }) }