Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Final user-facing interface improvements #107

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
}

// WithGetKeyPair sets GetKeyPair.
func WithGetKeyPair[H Hash](f func([]PublicKey) (int, PrivateKey, PublicKey)) func(config *Config[H]) {
func WithGetKeyPair[H Hash](f func(pubs []PublicKey) (int, PrivateKey, PublicKey)) func(config *Config[H]) {

Check warning on line 165 in config.go

View check run for this annotation

Codecov / codecov/patch

config.go#L165

Added line #L165 was not covered by tests
return func(cfg *Config[H]) {
cfg.GetKeyPair = f
}
Expand Down Expand Up @@ -281,14 +281,14 @@
}

// WithGetValidators sets GetValidators.
func WithGetValidators[H Hash](f func(...Transaction[H]) []PublicKey) func(config *Config[H]) {
func WithGetValidators[H Hash](f func(txs ...Transaction[H]) []PublicKey) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.GetValidators = f
}
}

// WithNewConsensusPayload sets NewConsensusPayload.
func WithNewConsensusPayload[H Hash](f func(*Context[H], MessageType, any) ConsensusPayload[H]) func(config *Config[H]) {
func WithNewConsensusPayload[H Hash](f func(ctx *Context[H], typ MessageType, msg any) ConsensusPayload[H]) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.NewConsensusPayload = f
}
Expand All @@ -309,14 +309,14 @@
}

// WithNewChangeView sets NewChangeView.
func WithNewChangeView[H Hash](f func(byte, ChangeViewReason, uint64) ChangeView) func(config *Config[H]) {
func WithNewChangeView[H Hash](f func(newViewNumber byte, reason ChangeViewReason, ts uint64) ChangeView) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.NewChangeView = f
}
}

