Skip to content

Commit

Permalink
Merge pull request #107 from nspcc-dev/improve-interface
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-khimov authored Mar 21, 2024
2 parents 88b2f18 + 10b4f0e commit 332ff86
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 46 deletions.
14 changes: 7 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func WithKeyPair[H Hash](priv PrivateKey, pub PublicKey) func(config *Config[H])
}

// WithGetKeyPair sets GetKeyPair.
func WithGetKeyPair[H Hash](f func([]PublicKey) (int, PrivateKey, PublicKey)) func(config *Config[H]) {
func WithGetKeyPair[H Hash](f func(pubs []PublicKey) (int, PrivateKey, PublicKey)) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.GetKeyPair = f
}
Expand Down Expand Up @@ -281,14 +281,14 @@ func WithCurrentBlockHash[H Hash](f func() H) func(config *Config[H]) {
}

// WithGetValidators sets GetValidators.
func WithGetValidators[H Hash](f func(...Transaction[H]) []PublicKey) func(config *Config[H]) {
func WithGetValidators[H Hash](f func(txs ...Transaction[H]) []PublicKey) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.GetValidators = f
}
}

// WithNewConsensusPayload sets NewConsensusPayload.
func WithNewConsensusPayload[H Hash](f func(*Context[H], MessageType, any) ConsensusPayload[H]) func(config *Config[H]) {
func WithNewConsensusPayload[H Hash](f func(ctx *Context[H], typ MessageType, msg any) ConsensusPayload[H]) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.NewConsensusPayload = f
}
Expand All @@ -309,14 +309,14 @@ func WithNewPrepareResponse[H Hash](f func(preparationHash H) PrepareResponse[H]
}

// WithNewChangeView sets NewChangeView.
func WithNewChangeView[H Hash](f func(byte, ChangeViewReason, uint64) ChangeView) func(config *Config[H]) {
func WithNewChangeView[H Hash](f func(newViewNumber byte, reason ChangeViewReason, ts uint64) ChangeView) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.NewChangeView = f
}
}

// WithNewCommit sets NewCommit.
func WithNewCommit[H Hash](f func([]byte) Commit) func(config *Config[H]) {
func WithNewCommit[H Hash](f func(signature []byte) Commit) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.NewCommit = f
}
Expand All @@ -337,14 +337,14 @@ func WithNewRecoveryMessage[H Hash](f func() RecoveryMessage[H]) func(config *Co
}

// WithVerifyPrepareRequest sets VerifyPrepareRequest.
func WithVerifyPrepareRequest[H Hash](f func(ConsensusPayload[H]) error) func(config *Config[H]) {
func WithVerifyPrepareRequest[H Hash](f func(prepareReq ConsensusPayload[H]) error) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.VerifyPrepareRequest = f
}
}

// WithVerifyPrepareResponse sets VerifyPrepareResponse.
func WithVerifyPrepareResponse[H Hash](f func(ConsensusPayload[H]) error) func(config *Config[H]) {
func WithVerifyPrepareResponse[H Hash](f func(prepareResp ConsensusPayload[H]) error) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.VerifyPrepareResponse = f
}
Expand Down
9 changes: 5 additions & 4 deletions dbft.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dbft

