diff --git a/config.go b/config.go index a65962de..5de4e58f 100644 --- a/config.go +++ b/config.go @@ -162,7 +162,7 @@ func WithKeyPair[H Hash](priv PrivateKey, pub PublicKey) func(config *Config[H]) } // WithGetKeyPair sets GetKeyPair. -func WithGetKeyPair[H Hash](f func([]PublicKey) (int, PrivateKey, PublicKey)) func(config *Config[H]) { +func WithGetKeyPair[H Hash](f func(pubs []PublicKey) (int, PrivateKey, PublicKey)) func(config *Config[H]) { return func(cfg *Config[H]) { cfg.GetKeyPair = f } @@ -281,14 +281,14 @@ func WithCurrentBlockHash[H Hash](f func() H) func(config *Config[H]) { } // WithGetValidators sets GetValidators. -func WithGetValidators[H Hash](f func(...Transaction[H]) []PublicKey) func(config *Config[H]) { +func WithGetValidators[H Hash](f func(txs ...Transaction[H]) []PublicKey) func(config *Config[H]) { return func(cfg *Config[H]) { cfg.GetValidators = f } } // WithNewConsensusPayload sets NewConsensusPayload. -func WithNewConsensusPayload[H Hash](f func(*Context[H], MessageType, any) ConsensusPayload[H]) func(config *Config[H]) { +func WithNewConsensusPayload[H Hash](f func(ctx *Context[H], typ MessageType, msg any) ConsensusPayload[H]) func(config *Config[H]) { return func(cfg *Config[H]) { cfg.NewConsensusPayload = f } @@ -309,14 +309,14 @@ func WithNewPrepareResponse[H Hash](f func(preparationHash H) PrepareResponse[H] } // WithNewChangeView sets NewChangeView. -func WithNewChangeView[H Hash](f func(byte, ChangeViewReason, uint64) ChangeView) func(config *Config[H]) { +func WithNewChangeView[H Hash](f func(newViewNumber byte, reason ChangeViewReason, ts uint64) ChangeView) func(config *Config[H]) { return func(cfg *Config[H]) { cfg.NewChangeView = f } } // WithNewCommit sets NewCommit. -func WithNewCommit[H Hash](f func([]byte) Commit) func(config *Config[H]) { +func WithNewCommit[H Hash](f func(signature []byte) Commit) func(config *Config[H]) { return func(cfg *Config[H]) { cfg.NewCommit = f } @@ -337,14 +337,14 @@ func WithNewRecoveryMessage[H Hash](f func() RecoveryMessage[H]) func(config *Co } // WithVerifyPrepareRequest sets VerifyPrepareRequest. -func WithVerifyPrepareRequest[H Hash](f func(ConsensusPayload[H]) error) func(config *Config[H]) { +func WithVerifyPrepareRequest[H Hash](f func(prepareReq ConsensusPayload[H]) error) func(config *Config[H]) { return func(cfg *Config[H]) { cfg.VerifyPrepareRequest = f } } // WithVerifyPrepareResponse sets VerifyPrepareResponse. -func WithVerifyPrepareResponse[H Hash](f func(ConsensusPayload[H]) error) func(config *Config[H]) { +func WithVerifyPrepareResponse[H Hash](f func(prepareResp ConsensusPayload[H]) error) func(config *Config[H]) { return func(cfg *Config[H]) { cfg.VerifyPrepareResponse = f } diff --git a/dbft.go b/dbft.go index 39f72db9..db6cf851 100644 --- a/dbft.go +++ b/dbft.go @@ -1,6 +1,7 @@ package dbft import ( + "fmt" "sync" "time" @@ -23,10 +24,10 @@ type ( ) // New returns new DBFT instance with specified H and A generic parameters -// using provided options or nil if some of the options are missing or invalid. +// using provided options or nil and error if some of the options are missing or invalid. // H and A generic parameters are used as hash and address representation for // dBFT consensus messages, blocks and transactions. -func New[H Hash](options ...func(config *Config[H])) *DBFT[H] { +func New[H Hash](options ...func(config *Config[H])) (*DBFT[H], error) { cfg := defaultConfig[H]() for _, option := range options { @@ -34,7 +35,7 @@ func New[H Hash](options ...func(config *Config[H])) *DBFT[H] { } if err := checkConfig(cfg); err != nil { - return nil + return nil, fmt.Errorf("invalid config: %w", err) } d := &DBFT[H]{ @@ -45,7 +46,7 @@ func New[H Hash](options ...func(config *Config[H])) *DBFT[H] { }, } - return d + return d, nil } func (d *DBFT[H]) addTransaction(tx Transaction[H]) { diff --git a/dbft_test.go b/dbft_test.go index 16ac328e..2adb1d52 100644 --- a/dbft_test.go +++ b/dbft_test.go @@ -43,7 +43,7 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) { t.Run("backup sends nothing on start", func(t *testing.T) { s.currHeight = 0 - service := dbft.New[crypto.Uint256](s.getOptions()...) + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) service.Start(0) require.Nil(t, s.tryRecv()) @@ -51,7 +51,7 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) { t.Run("primary send PrepareRequest on start", func(t *testing.T) { s.currHeight = 1 - service := dbft.New[crypto.Uint256](s.getOptions()...) + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) service.Start(0) p := s.tryRecv() @@ -90,7 +90,7 @@ func TestDBFT_SingleNode(t *testing.T) { s := newTestState(0, 1) s.currHeight = 2 - service := dbft.New[crypto.Uint256](s.getOptions()...) + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) service.Start(0) p := s.tryRecv() @@ -128,7 +128,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) { t.Run("receive request from primary", func(t *testing.T) { s.currHeight = 4 - service := dbft.New[crypto.Uint256](s.getOptions()...) + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) txs := []testTx{1} s.pool.Add(txs[0]) @@ -160,7 +160,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) { t.Run("change view on invalid tx", func(t *testing.T) { s.currHeight = 4 - service := dbft.New[crypto.Uint256](s.getOptions()...) + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) txs := []testTx{10} service.Start(0) @@ -188,7 +188,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) { t.Run("receive invalid prepare request", func(t *testing.T) { s.currHeight = 4 - service := dbft.New[crypto.Uint256](s.getOptions()...) + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) txs := []testTx{1, 2} s.pool.Add(txs[0]) @@ -236,7 +236,7 @@ func TestDBFT_CommitOnTransaction(t *testing.T) { s := newTestState(0, 4) s.currHeight = 1 - srv := dbft.New[crypto.Uint256](s.getOptions()...) + srv, _ := dbft.New[crypto.Uint256](s.getOptions()...) srv.Start(0) require.Nil(t, s.tryRecv()) @@ -256,7 +256,7 @@ func TestDBFT_CommitOnTransaction(t *testing.T) { privs: s.privs, } s1.pool.Add(tx) - srv1 := dbft.New[crypto.Uint256](s1.getOptions()...) + srv1, _ := dbft.New[crypto.Uint256](s1.getOptions()...) srv1.Start(0) srv1.OnReceive(req) srv1.OnReceive(s1.getPrepareResponse(1, req.Hash())) @@ -278,7 +278,7 @@ func TestDBFT_OnReceiveCommit(t *testing.T) { s := newTestState(2, 4) t.Run("send commit after enough responses", func(t *testing.T) { s.currHeight = 1 - service := dbft.New[crypto.Uint256](s.getOptions()...) + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) service.Start(0) req := s.tryRecv() @@ -338,7 +338,7 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) { s := newTestState(2, 4) t.Run("send recovery message", func(t *testing.T) { s.currHeight = 1 - service := dbft.New[crypto.Uint256](s.getOptions()...) + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) service.Start(0) req := s.tryRecv() @@ -360,7 +360,7 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) { require.Equal(t, dbft.RecoveryMessageType, rm.Type()) other := s.copyWithIndex(3) - srv2 := dbft.New[crypto.Uint256](other.getOptions()...) + srv2, _ := dbft.New[crypto.Uint256](other.getOptions()...) srv2.Start(0) srv2.OnReceive(rm) @@ -383,7 +383,7 @@ func TestDBFT_OnReceiveChangeView(t *testing.T) { s := newTestState(2, 4) t.Run("change view correctly", func(t *testing.T) { s.currHeight = 6 - service := dbft.New[crypto.Uint256](s.getOptions()...) + service, _ := dbft.New[crypto.Uint256](s.getOptions()...) service.Start(0) resp := s.getChangeView(1, 1) @@ -410,7 +410,8 @@ func TestDBFT_OnReceiveChangeView(t *testing.T) { func TestDBFT_Invalid(t *testing.T) { t.Run("without keys", func(t *testing.T) { - require.Nil(t, dbft.New[crypto.Uint256]()) + _, err := dbft.New[crypto.Uint256]() + require.Error(t, err) }) priv, pub := crypto.Generate(rand.Reader) @@ -419,85 +420,98 @@ func TestDBFT_Invalid(t *testing.T) { opts := []func(*dbft.Config[crypto.Uint256]){dbft.WithKeyPair[crypto.Uint256](priv, pub)} t.Run("without Timer", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithTimer[crypto.Uint256](timer.New())) t.Run("without CurrentHeight", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithCurrentHeight[crypto.Uint256](func() uint32 { return 0 })) t.Run("without CurrentBlockHash", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithCurrentBlockHash[crypto.Uint256](func() crypto.Uint256 { return crypto.Uint256{} })) t.Run("without GetValidators", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithGetValidators[crypto.Uint256](func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey { return []dbft.PublicKey{pub} })) t.Run("without NewBlockFromContext", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithNewBlockFromContext[crypto.Uint256](func(_ *dbft.Context[crypto.Uint256]) dbft.Block[crypto.Uint256] { return nil })) t.Run("without NewConsensusPayload", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithNewConsensusPayload[crypto.Uint256](func(_ *dbft.Context[crypto.Uint256], _ dbft.MessageType, _ any) dbft.ConsensusPayload[crypto.Uint256] { return nil })) t.Run("without NewPrepareRequest", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithNewPrepareRequest[crypto.Uint256](func(uint64, uint64, []crypto.Uint256) dbft.PrepareRequest[crypto.Uint256] { return nil })) t.Run("without NewPrepareResponse", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithNewPrepareResponse[crypto.Uint256](func(crypto.Uint256) dbft.PrepareResponse[crypto.Uint256] { return nil })) t.Run("without NewChangeView", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithNewChangeView[crypto.Uint256](func(byte, dbft.ChangeViewReason, uint64) dbft.ChangeView { return nil })) t.Run("without NewCommit", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithNewCommit[crypto.Uint256](func([]byte) dbft.Commit { return nil })) t.Run("without NewRecoveryRequest", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithNewRecoveryRequest[crypto.Uint256](func(uint64) dbft.RecoveryRequest { return nil })) t.Run("without NewRecoveryMessage", func(t *testing.T) { - require.Nil(t, dbft.New(opts...)) + _, err := dbft.New(opts...) + require.Error(t, err) }) opts = append(opts, dbft.WithNewRecoveryMessage[crypto.Uint256](func() dbft.RecoveryMessage[crypto.Uint256] { return nil })) t.Run("with all defaults", func(t *testing.T) { - d := dbft.New(opts...) + d, err := dbft.New(opts...) + require.NoError(t, err) require.NotNil(t, d) require.NotNil(t, d.Config.RequestTx) require.NotNil(t, d.Config.GetTx) @@ -526,19 +540,19 @@ func TestDBFT_Invalid(t *testing.T) { func TestDBFT_FourGoodNodesDeadlock(t *testing.T) { r0 := newTestState(0, 4) r0.currHeight = 4 - s0 := dbft.New[crypto.Uint256](r0.getOptions()...) + s0, _ := dbft.New[crypto.Uint256](r0.getOptions()...) s0.Start(0) r1 := r0.copyWithIndex(1) - s1 := dbft.New[crypto.Uint256](r1.getOptions()...) + s1, _ := dbft.New[crypto.Uint256](r1.getOptions()...) s1.Start(0) r2 := r0.copyWithIndex(2) - s2 := dbft.New[crypto.Uint256](r2.getOptions()...) + s2, _ := dbft.New[crypto.Uint256](r2.getOptions()...) s2.Start(0) r3 := r0.copyWithIndex(3) - s3 := dbft.New[crypto.Uint256](r3.getOptions()...) + s3, _ := dbft.New[crypto.Uint256](r3.getOptions()...) s3.Start(0) // Step 1. The primary (at view 0) replica 1 sends the PrepareRequest message. diff --git a/internal/consensus/consensus.go b/internal/consensus/consensus.go index 2db7c88c..e264d884 100644 --- a/internal/consensus/consensus.go +++ b/internal/consensus/consensus.go @@ -17,7 +17,7 @@ func New(logger *zap.Logger, key dbft.PrivateKey, pub dbft.PublicKey, currentHeight func() uint32, currentBlockHash func() crypto.Uint256, getValidators func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey, - verifyPayload func(consensusPayload dbft.ConsensusPayload[crypto.Uint256]) error) *dbft.DBFT[crypto.Uint256] { + verifyPayload func(consensusPayload dbft.ConsensusPayload[crypto.Uint256]) error) (*dbft.DBFT[crypto.Uint256], error) { return dbft.New[crypto.Uint256]( dbft.WithTimer[crypto.Uint256](timer.New()), dbft.WithLogger[crypto.Uint256](logger), diff --git a/internal/simulation/main.go b/internal/simulation/main.go index 3c857727..4587b358 100644 --- a/internal/simulation/main.go +++ b/internal/simulation/main.go @@ -121,7 +121,8 @@ func initSimNode(nodes []*simNode, i int, log *zap.Logger) error { cluster: nodes, } - nodes[i].d = consensus.New(nodes[i].log, key, pub, nodes[i].pool.Get, + var err error + nodes[i].d, err = consensus.New(nodes[i].log, key, pub, nodes[i].pool.Get, nodes[i].pool.GetVerified, nodes[i].Broadcast, nodes[i].ProcessBlock, @@ -130,9 +131,8 @@ func initSimNode(nodes []*simNode, i int, log *zap.Logger) error { nodes[i].GetValidators, nodes[i].VerifyPayload, ) - - if nodes[i].d == nil { - return errors.New("can't initialize dBFT") + if err != nil { + return fmt.Errorf("failed to initialize dBFT: %w", err) } nodes[i].addTx(*txCount)