// WithNewCommit sets NewCommit.
func WithNewCommit[H Hash](f func([]byte) Commit) func(config *Config[H]) {
func WithNewCommit[H Hash](f func(signature []byte) Commit) func(config *Config[H]) {
return func(cfg *Config[H]) {
cfg.NewCommit = f
}
Expand All @@ -337,14 +337,14 @@
}

// WithVerifyPrepareRequest sets VerifyPrepareRequest.
func WithVerifyPrepareRequest[H Hash](f func(ConsensusPayload[H]) error) func(config *Config[H]) {
func WithVerifyPrepareRequest[H Hash](f func(prepareReq ConsensusPayload[H]) error) func(config *Config[H]) {

Check warning on line 340 in config.go

View check run for this annotation

Codecov / codecov/patch

config.go#L340

Added line #L340 was not covered by tests
return func(cfg *Config[H]) {
cfg.VerifyPrepareRequest = f
}
}

// WithVerifyPrepareResponse sets VerifyPrepareResponse.
func WithVerifyPrepareResponse[H Hash](f func(ConsensusPayload[H]) error) func(config *Config[H]) {
func WithVerifyPrepareResponse[H Hash](f func(prepareResp ConsensusPayload[H]) error) func(config *Config[H]) {

Check warning on line 347 in config.go

View check run for this annotation

Codecov / codecov/patch

config.go#L347

Added line #L347 was not covered by tests
AnnaShaleva marked this conversation as resolved.
Show resolved Hide resolved
return func(cfg *Config[H]) {
cfg.VerifyPrepareResponse = f
}
Expand Down
9 changes: 5 additions & 4 deletions dbft.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dbft

import (
"fmt"
"sync"
"time"

Expand All @@ -23,18 +24,18 @@ type (
)

// New returns new DBFT instance with specified H and A generic parameters
// using provided options or nil if some of the options are missing or invalid.
// using provided options or nil and error if some of the options are missing or invalid.
// H and A generic parameters are used as hash and address representation for
// dBFT consensus messages, blocks and transactions.
func New[H Hash](options ...func(config *Config[H])) *DBFT[H] {
func New[H Hash](options ...func(config *Config[H])) (*DBFT[H], error) {
cfg := defaultConfig[H]()

for _, option := range options {
option(cfg)
}

if err := checkConfig(cfg); err != nil {
return nil
return nil, fmt.Errorf("invalid config: %w", err)
}

d := &DBFT[H]{
Expand All @@ -45,7 +46,7 @@ func New[H Hash](options ...func(config *Config[H])) *DBFT[H] {
},
}

return d
return d, nil
}

func (d *DBFT[H]) addTransaction(tx Transaction[H]) {
Expand Down
74 changes: 44 additions & 30 deletions dbft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) {

t.Run("backup sends nothing on start", func(t *testing.T) {
s.currHeight = 0
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)

service.Start(0)
require.Nil(t, s.tryRecv())
})

t.Run("primary send PrepareRequest on start", func(t *testing.T) {
s.currHeight = 1
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)

service.Start(0)
p := s.tryRecv()
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestDBFT_SingleNode(t *testing.T) {
s := newTestState(0, 1)

s.currHeight = 2
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)

service.Start(0)
p := s.tryRecv()
Expand Down Expand Up @@ -128,7 +128,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) {

t.Run("receive request from primary", func(t *testing.T) {
s.currHeight = 4
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
txs := []testTx{1}
s.pool.Add(txs[0])

Expand Down Expand Up @@ -160,7 +160,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) {

t.Run("change view on invalid tx", func(t *testing.T) {
s.currHeight = 4
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
txs := []testTx{10}

service.Start(0)
Expand Down Expand Up @@ -188,7 +188,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) {

t.Run("receive invalid prepare request", func(t *testing.T) {
s.currHeight = 4
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
txs := []testTx{1, 2}
s.pool.Add(txs[0])

Expand Down Expand Up @@ -236,7 +236,7 @@ func TestDBFT_CommitOnTransaction(t *testing.T) {
s := newTestState(0, 4)
s.currHeight = 1

srv := dbft.New[crypto.Uint256](s.getOptions()...)
srv, _ := dbft.New[crypto.Uint256](s.getOptions()...)
srv.Start(0)
require.Nil(t, s.tryRecv())

Expand All @@ -256,7 +256,7 @@ func TestDBFT_CommitOnTransaction(t *testing.T) {
privs: s.privs,
}
s1.pool.Add(tx)
srv1 := dbft.New[crypto.Uint256](s1.getOptions()...)
srv1, _ := dbft.New[crypto.Uint256](s1.getOptions()...)
srv1.Start(0)
srv1.OnReceive(req)
srv1.OnReceive(s1.getPrepareResponse(1, req.Hash()))
Expand All @@ -278,7 +278,7 @@ func TestDBFT_OnReceiveCommit(t *testing.T) {
s := newTestState(2, 4)
t.Run("send commit after enough responses", func(t *testing.T) {
s.currHeight = 1
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)

req := s.tryRecv()
Expand Down Expand Up @@ -338,7 +338,7 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) {
s := newTestState(2, 4)
t.Run("send recovery message", func(t *testing.T) {
s.currHeight = 1
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)

req := s.tryRecv()
Expand All @@ -360,7 +360,7 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) {
require.Equal(t, dbft.RecoveryMessageType, rm.Type())

other := s.copyWithIndex(3)
srv2 := dbft.New[crypto.Uint256](other.getOptions()...)
srv2, _ := dbft.New[crypto.Uint256](other.getOptions()...)
srv2.Start(0)
srv2.OnReceive(rm)

Expand All @@ -383,7 +383,7 @@ func TestDBFT_OnReceiveChangeView(t *testing.T) {
s := newTestState(2, 4)
t.Run("change view correctly", func(t *testing.T) {
s.currHeight = 6
service := dbft.New[crypto.Uint256](s.getOptions()...)
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)

resp := s.getChangeView(1, 1)
Expand All @@ -410,7 +410,8 @@ func TestDBFT_OnReceiveChangeView(t *testing.T) {

func TestDBFT_Invalid(t *testing.T) {
t.Run("without keys", func(t *testing.T) {
require.Nil(t, dbft.New[crypto.Uint256]())
_, err := dbft.New[crypto.Uint256]()
require.Error(t, err)
})

priv, pub := crypto.Generate(rand.Reader)
Expand All @@ -419,85 +420,98 @@ func TestDBFT_Invalid(t *testing.T) {

opts := []func(*dbft.Config[crypto.Uint256]){dbft.WithKeyPair[crypto.Uint256](priv, pub)}
t.Run("without Timer", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithTimer[crypto.Uint256](timer.New()))
t.Run("without CurrentHeight", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithCurrentHeight[crypto.Uint256](func() uint32 { return 0 }))
t.Run("without CurrentBlockHash", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithCurrentBlockHash[crypto.Uint256](func() crypto.Uint256 { return crypto.Uint256{} }))
t.Run("without GetValidators", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithGetValidators[crypto.Uint256](func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey {
return []dbft.PublicKey{pub}
}))
t.Run("without NewBlockFromContext", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewBlockFromContext[crypto.Uint256](func(_ *dbft.Context[crypto.Uint256]) dbft.Block[crypto.Uint256] {
return nil
}))
t.Run("without NewConsensusPayload", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewConsensusPayload[crypto.Uint256](func(_ *dbft.Context[crypto.Uint256], _ dbft.MessageType, _ any) dbft.ConsensusPayload[crypto.Uint256] {
return nil
}))
t.Run("without NewPrepareRequest", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewPrepareRequest[crypto.Uint256](func(uint64, uint64, []crypto.Uint256) dbft.PrepareRequest[crypto.Uint256] {
return nil
}))
t.Run("without NewPrepareResponse", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewPrepareResponse[crypto.Uint256](func(crypto.Uint256) dbft.PrepareResponse[crypto.Uint256] {
return nil
}))
t.Run("without NewChangeView", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewChangeView[crypto.Uint256](func(byte, dbft.ChangeViewReason, uint64) dbft.ChangeView {
return nil
}))
t.Run("without NewCommit", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewCommit[crypto.Uint256](func([]byte) dbft.Commit {
return nil
}))
t.Run("without NewRecoveryRequest", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewRecoveryRequest[crypto.Uint256](func(uint64) dbft.RecoveryRequest {
return nil
}))
t.Run("without NewRecoveryMessage", func(t *testing.T) {
require.Nil(t, dbft.New(opts...))
_, err := dbft.New(opts...)
require.Error(t, err)
})

opts = append(opts, dbft.WithNewRecoveryMessage[crypto.Uint256](func() dbft.RecoveryMessage[crypto.Uint256] {
return nil
}))
t.Run("with all defaults", func(t *testing.T) {
d := dbft.New(opts...)
d, err := dbft.New(opts...)
require.NoError(t, err)
require.NotNil(t, d)
require.NotNil(t, d.Config.RequestTx)
require.NotNil(t, d.Config.GetTx)
Expand Down Expand Up @@ -526,19 +540,19 @@ func TestDBFT_Invalid(t *testing.T) {
func TestDBFT_FourGoodNodesDeadlock(t *testing.T) {
r0 := newTestState(0, 4)
r0.currHeight = 4
s0 := dbft.New[crypto.Uint256](r0.getOptions()...)
s0, _ := dbft.New[crypto.Uint256](r0.getOptions()...)
s0.Start(0)

r1 := r0.copyWithIndex(1)
s1 := dbft.New[crypto.Uint256](r1.getOptions()...)
s1, _ := dbft.New[crypto.Uint256](r1.getOptions()...)
s1.Start(0)

r2 := r0.copyWithIndex(2)
s2 := dbft.New[crypto.Uint256](r2.getOptions()...)
s2, _ := dbft.New[crypto.Uint256](r2.getOptions()...)
s2.Start(0)

r3 := r0.copyWithIndex(3)
s3 := dbft.New[crypto.Uint256](r3.getOptions()...)
s3, _ := dbft.New[crypto.Uint256](r3.getOptions()...)
s3.Start(0)

// Step 1. The primary (at view 0) replica 1 sends the PrepareRequest message.
Expand Down
2 changes: 1 addition & 1 deletion internal/consensus/consensus.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
currentHeight func() uint32,
currentBlockHash func() crypto.Uint256,
getValidators func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey,
verifyPayload func(consensusPayload dbft.ConsensusPayload[crypto.Uint256]) error) *dbft.DBFT[crypto.Uint256] {
verifyPayload func(consensusPayload dbft.ConsensusPayload[crypto.Uint256]) error) (*dbft.DBFT[crypto.Uint256], error) {

Check warning on line 20 in internal/consensus/consensus.go

View check run for this annotation

Codecov / codecov/patch

internal/consensus/consensus.go#L20

Added line #L20 was not covered by tests
return dbft.New[crypto.Uint256](
dbft.WithTimer[crypto.Uint256](timer.New()),
dbft.WithLogger[crypto.Uint256](logger),
Expand Down
8 changes: 4 additions & 4 deletions internal/simulation/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@
cluster: nodes,
}

nodes[i].d = consensus.New(nodes[i].log, key, pub, nodes[i].pool.Get,
var err error
nodes[i].d, err = consensus.New(nodes[i].log, key, pub, nodes[i].pool.Get,

Check warning on line 125 in internal/simulation/main.go

View check run for this annotation

Codecov / codecov/patch

internal/simulation/main.go#L124-L125

Added lines #L124 - L125 were not covered by tests
nodes[i].pool.GetVerified,
nodes[i].Broadcast,
nodes[i].ProcessBlock,
Expand All @@ -130,9 +131,8 @@
nodes[i].GetValidators,
nodes[i].VerifyPayload,
)

if nodes[i].d == nil {
return errors.New("can't initialize dBFT")
if err != nil {
return fmt.Errorf("failed to initialize dBFT: %w", err)

Check warning on line 135 in internal/simulation/main.go

View check run for this annotation

Codecov / codecov/patch

internal/simulation/main.go#L134-L135

Added lines #L134 - L135 were not covered by tests
}

nodes[i].addTx(*txCount)
Expand Down
Loading