import (
"fmt"
"sync"
"time"

Expand All @@ -23,18 +24,18 @@ type (
)

// New returns new DBFT instance with specified H and A generic parameters
// using provided options or nil if some of the options are missing or invalid.
// using provided options or nil and error if some of the options are missing or invalid.
// H and A generic parameters are used as hash and address representation for
// dBFT consensus messages, blocks and transactions.
func New[H Hash](options ...func(config *Config[H])) *DBFT[H] {
func New[H Hash](options ...func(config *Config[H])) (*DBFT[H], error) {
cfg := defaultConfig[H]()

for _, option := range options {
option(cfg)
}

if err := checkConfig(cfg); err != nil {
return nil
return nil, fmt.Errorf("invalid config: %w", err)
}

d := &DBFT[H]{
Expand All @@ -45,7 +46,7 @@ func New[H Hash](options ...func(config *Config[H])) *DBFT[H] {
},
}

return d
return d, nil
}

func (d *DBFT[H]) addTransaction(tx Transaction[H]) {
Expand Down
74 changes: 44 additions & 30 deletions dbft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) {

t.Run("backup sends nothing on start", func(t *testing.T) {
s.currHeight = 0
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)

service.Start(0)
require.Nil(t, s.tryRecv())
})

t.Run("primary send PrepareRequest on start", func(t *testing.T) {
s.currHeight = 1
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)

service.Start(0)
p := s.tryRecv()
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestDBFT_SingleNode(t *testing.T) {
s := newTestState(0, 1)

s.currHeight = 2
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)

service.Start(0)
p := s.tryRecv()
Expand Down Expand Up @@ -128,7 +128,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) {

t.Run("receive request from primary", func(t *testing.T) {
s.currHeight = 4
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
txs := []testTx{1}
s.pool.Add(txs[0])

Expand Down Expand Up @@ -160,7 +160,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) {

t.Run("change view on invalid tx", func(t *testing.T) {
s.currHeight = 4
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
txs := []testTx{10}

service.Start(0)
Expand Down Expand Up @@ -188,7 +188,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) {

t.Run("receive invalid prepare request", func(t *testing.T) {
s.currHeight = 4
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
txs := []testTx{1, 2}
s.pool.Add(txs[0])

Expand Down Expand Up @@ -236,7 +236,7 @@ func TestDBFT_CommitOnTransaction(t *testing.T) {
s := newTestState(0, 4)
s.currHeight = 1

srv := dbft.New[crypto.Uint256](s.getOptions()...)
srv, _ := dbft.New[crypto.Uint256](s.getOptions()...)
srv.Start(0)
require.Nil(t, s.tryRecv())

Expand All @@ -256,7 +256,7 @@ func TestDBFT_CommitOnTransaction(t *testing.T) {
privs: s.privs,
}
s1.pool.Add(tx)
srv1 := dbft.New[crypto.Uint256](s1.getOptions()...)
srv1, _ := dbft.New[crypto.Uint256](s1.getOptions()...)
srv1.Start(0)
srv1.OnReceive(req)
srv1.OnReceive(s1.getPrepareResponse(1, req.Hash()))
Expand All @@ -278,7 +278,7 @@ func TestDBFT_OnReceiveCommit(t *testing.T) {
s := newTestState(2, 4)
t.Run("send commit after enough responses", func(t *testing.T) {
s.currHeight = 1
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)

req := s.tryRecv()
Expand Down Expand Up @@ -338,7 +338,7 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) {
s := newTestState(2, 4)
t.Run("send recovery message", func(t *testing.T) {
s.currHeight = 1
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)

req := s.tryRecv()
Expand All @@ -360,7 +360,7 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) {
require.Equal(t, dbft.RecoveryMessageType, rm.Type())

other := s.copyWithIndex(3)
srv2 := dbft.New[crypto.Uint256](other.getOptions()...)
srv2, _ := dbft.New[crypto.Uint256](other.getOptions()...)
srv2.Start(0)
srv2.OnReceive(rm)

Expand All @@ -383,7 +383,7 @@ func TestDBFT_OnReceiveChangeView(t *testing.T) {
s := newTestState(2, 4)
t.Run("change view correctly", func(t *testing.T) {
s.currHeight = 6
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)

resp := s.getChangeView(1, 1)
Expand All @@ -410,7 +410,8 @@ func TestDBFT_OnReceiveChangeView(t *testing.T) {

func TestDBFT_Invalid(t *testing.T) {
t.Run("without keys", func(t *testing.T) {
require.Nil(t, dbft.New[crypto.Uint256]())
_, err := dbft.New[crypto.Uint256]()
require.Error(t, err)
})

priv, pub := crypto.Generate(rand.Reader)
Expand All @@ -419,85 +420,98 @@ func TestDBFT_Invalid(t *testing.T) {

opts := []func(*dbft.Config[crypto.Uint256]){dbft.WithKeyPair[crypto.Uint256](priv, pub)}
t.Run("without Timer", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithTimer[crypto.Uint256](timer.New()))
t.Run("without CurrentHeight", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithCurrentHeight[crypto.Uint256](func() uint32 { return 0 }))
t.Run("without CurrentBlockHash", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithCurrentBlockHash[crypto.Uint256](func() crypto.Uint256 { return crypto.Uint256{} }))
t.Run("without GetValidators", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithGetValidators[crypto.Uint256](func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey {
return []dbft.PublicKey{pub}
}))
t.Run("without NewBlockFromContext", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewBlockFromContext[crypto.Uint256](func(_ *dbft.Context[crypto.Uint256]) dbft.Block[crypto.Uint256] {
return nil
}))
t.Run("without NewConsensusPayload", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewConsensusPayload[crypto.Uint256](func(_ *dbft.Context[crypto.Uint256], _ dbft.MessageType, _ any) dbft.ConsensusPayload[crypto.Uint256] {
return nil
}))
t.Run("without NewPrepareRequest", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewPrepareRequest[crypto.Uint256](func(uint64, uint64, []crypto.Uint256) dbft.PrepareRequest[crypto.Uint256] {
return nil
}))
t.Run("without NewPrepareResponse", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewPrepareResponse[crypto.Uint256](func(crypto.Uint256) dbft.PrepareResponse[crypto.Uint256] {
return nil
}))
t.Run("without NewChangeView", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewChangeView[crypto.Uint256](func(byte, dbft.ChangeViewReason, uint64) dbft.ChangeView {
return nil
}))
t.Run("without NewCommit", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewCommit[crypto.Uint256](func([]byte) dbft.Commit {
return nil
}))
t.Run("without NewRecoveryRequest", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewRecoveryRequest[crypto.Uint256](func(uint64) dbft.RecoveryRequest {
return nil
}))
t.Run("without NewRecoveryMessage", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewRecoveryMessage[crypto.Uint256](func() dbft.RecoveryMessage[crypto.Uint256] {
return nil
}))
t.Run("with all defaults", func(t *testing.T) {
d := dbft.New(opts...)
d, err := dbft.New(opts...)
require.NoError(t, err)
require.NotNil(t, d)
require.NotNil(t, d.Config.RequestTx)
require.NotNil(t, d.Config.GetTx)
Expand Down Expand Up @@ -526,19 +540,19 @@ func TestDBFT_Invalid(t *testing.T) {
func TestDBFT_FourGoodNodesDeadlock(t *testing.T) {
r0 := newTestState(0, 4)
r0.currHeight = 4
s0 := dbft.New[crypto.Uint256](r0.getOptions()...)
s0, _ := dbft.New[crypto.Uint256](r0.getOptions()...)
s0.Start(0)

r1 := r0.copyWithIndex(1)
s1 := dbft.New[crypto.Uint256](r1.getOptions()...)
s1, _ := dbft.New[crypto.Uint256](r1.getOptions()...)
s1.Start(0)

r2 := r0.copyWithIndex(2)
s2 := dbft.New[crypto.Uint256](r2.getOptions()...)
s2, _ := dbft.New[crypto.Uint256](r2.getOptions()...)
s2.Start(0)

r3 := r0.copyWithIndex(3)
s3 := dbft.New[crypto.Uint256](r3.getOptions()...)
s3, _ := dbft.New[crypto.Uint256](r3.getOptions()...)
s3.Start(0)

// Step 1. The primary (at view 0) replica 1 sends the PrepareRequest message.
Expand Down
2 changes: 1 addition & 1 deletion internal/consensus/consensus.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func New(logger *zap.Logger, key dbft.PrivateKey, pub dbft.PublicKey,
currentHeight func() uint32,
currentBlockHash func() crypto.Uint256,
getValidators func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey,
verifyPayload func(consensusPayload dbft.ConsensusPayload[crypto.Uint256]) error) *dbft.DBFT[crypto.Uint256] {
verifyPayload func(consensusPayload dbft.ConsensusPayload[crypto.Uint256]) error) (*dbft.DBFT[crypto.Uint256], error) {
return dbft.New[crypto.Uint256](
dbft.WithTimer[crypto.Uint256](timer.New()),
dbft.WithLogger[crypto.Uint256](logger),
Expand Down
8 changes: 4 additions & 4 deletions internal/simulation/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ func initSimNode(nodes []*simNode, i int, log *zap.Logger) error {
cluster: nodes,
}

nodes[i].d = consensus.New(nodes[i].log, key, pub, nodes[i].pool.Get,
var err error
nodes[i].d, err = consensus.New(nodes[i].log, key, pub, nodes[i].pool.Get,
nodes[i].pool.GetVerified,
nodes[i].Broadcast,
nodes[i].ProcessBlock,
Expand All @@ -130,9 +131,8 @@ func initSimNode(nodes []*simNode, i int, log *zap.Logger) error {
nodes[i].GetValidators,
nodes[i].VerifyPayload,
)

if nodes[i].d == nil {
return errors.New("can't initialize dBFT")
if err != nil {
return fmt.Errorf("failed to initialize dBFT: %w", err)
}

nodes[i].addTx(*txCount)
Expand Down

0 comments on commit 332ff86

Please sign in to comment.