From ef1e5c67eea91238212b1cf0ad99b473b10d6a82 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Tue, 13 Feb 2024 12:31:57 +0300 Subject: [PATCH] *: use generic hashes implementation for DBFT A part of #2. Use generics instead of util.Uint160 and util.Uint256 types for DBFT and related components. Keep util.Uint160 and util.Uint256 only for specific DBFT implementation in testing code. The following regressions/behaviour changes were made to properly apply generics: 1. `dbft.Option` alias is removed since type parameters can't be defined on aliases (generic type aliases are prohibited). Ref. https://github.com/golang/go/issues/46477. 2. Default dBFT configuration is reduced: all payload-specific defaults are removed, as described in https://github.com/nspcc-dev/dbft/issues/91. It is done because default dBFT configuration should not depend on any implementation-specific hash type. 3. DBFT configuration validation check is extended wrt point 2. 4. The check if generic `dbft.DBFT` type implements generic `dbft.Service` interface is removed since such check should be performed on particular (non-generic) DBFT implementation. Signed-off-by: Anna Shaleva --- README.md | 9 +- block/block.go | 22 ++-- block/block_test.go | 2 +- block/transaction.go | 6 +- check.go | 6 +- config.go | 235 ++++++++++++++++++----------------- context.go | 86 ++++++------- crypto/crypto.go | 22 ++++ dbft.go | 83 ++++++------- dbft_test.go | 186 +++++++++++++++++++-------- helpers.go | 35 +++--- helpers_test.go | 3 +- payload/consensus_message.go | 26 ++-- payload/constructors.go | 12 +- payload/message.go | 8 +- payload/message_test.go | 4 +- payload/prepare_request.go | 13 +- payload/prepare_response.go | 9 +- payload/recovery_message.go | 38 +++--- send.go | 37 ++---- simulation/main.go | 77 ++++++++---- 21 files changed, 526 insertions(+), 393 deletions(-) diff --git a/README.md b/README.md index a0e7c65c..03177de9 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,17 @@ written in [TLA⁺](https://lamport.azurewebsites.net/tla/tla.html) language. ## Design and structure 1. All control flow is done in main package. Most of the code which communicates with external -world (event time events) is hidden behind interfaces and callbacks. As a consequence it is -highly flexible and extendable. Description of config options can be found in `config.go`. +world (event time events) is hidden behind interfaces, callbacks and generic parameters. As a +consequence it is highly flexible and extendable. Description of config options can be found +in `config.go`. 2. `crypto` package contains `PrivateKey`/`PublicKey` interfaces which permits usage of one's own cryptography for signing blocks on `Commit` stage. Default implementation with ECDSA signatures is provided, BLS multisignatures could be added in the nearest future. +3. `crypto` package contains `Hash`/`Address` interfaces which permits usage of one's own +hash/address implementation without additional overhead on conversions. Instantiate dBFT with +custom hash/address implementation that matches requirements specified in the corresponding +documentation. 3. `block` package contains `Block` and `Transaction` abstractions. Every block must be able to be signed and verified as well as implement setters and getters for main fields. Minimal default implementation is provided. diff --git a/block/block.go b/block/block.go index 3235c567..2ec26a01 100644 --- a/block/block.go +++ b/block/block.go @@ -23,15 +23,15 @@ type ( } // Block is a generic interface for a block used by dbft. - Block interface { + Block[H crypto.Hash, A crypto.Address] interface { // Hash returns block hash. - Hash() util.Uint256 + Hash() H Version() uint32 // PrevHash returns previous block hash. - PrevHash() util.Uint256 + PrevHash() H // MerkleRoot returns a merkle root of the transaction hashes. - MerkleRoot() util.Uint256 + MerkleRoot() H // Timestamp returns block's proposal timestamp. Timestamp() uint64 // Index returns block index. @@ -39,7 +39,7 @@ type ( // ConsensusData is a random nonce. ConsensusData() uint64 // NextConsensus returns hash of the validators of the next block. - NextConsensus() util.Uint160 + NextConsensus() A // Signature returns block's signature. Signature() []byte @@ -49,16 +49,16 @@ type ( Verify(key crypto.PublicKey, sign []byte) error // Transactions returns block's transaction list. - Transactions() []Transaction + Transactions() []Transaction[H] // SetTransactions sets block's transaction list. - SetTransactions([]Transaction) + SetTransactions([]Transaction[H]) } neoBlock struct { base consensusData uint64 - transactions []Transaction + transactions []Transaction[util.Uint256] signature []byte hash *util.Uint256 } @@ -100,17 +100,17 @@ func (b *neoBlock) ConsensusData() uint64 { } // Transactions implements Block interface. -func (b *neoBlock) Transactions() []Transaction { +func (b *neoBlock) Transactions() []Transaction[util.Uint256] { return b.transactions } // SetTransactions implements Block interface. -func (b *neoBlock) SetTransactions(txx []Transaction) { +func (b *neoBlock) SetTransactions(txx []Transaction[util.Uint256]) { b.transactions = txx } // NewBlock returns new block. -func NewBlock(timestamp uint64, index uint32, nextConsensus util.Uint160, prevHash util.Uint256, version uint32, nonce uint64, txHashes []util.Uint256) Block { +func NewBlock(timestamp uint64, index uint32, nextConsensus util.Uint160, prevHash util.Uint256, version uint32, nonce uint64, txHashes []util.Uint256) Block[util.Uint256, util.Uint160] { block := new(neoBlock) block.base.Timestamp = uint32(timestamp / 1000000000) block.base.Index = index diff --git a/block/block_test.go b/block/block_test.go index 96e14098..a3516f4f 100644 --- a/block/block_test.go +++ b/block/block_test.go @@ -19,7 +19,7 @@ func TestNeoBlock_Setters(t *testing.T) { require.Equal(t, util.Uint256{}, b.Hash()) - txs := []Transaction{testTx(1), testTx(2)} + txs := []Transaction[util.Uint256]{testTx(1), testTx(2)} b.SetTransactions(txs) assert.Equal(t, txs, b.Transactions()) diff --git a/block/transaction.go b/block/transaction.go index 39d26ad2..cb15ceeb 100644 --- a/block/transaction.go +++ b/block/transaction.go @@ -1,12 +1,12 @@ package block import ( - "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/dbft/crypto" ) // Transaction is a generic transaction interface. -type Transaction interface { +type Transaction[H crypto.Hash] interface { // Hash must return cryptographic hash of the transaction. // Transactions which have equal hashes are considered equal. - Hash() util.Uint256 + Hash() H } diff --git a/check.go b/check.go index c71b0106..14b2bb46 100644 --- a/check.go +++ b/check.go @@ -5,7 +5,7 @@ import ( "go.uber.org/zap" ) -func (d *DBFT) checkPrepare() { +func (d *DBFT[H, A]) checkPrepare() { if !d.hasAllTransactions() { d.Logger.Debug("check prepare: some transactions are missing", zap.Any("hashes", d.MissingTransactions)) return @@ -37,7 +37,7 @@ func (d *DBFT) checkPrepare() { } } -func (d *DBFT) checkCommit() { +func (d *DBFT[H, A]) checkCommit() { if !d.hasAllTransactions() { d.Logger.Debug("check commit: some transactions are missing", zap.Any("hashes", d.MissingTransactions)) return @@ -78,7 +78,7 @@ func (d *DBFT) checkCommit() { // new height. } -func (d *DBFT) checkChangeView(view byte) { +func (d *DBFT[H, A]) checkChangeView(view byte) { if d.ViewNumber >= view { return } diff --git a/config.go b/config.go index bdd803b2..13399da2 100644 --- a/config.go +++ b/config.go @@ -9,12 +9,11 @@ import ( "github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/dbft/timer" - "github.com/nspcc-dev/neo-go/pkg/util" "go.uber.org/zap" ) // Config contains initialization and working parameters for dBFT. -type Config struct { +type Config[H crypto.Hash, A crypto.Address] struct { // Logger Logger *zap.Logger // Timer @@ -30,45 +29,45 @@ type Config struct { // together with it's key pair. GetKeyPair func([]crypto.PublicKey) (int, crypto.PrivateKey, crypto.PublicKey) // NewBlockFromContext should allocate, fill from Context and return new block.Block. - NewBlockFromContext func(ctx *Context) block.Block + NewBlockFromContext func(ctx *Context[H, A]) block.Block[H, A] // RequestTx is a callback which is called when transaction contained // in current block can't be found in memory pool. - RequestTx func(h ...util.Uint256) + RequestTx func(h ...H) // StopTxFlow is a callback which is called when the process no longer needs // any transactions. StopTxFlow func() // GetTx returns a transaction from memory pool. - GetTx func(h util.Uint256) block.Transaction + GetTx func(h H) block.Transaction[H] // GetVerified returns a slice of verified transactions // to be proposed in a new block. - GetVerified func() []block.Transaction + GetVerified func() []block.Transaction[H] // VerifyBlock verifies if block is valid. - VerifyBlock func(b block.Block) bool + VerifyBlock func(b block.Block[H, A]) bool // Broadcast should broadcast payload m to the consensus nodes. - Broadcast func(m payload.ConsensusPayload) + Broadcast func(m payload.ConsensusPayload[H, A]) // ProcessBlock is called every time new block is accepted. - ProcessBlock func(b block.Block) + ProcessBlock func(b block.Block[H, A]) // GetBlock should return block with hash. - GetBlock func(h util.Uint256) block.Block + GetBlock func(h H) block.Block[H, A] // WatchOnly tells if a node should only watch. WatchOnly func() bool // CurrentHeight returns index of the last accepted block. CurrentHeight func() uint32 // CurrentBlockHash returns hash of the last accepted block. - CurrentBlockHash func() util.Uint256 + CurrentBlockHash func() H // GetValidators returns list of the validators. // When called with a transaction list it must return // list of the validators of the next block. // If this function ever returns 0-length slice, dbft will panic. - GetValidators func(...block.Transaction) []crypto.PublicKey + GetValidators func(...block.Transaction[H]) []crypto.PublicKey // GetConsensusAddress returns hash of the validator list. - GetConsensusAddress func(...crypto.PublicKey) util.Uint160 + GetConsensusAddress func(...crypto.PublicKey) A // NewConsensusPayload is a constructor for payload.ConsensusPayload. - NewConsensusPayload func(*Context, payload.MessageType, any) payload.ConsensusPayload + NewConsensusPayload func(*Context[H, A], payload.MessageType, any) payload.ConsensusPayload[H, A] // NewPrepareRequest is a constructor for payload.PrepareRequest. - NewPrepareRequest func() payload.PrepareRequest + NewPrepareRequest func() payload.PrepareRequest[H, A] // NewPrepareResponse is a constructor for payload.PrepareResponse. - NewPrepareResponse func() payload.PrepareResponse + NewPrepareResponse func() payload.PrepareResponse[H] // NewChangeView is a constructor for payload.ChangeView. NewChangeView func() payload.ChangeView // NewCommit is a constructor for payload.Commit. @@ -76,56 +75,44 @@ type Config struct { // NewRecoveryRequest is a constructor for payload.RecoveryRequest. NewRecoveryRequest func() payload.RecoveryRequest // NewRecoveryMessage is a constructor for payload.RecoveryMessage. - NewRecoveryMessage func() payload.RecoveryMessage + NewRecoveryMessage func() payload.RecoveryMessage[H, A] // VerifyPrepareRequest can perform external payload verification and returns true iff it was successful. - VerifyPrepareRequest func(p payload.ConsensusPayload) error + VerifyPrepareRequest func(p payload.ConsensusPayload[H, A]) error // VerifyPrepareResponse performs external PrepareResponse verification and returns nil if it's successful. - VerifyPrepareResponse func(p payload.ConsensusPayload) error + VerifyPrepareResponse func(p payload.ConsensusPayload[H, A]) error } const defaultSecondsPerBlock = time.Second * 15 const defaultTimestampIncrement = uint64(time.Millisecond / time.Nanosecond) -// Option is a generic options type. It can modify config in any way it wants. -type Option = func(*Config) - -func defaultConfig() *Config { +func defaultConfig[H crypto.Hash, A crypto.Address]() *Config[H, A] { // fields which are set to nil must be provided from client - return &Config{ - Logger: zap.NewNop(), - Timer: timer.New(), - SecondsPerBlock: defaultSecondsPerBlock, - TimestampIncrement: defaultTimestampIncrement, - GetKeyPair: nil, - NewBlockFromContext: NewBlockFromContext, - RequestTx: func(...util.Uint256) {}, - StopTxFlow: func() {}, - GetTx: func(util.Uint256) block.Transaction { return nil }, - GetVerified: func() []block.Transaction { return make([]block.Transaction, 0) }, - VerifyBlock: func(block.Block) bool { return true }, - Broadcast: func(payload.ConsensusPayload) {}, - ProcessBlock: func(block.Block) {}, - GetBlock: func(util.Uint256) block.Block { return nil }, - WatchOnly: func() bool { return false }, - CurrentHeight: nil, - CurrentBlockHash: nil, - GetValidators: nil, - GetConsensusAddress: func(...crypto.PublicKey) util.Uint160 { return util.Uint160{} }, - NewConsensusPayload: defaultNewConsensusPayload, - NewPrepareRequest: payload.NewPrepareRequest, - NewPrepareResponse: payload.NewPrepareResponse, - NewChangeView: payload.NewChangeView, - NewCommit: payload.NewCommit, - NewRecoveryRequest: payload.NewRecoveryRequest, - NewRecoveryMessage: payload.NewRecoveryMessage, - - VerifyPrepareRequest: func(payload.ConsensusPayload) error { return nil }, - VerifyPrepareResponse: func(payload.ConsensusPayload) error { return nil }, - } -} - -func checkConfig(cfg *Config) error { + return &Config[H, A]{ + Logger: zap.NewNop(), + Timer: timer.New(), + SecondsPerBlock: defaultSecondsPerBlock, + TimestampIncrement: defaultTimestampIncrement, + GetKeyPair: nil, + RequestTx: func(...H) {}, + StopTxFlow: func() {}, + GetTx: func(H) block.Transaction[H] { return nil }, + GetVerified: func() []block.Transaction[H] { return make([]block.Transaction[H], 0) }, + VerifyBlock: func(block.Block[H, A]) bool { return true }, + Broadcast: func(payload.ConsensusPayload[H, A]) {}, + ProcessBlock: func(block.Block[H, A]) {}, + GetBlock: func(H) block.Block[H, A] { return nil }, + WatchOnly: func() bool { return false }, + CurrentHeight: nil, + CurrentBlockHash: nil, + GetValidators: nil, + + VerifyPrepareRequest: func(payload.ConsensusPayload[H, A]) error { return nil }, + VerifyPrepareResponse: func(payload.ConsensusPayload[H, A]) error { return nil }, + } +} + +func checkConfig[H crypto.Hash, A crypto.Address](cfg *Config[H, A]) error { if cfg.GetKeyPair == nil { return errors.New("private key is nil") } else if cfg.CurrentHeight == nil { @@ -134,6 +121,24 @@ func checkConfig(cfg *Config) error { return errors.New("CurrentBlockHash is nil") } else if cfg.GetValidators == nil { return errors.New("GetValidators is nil") + } else if cfg.NewBlockFromContext == nil { + return errors.New("NewBlockFromContext is nil") + } else if cfg.GetConsensusAddress == nil { + return errors.New("GetConsensusAddress is nil") + } else if cfg.NewConsensusPayload == nil { + return errors.New("NewConsensusPayload is nil") + } else if cfg.NewPrepareRequest == nil { + return errors.New("NewPrepareRequest is nil") + } else if cfg.NewPrepareResponse == nil { + return errors.New("NewPrepareResponse is nil") + } else if cfg.NewChangeView == nil { + return errors.New("NewChangeView is nil") + } else if cfg.NewCommit == nil { + return errors.New("NewCommit is nil") + } else if cfg.NewRecoveryRequest == nil { + return errors.New("NewRecoveryRequest is nil") + } else if cfg.NewRecoveryMessage == nil { + return errors.New("NewRecoveryMessage is nil") } return nil @@ -141,13 +146,13 @@ func checkConfig(cfg *Config) error { // WithKeyPair sets GetKeyPair to a function returning default key pair // if it is present in a list of validators. -func WithKeyPair(priv crypto.PrivateKey, pub crypto.PublicKey) Option { +func WithKeyPair[H crypto.Hash, A crypto.Address](priv crypto.PrivateKey, pub crypto.PublicKey) func(config *Config[H, A]) { myPub, err := pub.MarshalBinary() if err != nil { return nil } - return func(cfg *Config) { + return func(cfg *Config[H, A]) { cfg.GetKeyPair = func(ps []crypto.PublicKey) (int, crypto.PrivateKey, crypto.PublicKey) { for i := range ps { pi, err := ps[i].MarshalBinary() @@ -164,197 +169,197 @@ func WithKeyPair(priv crypto.PrivateKey, pub crypto.PublicKey) Option { } // WithGetKeyPair sets GetKeyPair. -func WithGetKeyPair(f func([]crypto.PublicKey) (int, crypto.PrivateKey, crypto.PublicKey)) Option { - return func(cfg *Config) { +func WithGetKeyPair[H crypto.Hash, A crypto.Address](f func([]crypto.PublicKey) (int, crypto.PrivateKey, crypto.PublicKey)) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.GetKeyPair = f } } // WithLogger sets Logger. -func WithLogger(log *zap.Logger) Option { - return func(cfg *Config) { +func WithLogger[H crypto.Hash, A crypto.Address](log *zap.Logger) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.Logger = log } } // WithTimer sets Timer. -func WithTimer(t timer.Timer) Option { - return func(cfg *Config) { +func WithTimer[H crypto.Hash, A crypto.Address](t timer.Timer) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.Timer = t } } // WithSecondsPerBlock sets SecondsPerBlock. -func WithSecondsPerBlock(d time.Duration) Option { - return func(cfg *Config) { +func WithSecondsPerBlock[H crypto.Hash, A crypto.Address](d time.Duration) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.SecondsPerBlock = d } } // WithTimestampIncrement sets TimestampIncrement. -func WithTimestampIncrement(u uint64) Option { - return func(cfg *Config) { +func WithTimestampIncrement[H crypto.Hash, A crypto.Address](u uint64) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.TimestampIncrement = u } } // WithNewBlockFromContext sets NewBlockFromContext. -func WithNewBlockFromContext(f func(ctx *Context) block.Block) Option { - return func(cfg *Config) { +func WithNewBlockFromContext[H crypto.Hash, A crypto.Address](f func(ctx *Context[H, A]) block.Block[H, A]) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.NewBlockFromContext = f } } // WithRequestTx sets RequestTx. -func WithRequestTx(f func(h ...util.Uint256)) Option { - return func(cfg *Config) { +func WithRequestTx[H crypto.Hash, A crypto.Address](f func(h ...H)) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.RequestTx = f } } // WithStopTxFlow sets StopTxFlow. -func WithStopTxFlow(f func()) Option { - return func(cfg *Config) { +func WithStopTxFlow[H crypto.Hash, A crypto.Address](f func()) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.StopTxFlow = f } } // WithGetTx sets GetTx. -func WithGetTx(f func(h util.Uint256) block.Transaction) Option { - return func(cfg *Config) { +func WithGetTx[H crypto.Hash, A crypto.Address](f func(h H) block.Transaction[H]) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.GetTx = f } } // WithGetVerified sets GetVerified. -func WithGetVerified(f func() []block.Transaction) Option { - return func(cfg *Config) { +func WithGetVerified[H crypto.Hash, A crypto.Address](f func() []block.Transaction[H]) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.GetVerified = f } } // WithVerifyBlock sets VerifyBlock. -func WithVerifyBlock(f func(b block.Block) bool) Option { - return func(cfg *Config) { +func WithVerifyBlock[H crypto.Hash, A crypto.Address](f func(b block.Block[H, A]) bool) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.VerifyBlock = f } } // WithBroadcast sets Broadcast. -func WithBroadcast(f func(m payload.ConsensusPayload)) Option { - return func(cfg *Config) { +func WithBroadcast[H crypto.Hash, A crypto.Address](f func(m payload.ConsensusPayload[H, A])) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.Broadcast = f } } // WithProcessBlock sets ProcessBlock. -func WithProcessBlock(f func(b block.Block)) Option { - return func(cfg *Config) { +func WithProcessBlock[H crypto.Hash, A crypto.Address](f func(b block.Block[H, A])) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.ProcessBlock = f } } // WithGetBlock sets GetBlock. -func WithGetBlock(f func(h util.Uint256) block.Block) Option { - return func(cfg *Config) { +func WithGetBlock[H crypto.Hash, A crypto.Address](f func(h H) block.Block[H, A]) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.GetBlock = f } } // WithWatchOnly sets WatchOnly. -func WithWatchOnly(f func() bool) Option { - return func(cfg *Config) { +func WithWatchOnly[H crypto.Hash, A crypto.Address](f func() bool) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.WatchOnly = f } } // WithCurrentHeight sets CurrentHeight. -func WithCurrentHeight(f func() uint32) Option { - return func(cfg *Config) { +func WithCurrentHeight[H crypto.Hash, A crypto.Address](f func() uint32) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.CurrentHeight = f } } // WithCurrentBlockHash sets CurrentBlockHash. -func WithCurrentBlockHash(f func() util.Uint256) Option { - return func(cfg *Config) { +func WithCurrentBlockHash[H crypto.Hash, A crypto.Address](f func() H) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.CurrentBlockHash = f } } // WithGetValidators sets GetValidators. -func WithGetValidators(f func(...block.Transaction) []crypto.PublicKey) Option { - return func(cfg *Config) { +func WithGetValidators[H crypto.Hash, A crypto.Address](f func(...block.Transaction[H]) []crypto.PublicKey) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.GetValidators = f } } // WithGetConsensusAddress sets GetConsensusAddress. -func WithGetConsensusAddress(f func(keys ...crypto.PublicKey) util.Uint160) Option { - return func(cfg *Config) { +func WithGetConsensusAddress[H crypto.Hash, A crypto.Address](f func(keys ...crypto.PublicKey) A) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.GetConsensusAddress = f } } // WithNewConsensusPayload sets NewConsensusPayload. -func WithNewConsensusPayload(f func(*Context, payload.MessageType, any) payload.ConsensusPayload) Option { - return func(cfg *Config) { +func WithNewConsensusPayload[H crypto.Hash, A crypto.Address](f func(*Context[H, A], payload.MessageType, any) payload.ConsensusPayload[H, A]) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.NewConsensusPayload = f } } // WithNewPrepareRequest sets NewPrepareRequest. -func WithNewPrepareRequest(f func() payload.PrepareRequest) Option { - return func(cfg *Config) { +func WithNewPrepareRequest[H crypto.Hash, A crypto.Address](f func() payload.PrepareRequest[H, A]) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.NewPrepareRequest = f } } // WithNewPrepareResponse sets NewPrepareResponse. -func WithNewPrepareResponse(f func() payload.PrepareResponse) Option { - return func(cfg *Config) { +func WithNewPrepareResponse[H crypto.Hash, A crypto.Address](f func() payload.PrepareResponse[H]) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.NewPrepareResponse = f } } // WithNewChangeView sets NewChangeView. -func WithNewChangeView(f func() payload.ChangeView) Option { - return func(cfg *Config) { +func WithNewChangeView[H crypto.Hash, A crypto.Address](f func() payload.ChangeView) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.NewChangeView = f } } // WithNewCommit sets NewCommit. -func WithNewCommit(f func() payload.Commit) Option { - return func(cfg *Config) { +func WithNewCommit[H crypto.Hash, A crypto.Address](f func() payload.Commit) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.NewCommit = f } } // WithNewRecoveryRequest sets NewRecoveryRequest. -func WithNewRecoveryRequest(f func() payload.RecoveryRequest) Option { - return func(cfg *Config) { +func WithNewRecoveryRequest[H crypto.Hash, A crypto.Address](f func() payload.RecoveryRequest) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.NewRecoveryRequest = f } } // WithNewRecoveryMessage sets NewRecoveryMessage. -func WithNewRecoveryMessage(f func() payload.RecoveryMessage) Option { - return func(cfg *Config) { +func WithNewRecoveryMessage[H crypto.Hash, A crypto.Address](f func() payload.RecoveryMessage[H, A]) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.NewRecoveryMessage = f } } // WithVerifyPrepareRequest sets VerifyPrepareRequest. -func WithVerifyPrepareRequest(f func(payload.ConsensusPayload) error) Option { - return func(cfg *Config) { +func WithVerifyPrepareRequest[H crypto.Hash, A crypto.Address](f func(payload.ConsensusPayload[H, A]) error) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.VerifyPrepareRequest = f } } // WithVerifyPrepareResponse sets VerifyPrepareResponse. -func WithVerifyPrepareResponse(f func(payload.ConsensusPayload) error) Option { - return func(cfg *Config) { +func WithVerifyPrepareResponse[H crypto.Hash, A crypto.Address](f func(payload.ConsensusPayload[H, A]) error) func(config *Config[H, A]) { + return func(cfg *Config[H, A]) { cfg.VerifyPrepareResponse = f } } diff --git a/context.go b/context.go index 70a84d94..d8e56677 100644 --- a/context.go +++ b/context.go @@ -14,17 +14,17 @@ import ( // Context is a main dBFT structure which // contains all information needed for performing transitions. -type Context struct { +type Context[H crypto.Hash, A crypto.Address] struct { // Config is dBFT's Config instance. - Config *Config + Config *Config[H, A] // Priv is node's private key. Priv crypto.PrivateKey // Pub is node's public key. Pub crypto.PublicKey - block block.Block - header block.Block + block block.Block[H, A] + header block.Block[H, A] // blockProcessed denotes whether Config.ProcessBlock callback was called for the current // height. If so, then no second call must happen. After new block is received by the user, // dBFT stops any new transaction or messages processing as far as timeouts handling till @@ -45,33 +45,33 @@ type Context struct { Version uint32 // NextConsensus is a hash of the validators which will be accepting the next block. - NextConsensus util.Uint160 + NextConsensus A // PrevHash is a hash of the previous block. - PrevHash util.Uint256 + PrevHash H // Timestamp is a nanosecond-precision timestamp Timestamp uint64 Nonce uint64 // TransactionHashes is a slice of hashes of proposed transactions in the current block. - TransactionHashes []util.Uint256 + TransactionHashes []H // MissingTransactions is a slice of hashes containing missing transactions for the current block. - MissingTransactions []util.Uint256 + MissingTransactions []H // Transactions is a map containing actual transactions for the current block. - Transactions map[util.Uint256]block.Transaction + Transactions map[H]block.Transaction[H] // PreparationPayloads stores consensus Prepare* payloads for the current epoch. - PreparationPayloads []payload.ConsensusPayload + PreparationPayloads []payload.ConsensusPayload[H, A] // CommitPayloads stores consensus Commit payloads sent throughout all epochs. It // is assumed that valid Commit payload can only be sent once by a single node per // the whole set of consensus epochs for particular block. Invalid commit payloads // are kicked off this list immediately (if PrepareRequest was received for the // current round, so it's possible to verify Commit against it) or stored till // the corresponding PrepareRequest receiving. - CommitPayloads []payload.ConsensusPayload + CommitPayloads []payload.ConsensusPayload[H, A] // ChangeViewPayloads stores consensus ChangeView payloads for the current epoch. - ChangeViewPayloads []payload.ConsensusPayload + ChangeViewPayloads []payload.ConsensusPayload[H, A] // LastChangeViewPayloads stores consensus ChangeView payloads for the last epoch. - LastChangeViewPayloads []payload.ConsensusPayload + LastChangeViewPayloads []payload.ConsensusPayload[H, A] // LastSeenMessage array stores the height of the last seen message, for each validator. // if this node never heard from validator i, LastSeenMessage[i] will be -1. LastSeenMessage []*timer.HV @@ -82,16 +82,16 @@ type Context struct { } // N returns total number of validators. -func (c *Context) N() int { return len(c.Validators) } +func (c *Context[H, A]) N() int { return len(c.Validators) } // F returns number of validators which can be faulty. -func (c *Context) F() int { return (len(c.Validators) - 1) / 3 } +func (c *Context[H, A]) F() int { return (len(c.Validators) - 1) / 3 } // M returns number of validators which must function correctly. -func (c *Context) M() int { return len(c.Validators) - c.F() } +func (c *Context[H, A]) M() int { return len(c.Validators) - c.F() } // GetPrimaryIndex returns index of a primary node for the specified view. -func (c *Context) GetPrimaryIndex(viewNumber byte) uint { +func (c *Context[H, A]) GetPrimaryIndex(viewNumber byte) uint { p := (int(c.BlockIndex) - int(viewNumber)) % len(c.Validators) if p >= 0 { return uint(p) @@ -101,19 +101,19 @@ func (c *Context) GetPrimaryIndex(viewNumber byte) uint { } // IsPrimary returns true iff node is primary for current height and view. -func (c *Context) IsPrimary() bool { return c.MyIndex == int(c.PrimaryIndex) } +func (c *Context[H, A]) IsPrimary() bool { return c.MyIndex == int(c.PrimaryIndex) } // IsBackup returns true iff node is backup for current height and view. -func (c *Context) IsBackup() bool { +func (c *Context[H, A]) IsBackup() bool { return c.MyIndex >= 0 && !c.IsPrimary() } // WatchOnly returns true iff node takes no active part in consensus. -func (c *Context) WatchOnly() bool { return c.MyIndex < 0 || c.Config.WatchOnly() } +func (c *Context[H, A]) WatchOnly() bool { return c.MyIndex < 0 || c.Config.WatchOnly() } // CountCommitted returns number of received Commit messages not only for the current // epoch but also for any other epoch. -func (c *Context) CountCommitted() (count int) { +func (c *Context[H, A]) CountCommitted() (count int) { for i := range c.CommitPayloads { if c.CommitPayloads[i] != nil { count++ @@ -125,7 +125,7 @@ func (c *Context) CountCommitted() (count int) { // CountFailed returns number of nodes with which no communication was performed // for this view and that hasn't sent the Commit message at the previous views. -func (c *Context) CountFailed() (count int) { +func (c *Context[H, A]) CountFailed() (count int) { for i, hv := range c.LastSeenMessage { if c.CommitPayloads[i] == nil && (hv == nil || hv.Height < c.BlockIndex || hv.View < c.ViewNumber) { count++ @@ -137,18 +137,18 @@ func (c *Context) CountFailed() (count int) { // RequestSentOrReceived returns true iff PrepareRequest // was sent or received for the current epoch. -func (c *Context) RequestSentOrReceived() bool { +func (c *Context[H, A]) RequestSentOrReceived() bool { return c.PreparationPayloads[c.PrimaryIndex] != nil } // ResponseSent returns true iff Prepare* message was sent for the current epoch. -func (c *Context) ResponseSent() bool { +func (c *Context[H, A]) ResponseSent() bool { return !c.WatchOnly() && c.PreparationPayloads[c.MyIndex] != nil } // CommitSent returns true iff Commit message was sent for the current epoch // assuming that the node can't go further than current epoch after commit was sent. -func (c *Context) CommitSent() bool { +func (c *Context[H, A]) CommitSent() bool { return !c.WatchOnly() && c.CommitPayloads[c.MyIndex] != nil } @@ -165,10 +165,10 @@ func (c *Context) CommitSent() bool { // several places where the call to CreateBlock happens (one of them is right after // PrepareRequest receiving). Thus, we have a separate Context.blockProcessed field // for the described purpose. -func (c *Context) BlockSent() bool { return c.blockProcessed } +func (c *Context[H, A]) BlockSent() bool { return c.blockProcessed } // ViewChanging returns true iff node is in a process of changing view. -func (c *Context) ViewChanging() bool { +func (c *Context[H, A]) ViewChanging() bool { if c.WatchOnly() { return false } @@ -179,7 +179,7 @@ func (c *Context) ViewChanging() bool { } // NotAcceptingPayloadsDueToViewChanging returns true if node should not accept new payloads. -func (c *Context) NotAcceptingPayloadsDueToViewChanging() bool { +func (c *Context[H, A]) NotAcceptingPayloadsDueToViewChanging() bool { return c.ViewChanging() && !c.MoreThanFNodesCommittedOrLost() } @@ -190,11 +190,11 @@ func (c *Context) NotAcceptingPayloadsDueToViewChanging() bool { // asking change views loses network or crashes and comes back when nodes are committed in more than one higher // numbered view, it is possible for the node accepting recovery to commit in any of the higher views, thus // potentially splitting nodes among views and stalling the network. -func (c *Context) MoreThanFNodesCommittedOrLost() bool { +func (c *Context[H, A]) MoreThanFNodesCommittedOrLost() bool { return c.CountCommitted()+c.CountFailed() > c.F() } -func (c *Context) reset(view byte, ts uint64) { +func (c *Context[H, A]) reset(view byte, ts uint64) { c.MyIndex = -1 c.lastBlockTimestamp = ts @@ -204,7 +204,7 @@ func (c *Context) reset(view byte, ts uint64) { c.Validators = c.Config.GetValidators() n := len(c.Validators) - c.LastChangeViewPayloads = make([]payload.ConsensusPayload, n) + c.LastChangeViewPayloads = make([]payload.ConsensusPayload[H, A], n) if c.LastSeenMessage == nil { c.LastSeenMessage = make([]*timer.HV, n) @@ -227,13 +227,13 @@ func (c *Context) reset(view byte, ts uint64) { c.header = nil n := len(c.Validators) - c.ChangeViewPayloads = make([]payload.ConsensusPayload, n) + c.ChangeViewPayloads = make([]payload.ConsensusPayload[H, A], n) if view == 0 { - c.CommitPayloads = make([]payload.ConsensusPayload, n) + c.CommitPayloads = make([]payload.ConsensusPayload[H, A], n) } - c.PreparationPayloads = make([]payload.ConsensusPayload, n) + c.PreparationPayloads = make([]payload.ConsensusPayload[H, A], n) - c.Transactions = make(map[util.Uint256]block.Transaction) + c.Transactions = make(map[H]block.Transaction[H]) c.TransactionHashes = nil c.MissingTransactions = nil c.PrimaryIndex = c.GetPrimaryIndex(view) @@ -248,7 +248,7 @@ func (c *Context) reset(view byte, ts uint64) { } // Fill initializes consensus when node is a speaker. -func (c *Context) Fill() { +func (c *Context[H, A]) Fill() { b := make([]byte, 8) _, err := rand.Read(b) if err != nil { @@ -257,7 +257,7 @@ func (c *Context) Fill() { txx := c.Config.GetVerified() c.Nonce = binary.LittleEndian.Uint64(b) - c.TransactionHashes = make([]util.Uint256, len(txx)) + c.TransactionHashes = make([]H, len(txx)) for i := range txx { h := txx[i].Hash() @@ -276,18 +276,18 @@ func (c *Context) Fill() { // getTimestamp returns nanoseconds-precision timestamp using // current context config. -func (c *Context) getTimestamp() uint64 { +func (c *Context[H, A]) getTimestamp() uint64 { return uint64(c.Config.Timer.Now().UnixNano()) / c.Config.TimestampIncrement * c.Config.TimestampIncrement } // CreateBlock returns resulting block for the current epoch. -func (c *Context) CreateBlock() block.Block { +func (c *Context[H, A]) CreateBlock() block.Block[H, A] { if c.block == nil { if c.block = c.MakeHeader(); c.block == nil { return nil } - txx := make([]block.Transaction, len(c.TransactionHashes)) + txx := make([]block.Transaction[H], len(c.TransactionHashes)) for i, h := range c.TransactionHashes { txx[i] = c.Transactions[h] @@ -301,7 +301,7 @@ func (c *Context) CreateBlock() block.Block { // MakeHeader returns half-filled block for the current epoch. // All hashable fields will be filled. -func (c *Context) MakeHeader() block.Block { +func (c *Context[H, A]) MakeHeader() block.Block[H, A] { if c.header == nil { if !c.RequestSentOrReceived() { return nil @@ -313,7 +313,7 @@ func (c *Context) MakeHeader() block.Block { } // NewBlockFromContext returns new block filled with given contexet. -func NewBlockFromContext(ctx *Context) block.Block { +func NewBlockFromContext(ctx *Context[util.Uint256, util.Uint160]) block.Block[util.Uint256, util.Uint160] { if ctx.TransactionHashes == nil { return nil } @@ -323,6 +323,6 @@ func NewBlockFromContext(ctx *Context) block.Block { // hasAllTransactions returns true iff all transactions were received // for the proposed block. -func (c *Context) hasAllTransactions() bool { +func (c *Context[H, A]) hasAllTransactions() bool { return len(c.TransactionHashes) == len(c.Transactions) } diff --git a/crypto/crypto.go b/crypto/crypto.go index 839813e8..3c1d5021 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -2,6 +2,7 @@ package crypto import ( "encoding" + "fmt" "io" ) @@ -20,6 +21,27 @@ type ( // Sign returns msg's signature and error on failure. Sign(msg []byte) (sig []byte, err error) } + + // Hash is a generic hash interface used by dbft for payloads, blocks and + // transactions identification. It is recommended to implement this interface + // using hash functions with low hash collision probability. The following + // requirements must be met: + // 1. Hashes of two equal payloads/blocks/transactions are equal. + // 2. Hashes of two different payloads/blocks/transactions are different. + Hash interface { + comparable + fmt.Stringer + } + + // Address is a generic address interface used by dbft for operations related + // to consensus address. It is recommended to implement this interface + // using hash functions with low hash collision probability. The following + // requirements must be met: + // 1. Addresses of two equal sets of consensus members are equal. + // 2. Addresses of two different sets of consensus members are different. + Address interface { + comparable + } ) type suiteType byte diff --git a/dbft.go b/dbft.go index 7f061335..979cbdb9 100644 --- a/dbft.go +++ b/dbft.go @@ -5,39 +5,38 @@ import ( "time" "github.com/nspcc-dev/dbft/block" + "github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/dbft/timer" - "github.com/nspcc-dev/neo-go/pkg/util" "go.uber.org/zap" ) type ( // DBFT is a wrapper over Context containing service configuration and // some other parameters not directly related to dBFT's state machine. - DBFT struct { - Context - Config + DBFT[H crypto.Hash, A crypto.Address] struct { + Context[H, A] + Config[H, A] *sync.Mutex - cache cache + cache cache[H, A] recovering bool } - // Service is an interface for dBFT consensus. - Service interface { + Service[H crypto.Hash, A crypto.Address] interface { Start(uint64) - OnTransaction(block.Transaction) - OnReceive(payload.ConsensusPayload) + OnTransaction(block.Transaction[H]) + OnReceive(payload.ConsensusPayload[H, A]) OnTimeout(timer.HV) } ) -var _ Service = (*DBFT)(nil) - -// New returns new DBFT instance with provided options -// and nil if some of the options are missing or invalid. -func New(options ...Option) *DBFT { - cfg := defaultConfig() +// New returns new DBFT instance with specified H and A generic parameters +// using provided options or nil if some of the options are missing or invalid. +// H and A generic parameters are used as hash and address representation for +// dBFT consensus messages, blocks and transactions. +func New[H crypto.Hash, A crypto.Address](options ...func(config *Config[H, A])) *DBFT[H, A] { + cfg := defaultConfig[H, A]() for _, option := range options { option(cfg) @@ -47,10 +46,10 @@ func New(options ...Option) *DBFT { return nil } - d := &DBFT{ + d := &DBFT[H, A]{ Mutex: new(sync.Mutex), Config: *cfg, - Context: Context{ + Context: Context[H, A]{ Config: cfg, }, } @@ -58,7 +57,7 @@ func New(options ...Option) *DBFT { return d } -func (d *DBFT) addTransaction(tx block.Transaction) { +func (d *DBFT[H, A]) addTransaction(tx block.Transaction[H]) { d.Transactions[tx.Hash()] = tx if d.hasAllTransactions() { if d.IsPrimary() || d.Context.WatchOnly() { @@ -77,14 +76,14 @@ func (d *DBFT) addTransaction(tx block.Transaction) { // Start initializes dBFT instance and starts protocol if node is primary. It // accepts a timestamp of the previous block. -func (d *DBFT) Start(ts uint64) { - d.cache = newCache() +func (d *DBFT[H, A]) Start(ts uint64) { + d.cache = newCache[H, A]() d.InitializeConsensus(0, ts) d.start() } // InitializeConsensus initializes dBFT instance. -func (d *DBFT) InitializeConsensus(view byte, ts uint64) { +func (d *DBFT[H, A]) InitializeConsensus(view byte, ts uint64) { d.reset(view, ts) var role string @@ -137,7 +136,7 @@ func (d *DBFT) InitializeConsensus(view byte, ts uint64) { } // OnTransaction notifies service about receiving new transaction. -func (d *DBFT) OnTransaction(tx block.Transaction) { +func (d *DBFT[H, A]) OnTransaction(tx block.Transaction[H]) { // d.Logger.Debug("OnTransaction", // zap.Bool("backup", d.IsBackup()), // zap.Bool("not_accepting", d.NotAcceptingPayloadsDueToViewChanging()), @@ -169,7 +168,7 @@ func (d *DBFT) OnTransaction(tx block.Transaction) { } // OnTimeout advances state machine as if timeout was fired. -func (d *DBFT) OnTimeout(hv timer.HV) { +func (d *DBFT[H, A]) OnTimeout(hv timer.HV) { if d.Context.WatchOnly() || d.BlockSent() { return } @@ -200,7 +199,7 @@ func (d *DBFT) OnTimeout(hv timer.HV) { } // OnReceive advances state machine in accordance with msg. -func (d *DBFT) OnReceive(msg payload.ConsensusPayload) { +func (d *DBFT[H, A]) OnReceive(msg payload.ConsensusPayload[H, A]) { if int(msg.ValidatorIndex()) >= len(d.Validators) { d.Logger.Error("too big validator index", zap.Uint16("from", msg.ValidatorIndex())) return @@ -269,7 +268,7 @@ func (d *DBFT) OnReceive(msg payload.ConsensusPayload) { // start performs initial operations and returns messages to be sent. // It must be called after every height or view increment. -func (d *DBFT) start() { +func (d *DBFT[H, A]) start() { if !d.IsPrimary() { if msgs := d.cache.getHeight(d.BlockIndex); msgs != nil { for _, m := range msgs.prepare { @@ -291,10 +290,10 @@ func (d *DBFT) start() { d.sendPrepareRequest() } -func (d *DBFT) onPrepareRequest(msg payload.ConsensusPayload) { +func (d *DBFT[H, A]) onPrepareRequest(msg payload.ConsensusPayload[H, A]) { // ignore prepareRequest if we had already received it or // are in process of changing view - if d.RequestSentOrReceived() { //|| (d.ViewChanging() && !d.MoreThanFNodesCommittedOrLost()) { + if d.RequestSentOrReceived() { // || (d.ViewChanging() && !d.MoreThanFNodesCommittedOrLost()) { d.Logger.Debug("ignoring PrepareRequest", zap.Bool("sor", d.RequestSentOrReceived()), zap.Bool("viewChanging", d.ViewChanging()), @@ -343,8 +342,8 @@ func (d *DBFT) onPrepareRequest(msg payload.ConsensusPayload) { d.checkPrepare() } -func (d *DBFT) processMissingTx() { - missing := make([]util.Uint256, 0, len(d.TransactionHashes)/2) +func (d *DBFT[H, A]) processMissingTx() { + missing := make([]H, 0, len(d.TransactionHashes)/2) for _, h := range d.TransactionHashes { if _, ok := d.Transactions[h]; ok { @@ -369,8 +368,8 @@ func (d *DBFT) processMissingTx() { // the new proposed block, if it's fine it returns true, if something is wrong // with it, it sends a changeView request and returns false. It's only valid to // call it when all transactions for this block are already collected. -func (d *DBFT) createAndCheckBlock() bool { - txx := make([]block.Transaction, 0, len(d.TransactionHashes)) +func (d *DBFT[H, A]) createAndCheckBlock() bool { + txx := make([]block.Transaction[H], 0, len(d.TransactionHashes)) for _, h := range d.TransactionHashes { txx = append(txx, d.Transactions[h]) } @@ -387,7 +386,7 @@ func (d *DBFT) createAndCheckBlock() bool { return true } -func (d *DBFT) updateExistingPayloads(msg payload.ConsensusPayload) { +func (d *DBFT[H, A]) updateExistingPayloads(msg payload.ConsensusPayload[H, A]) { for i, m := range d.PreparationPayloads { if m != nil && m.Type() == payload.PrepareResponseType { resp := m.GetPrepareResponse() @@ -410,7 +409,7 @@ func (d *DBFT) updateExistingPayloads(msg payload.ConsensusPayload) { } } -func (d *DBFT) onPrepareResponse(msg payload.ConsensusPayload) { +func (d *DBFT[H, A]) onPrepareResponse(msg payload.ConsensusPayload[H, A]) { if d.ViewNumber != msg.ViewNumber() { d.Logger.Debug("ignoring wrong view number", zap.Uint("view", uint(msg.ViewNumber()))) return @@ -462,7 +461,7 @@ func (d *DBFT) onPrepareResponse(msg payload.ConsensusPayload) { } } -func (d *DBFT) onChangeView(msg payload.ConsensusPayload) { +func (d *DBFT[H, A]) onChangeView(msg payload.ConsensusPayload[H, A]) { p := msg.GetChangeView() if p.NewViewNumber() <= d.ViewNumber { @@ -493,16 +492,16 @@ func (d *DBFT) onChangeView(msg payload.ConsensusPayload) { d.checkChangeView(p.NewViewNumber()) } -func (d *DBFT) onCommit(msg payload.ConsensusPayload) { +func (d *DBFT[H, A]) onCommit(msg payload.ConsensusPayload[H, A]) { existing := d.CommitPayloads[msg.ValidatorIndex()] if existing != nil { - if !existing.Hash().Equals(msg.Hash()) { + if existing.Hash() != msg.Hash() { d.Logger.Warn("rejecting commit due to existing", zap.Uint("validator", uint(msg.ValidatorIndex())), zap.Uint("existing view", uint(existing.ViewNumber())), zap.Uint("view", uint(msg.ViewNumber())), - zap.String("existing hash", existing.Hash().StringLE()), - zap.String("hash", msg.Hash().StringLE()), + zap.Stringer("existing hash", existing.Hash()), + zap.Stringer("hash", msg.Hash()), ) } return @@ -535,7 +534,7 @@ func (d *DBFT) onCommit(msg payload.ConsensusPayload) { d.CommitPayloads[msg.ValidatorIndex()] = msg } -func (d *DBFT) onRecoveryRequest(msg payload.ConsensusPayload) { +func (d *DBFT[H, A]) onRecoveryRequest(msg payload.ConsensusPayload[H, A]) { if !d.CommitSent() { // Limit recoveries to be sent from no more than F nodes // TODO replace loop with a single if @@ -557,7 +556,7 @@ func (d *DBFT) onRecoveryRequest(msg payload.ConsensusPayload) { d.sendRecoveryMessage() } -func (d *DBFT) onRecoveryMessage(msg payload.ConsensusPayload) { +func (d *DBFT[H, A]) onRecoveryMessage(msg payload.ConsensusPayload[H, A]) { d.Logger.Debug("recovery message received", zap.Any("dump", msg)) var ( @@ -617,7 +616,7 @@ func (d *DBFT) onRecoveryMessage(msg payload.ConsensusPayload) { } } -func (d *DBFT) changeTimer(delay time.Duration) { +func (d *DBFT[H, A]) changeTimer(delay time.Duration) { d.Logger.Debug("reset timer", zap.Uint32("h", d.BlockIndex), zap.Int("v", int(d.ViewNumber)), @@ -625,7 +624,7 @@ func (d *DBFT) changeTimer(delay time.Duration) { d.Timer.Reset(timer.HV{Height: d.BlockIndex, View: d.ViewNumber}, delay) } -func (d *DBFT) extendTimer(count time.Duration) { +func (d *DBFT[H, A]) extendTimer(count time.Duration) { if !d.CommitSent() && !d.ViewChanging() { d.Timer.Extend(count * d.SecondsPerBlock / time.Duration(d.M())) } diff --git a/dbft_test.go b/dbft_test.go index 9ec9eddf..78a9844d 100644 --- a/dbft_test.go +++ b/dbft_test.go @@ -15,7 +15,7 @@ import ( "go.uber.org/zap" ) -type Payload = payload.ConsensusPayload +type Payload = payload.ConsensusPayload[util.Uint256, util.Uint160] type testState struct { myIndex int @@ -26,8 +26,8 @@ type testState struct { currHeight uint32 currHash util.Uint256 pool *testPool - blocks []block.Block - verify func(b block.Block) bool + blocks []block.Block[util.Uint256, util.Uint160] + verify func(b block.Block[util.Uint256, util.Uint160]) bool } type ( @@ -44,7 +44,7 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) { t.Run("backup sends nothing on start", func(t *testing.T) { s.currHeight = 0 - service := New(s.getOptions()...) + service := New[util.Uint256, util.Uint160](s.getOptions()...) service.Start(0) require.Nil(t, s.tryRecv()) @@ -52,7 +52,7 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) { t.Run("primary send PrepareRequest on start", func(t *testing.T) { s.currHeight = 1 - service := New(s.getOptions()...) + service := New[util.Uint256, util.Uint160](s.getOptions()...) service.Start(0) p := s.tryRecv() @@ -91,7 +91,7 @@ func TestDBFT_SingleNode(t *testing.T) { s := newTestState(0, 1) s.currHeight = 2 - service := New(s.getOptions()...) + service := New[util.Uint256, util.Uint160](s.getOptions()...) service.Start(0) p := s.tryRecv() @@ -117,7 +117,7 @@ func TestDBFT_SingleNode(t *testing.T) { func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) { s := newTestState(2, 7) - s.verify = func(b block.Block) bool { + s.verify = func(b block.Block[util.Uint256, util.Uint160]) bool { for _, tx := range b.Transactions() { if tx.(testTx)%10 == 0 { return false @@ -129,7 +129,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) { t.Run("receive request from primary", func(t *testing.T) { s.currHeight = 4 - service := New(s.getOptions()...) + service := New[util.Uint256, util.Uint160](s.getOptions()...) txs := []testTx{1} s.pool.Add(txs[0]) @@ -161,7 +161,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) { t.Run("change view on invalid tx", func(t *testing.T) { s.currHeight = 4 - service := New(s.getOptions()...) + service := New[util.Uint256, util.Uint160](s.getOptions()...) txs := []testTx{10} service.Start(0) @@ -189,7 +189,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) { t.Run("receive invalid prepare request", func(t *testing.T) { s.currHeight = 4 - service := New(s.getOptions()...) + service := New[util.Uint256, util.Uint160](s.getOptions()...) txs := []testTx{1, 2} s.pool.Add(txs[0]) @@ -238,7 +238,7 @@ func TestDBFT_CommitOnTransaction(t *testing.T) { s := newTestState(0, 4) s.currHeight = 1 - srv := New(s.getOptions()...) + srv := New[util.Uint256, util.Uint160](s.getOptions()...) srv.Start(0) require.Nil(t, s.tryRecv()) @@ -258,7 +258,7 @@ func TestDBFT_CommitOnTransaction(t *testing.T) { privs: s.privs, } s1.pool.Add(tx) - srv1 := New(s1.getOptions()...) + srv1 := New[util.Uint256, util.Uint160](s1.getOptions()...) srv1.Start(0) srv1.OnReceive(req) srv1.OnReceive(s1.getPrepareResponse(1, req.Hash())) @@ -280,7 +280,7 @@ func TestDBFT_OnReceiveCommit(t *testing.T) { s := newTestState(2, 4) t.Run("send commit after enough responses", func(t *testing.T) { s.currHeight = 1 - service := New(s.getOptions()...) + service := New[util.Uint256, util.Uint160](s.getOptions()...) service.Start(0) req := s.tryRecv() @@ -340,7 +340,7 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) { s := newTestState(2, 4) t.Run("send recovery message", func(t *testing.T) { s.currHeight = 1 - service := New(s.getOptions()...) + service := New[util.Uint256, util.Uint160](s.getOptions()...) service.Start(0) req := s.tryRecv() @@ -362,7 +362,7 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) { require.Equal(t, payload.RecoveryMessageType, rm.Type()) other := s.copyWithIndex(3) - srv2 := New(other.getOptions()...) + srv2 := New[util.Uint256, util.Uint160](other.getOptions()...) srv2.Start(0) srv2.OnReceive(rm) @@ -385,7 +385,7 @@ func TestDBFT_OnReceiveChangeView(t *testing.T) { s := newTestState(2, 4) t.Run("change view correctly", func(t *testing.T) { s.currHeight = 6 - service := New(s.getOptions()...) + service := New[util.Uint256, util.Uint160](s.getOptions()...) service.Start(0) resp := s.getChangeView(1, 1) @@ -412,31 +412,94 @@ func TestDBFT_OnReceiveChangeView(t *testing.T) { func TestDBFT_Invalid(t *testing.T) { t.Run("without keys", func(t *testing.T) { - require.Nil(t, New()) + require.Nil(t, New[util.Uint256, util.Uint160]()) }) priv, pub := crypto.Generate(rand.Reader) require.NotNil(t, priv) require.NotNil(t, pub) - opts := []Option{WithKeyPair(priv, pub)} + opts := []func(*Config[util.Uint256, util.Uint160]){WithKeyPair[util.Uint256, util.Uint160](priv, pub)} t.Run("without CurrentHeight", func(t *testing.T) { require.Nil(t, New(opts...)) }) - opts = append(opts, WithCurrentHeight(func() uint32 { return 0 })) + opts = append(opts, WithCurrentHeight[util.Uint256, util.Uint160](func() uint32 { return 0 })) t.Run("without CurrentBlockHash", func(t *testing.T) { require.Nil(t, New(opts...)) }) - opts = append(opts, WithCurrentBlockHash(func() util.Uint256 { return util.Uint256{} })) + opts = append(opts, WithCurrentBlockHash[util.Uint256, util.Uint160](func() util.Uint256 { return util.Uint256{} })) t.Run("without GetValidators", func(t *testing.T) { require.Nil(t, New(opts...)) }) - opts = append(opts, WithGetValidators(func(...block.Transaction) []crypto.PublicKey { + opts = append(opts, WithGetValidators[util.Uint256, util.Uint160](func(...block.Transaction[util.Uint256]) []crypto.PublicKey { return []crypto.PublicKey{pub} })) + t.Run("without NewBlockFromContext", func(t *testing.T) { + require.Nil(t, New(opts...)) + }) + + opts = append(opts, WithNewBlockFromContext[util.Uint256, util.Uint160](func(_ *Context[util.Uint256, util.Uint160]) block.Block[util.Uint256, util.Uint160] { + return nil + })) + t.Run("without GetConsensusAddress", func(t *testing.T) { + require.Nil(t, New(opts...)) + }) + + opts = append(opts, WithGetConsensusAddress[util.Uint256, util.Uint160](func(_ ...crypto.PublicKey) util.Uint160 { + return util.Uint160{} + })) + t.Run("without NewConsensusPayload", func(t *testing.T) { + require.Nil(t, New(opts...)) + }) + + opts = append(opts, WithNewConsensusPayload[util.Uint256, util.Uint160](func(_ *Context[util.Uint256, util.Uint160], _ payload.MessageType, _ any) payload.ConsensusPayload[util.Uint256, util.Uint160] { + return nil + })) + t.Run("without NewPrepareRequest", func(t *testing.T) { + require.Nil(t, New(opts...)) + }) + + opts = append(opts, WithNewPrepareRequest[util.Uint256, util.Uint160](func() payload.PrepareRequest[util.Uint256, util.Uint160] { + return nil + })) + t.Run("without NewPrepareResponse", func(t *testing.T) { + require.Nil(t, New(opts...)) + }) + + opts = append(opts, WithNewPrepareResponse[util.Uint256, util.Uint160](func() payload.PrepareResponse[util.Uint256] { + return nil + })) + t.Run("without NewChangeView", func(t *testing.T) { + require.Nil(t, New(opts...)) + }) + + opts = append(opts, WithNewChangeView[util.Uint256, util.Uint160](func() payload.ChangeView { + return nil + })) + t.Run("without NewCommit", func(t *testing.T) { + require.Nil(t, New(opts...)) + }) + + opts = append(opts, WithNewCommit[util.Uint256, util.Uint160](func() payload.Commit { + return nil + })) + t.Run("without NewRecoveryRequest", func(t *testing.T) { + require.Nil(t, New(opts...)) + }) + + opts = append(opts, WithNewRecoveryRequest[util.Uint256, util.Uint160](func() payload.RecoveryRequest { + return nil + })) + t.Run("without NewRecoveryMessage", func(t *testing.T) { + require.Nil(t, New(opts...)) + }) + + opts = append(opts, WithNewRecoveryMessage[util.Uint256, util.Uint160](func() payload.RecoveryMessage[util.Uint256, util.Uint160] { + return nil + })) t.Run("with all defaults", func(t *testing.T) { d := New(opts...) require.NotNil(t, d) @@ -467,19 +530,19 @@ func TestDBFT_Invalid(t *testing.T) { func TestDBFT_FourGoodNodesDeadlock(t *testing.T) { r0 := newTestState(0, 4) r0.currHeight = 4 - s0 := New(r0.getOptions()...) + s0 := New[util.Uint256, util.Uint160](r0.getOptions()...) s0.Start(0) r1 := r0.copyWithIndex(1) - s1 := New(r1.getOptions()...) + s1 := New[util.Uint256, util.Uint160](r1.getOptions()...) s1.Start(0) r2 := r0.copyWithIndex(2) - s2 := New(r2.getOptions()...) + s2 := New[util.Uint256, util.Uint160](r2.getOptions()...) s2.Start(0) r3 := r0.copyWithIndex(3) - s3 := New(r3.getOptions()...) + s3 := New[util.Uint256, util.Uint160](r3.getOptions()...) s3.Start(0) // Step 1. The primary (at view 0) replica 1 sends the PrepareRequest message. @@ -752,7 +815,7 @@ func (s *testState) tryRecv() Payload { return p } -func (s *testState) nextBlock() block.Block { +func (s *testState) nextBlock() block.Block[util.Uint256, util.Uint160] { if len(s.blocks) == 0 { return nil } @@ -779,37 +842,37 @@ func (s testState) nextConsensus(...crypto.PublicKey) util.Uint160 { return util.Uint160{1} } -func (s *testState) getOptions() []Option { - opts := []Option{ - WithCurrentHeight(func() uint32 { return s.currHeight }), - WithCurrentBlockHash(func() util.Uint256 { return s.currHash }), - WithGetValidators(func(...block.Transaction) []crypto.PublicKey { return s.pubs }), - WithKeyPair(s.privs[s.myIndex], s.pubs[s.myIndex]), - WithBroadcast(func(p Payload) { s.ch = append(s.ch, p) }), - WithGetTx(s.pool.Get), - WithProcessBlock(func(b block.Block) { s.blocks = append(s.blocks, b) }), - WithGetConsensusAddress(s.nextConsensus), - WithWatchOnly(func() bool { return false }), - WithGetBlock(func(util.Uint256) block.Block { return nil }), - WithTimer(timer.New()), - WithLogger(zap.NewNop()), - WithNewBlockFromContext(NewBlockFromContext), - WithSecondsPerBlock(time.Second * 10), - WithRequestTx(func(...util.Uint256) {}), - WithGetVerified(func() []block.Transaction { return []block.Transaction{} }), - - WithNewConsensusPayload(defaultNewConsensusPayload), - WithNewPrepareRequest(payload.NewPrepareRequest), - WithNewPrepareResponse(payload.NewPrepareResponse), - WithNewChangeView(payload.NewChangeView), - WithNewCommit(payload.NewCommit), - WithNewRecoveryRequest(payload.NewRecoveryRequest), - WithNewRecoveryMessage(payload.NewRecoveryMessage), +func (s *testState) getOptions() []func(*Config[util.Uint256, util.Uint160]) { + opts := []func(*Config[util.Uint256, util.Uint160]){ + WithCurrentHeight[util.Uint256, util.Uint160](func() uint32 { return s.currHeight }), + WithCurrentBlockHash[util.Uint256, util.Uint160](func() util.Uint256 { return s.currHash }), + WithGetValidators[util.Uint256, util.Uint160](func(...block.Transaction[util.Uint256]) []crypto.PublicKey { return s.pubs }), + WithKeyPair[util.Uint256, util.Uint160](s.privs[s.myIndex], s.pubs[s.myIndex]), + WithBroadcast[util.Uint256, util.Uint160](func(p Payload) { s.ch = append(s.ch, p) }), + WithGetTx[util.Uint256, util.Uint160](s.pool.Get), + WithProcessBlock[util.Uint256, util.Uint160](func(b block.Block[util.Uint256, util.Uint160]) { s.blocks = append(s.blocks, b) }), + WithGetConsensusAddress[util.Uint256, util.Uint160](s.nextConsensus), + WithWatchOnly[util.Uint256, util.Uint160](func() bool { return false }), + WithGetBlock[util.Uint256, util.Uint160](func(util.Uint256) block.Block[util.Uint256, util.Uint160] { return nil }), + WithTimer[util.Uint256, util.Uint160](timer.New()), + WithLogger[util.Uint256, util.Uint160](zap.NewNop()), + WithNewBlockFromContext[util.Uint256, util.Uint160](NewBlockFromContext), + WithSecondsPerBlock[util.Uint256, util.Uint160](time.Second * 10), + WithRequestTx[util.Uint256, util.Uint160](func(...util.Uint256) {}), + WithGetVerified[util.Uint256, util.Uint160](func() []block.Transaction[util.Uint256] { return []block.Transaction[util.Uint256]{} }), + + WithNewConsensusPayload[util.Uint256, util.Uint160](newConsensusPayload), + WithNewPrepareRequest[util.Uint256, util.Uint160](payload.NewPrepareRequest), + WithNewPrepareResponse[util.Uint256, util.Uint160](payload.NewPrepareResponse), + WithNewChangeView[util.Uint256, util.Uint160](payload.NewChangeView), + WithNewCommit[util.Uint256, util.Uint160](payload.NewCommit), + WithNewRecoveryRequest[util.Uint256, util.Uint160](payload.NewRecoveryRequest), + WithNewRecoveryMessage[util.Uint256, util.Uint160](payload.NewRecoveryMessage), } verify := s.verify if verify == nil { - verify = func(block.Block) bool { return true } + verify = func(block.Block[util.Uint256, util.Uint160]) bool { return true } } opts = append(opts, WithVerifyBlock(verify)) @@ -818,12 +881,25 @@ func (s *testState) getOptions() []Option { cfg := zap.NewDevelopmentConfig() cfg.DisableStacktrace = true logger, _ := cfg.Build() - opts = append(opts, WithLogger(logger)) + opts = append(opts, WithLogger[util.Uint256, util.Uint160](logger)) } return opts } +// newConsensusPayload is a function for creating consensus payload of specific +// type. +func newConsensusPayload(c *Context[util.Uint256, util.Uint160], t payload.MessageType, msg any) payload.ConsensusPayload[util.Uint256, util.Uint160] { + cp := payload.NewConsensusPayload() + cp.SetHeight(c.BlockIndex) + cp.SetValidatorIndex(uint16(c.MyIndex)) + cp.SetViewNumber(c.ViewNumber) + cp.SetType(t) + cp.SetPayload(msg) + + return cp +} + func getTestValidators(n int) (privs []crypto.PrivateKey, pubs []crypto.PublicKey) { for i := 0; i < n; i++ { priv, pub := crypto.Generate(rand.Reader) @@ -849,7 +925,7 @@ func (p *testPool) Add(tx testTx) { p.storage[tx.Hash()] = tx } -func (p *testPool) Get(h util.Uint256) block.Transaction { +func (p *testPool) Get(h util.Uint256) block.Transaction[util.Uint256] { if tx, ok := p.storage[h]; ok { return tx } diff --git a/helpers.go b/helpers.go index 510fa132..c0b0fcc3 100644 --- a/helpers.go +++ b/helpers.go @@ -1,39 +1,40 @@ package dbft import ( + "github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/dbft/payload" ) type ( // inbox is a structure storing messages from a single epoch. - inbox struct { - prepare map[uint16]payload.ConsensusPayload - chViews map[uint16]payload.ConsensusPayload - commit map[uint16]payload.ConsensusPayload + inbox[H crypto.Hash, A crypto.Address] struct { + prepare map[uint16]payload.ConsensusPayload[H, A] + chViews map[uint16]payload.ConsensusPayload[H, A] + commit map[uint16]payload.ConsensusPayload[H, A] } // cache is an auxiliary structure storing messages // from future epochs. - cache struct { - mail map[uint32]*inbox + cache[H crypto.Hash, A crypto.Address] struct { + mail map[uint32]*inbox[H, A] } ) -func newInbox() *inbox { - return &inbox{ - prepare: make(map[uint16]payload.ConsensusPayload), - chViews: make(map[uint16]payload.ConsensusPayload), - commit: make(map[uint16]payload.ConsensusPayload), +func newInbox[H crypto.Hash, A crypto.Address]() *inbox[H, A] { + return &inbox[H, A]{ + prepare: make(map[uint16]payload.ConsensusPayload[H, A]), + chViews: make(map[uint16]payload.ConsensusPayload[H, A]), + commit: make(map[uint16]payload.ConsensusPayload[H, A]), } } -func newCache() cache { - return cache{ - mail: make(map[uint32]*inbox), +func newCache[H crypto.Hash, A crypto.Address]() cache[H, A] { + return cache[H, A]{ + mail: make(map[uint32]*inbox[H, A]), } } -func (c *cache) getHeight(h uint32) *inbox { +func (c *cache[H, A]) getHeight(h uint32) *inbox[H, A] { if m, ok := c.mail[h]; ok { delete(c.mail, h) return m @@ -42,10 +43,10 @@ func (c *cache) getHeight(h uint32) *inbox { return nil } -func (c *cache) addMessage(m payload.ConsensusPayload) { +func (c *cache[H, A]) addMessage(m payload.ConsensusPayload[H, A]) { msgs, ok := c.mail[m.Height()] if !ok { - msgs = newInbox() + msgs = newInbox[H, A]() c.mail[m.Height()] = msgs } diff --git a/helpers_test.go b/helpers_test.go index add09d8f..7cac51d5 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -3,13 +3,14 @@ package dbft import ( "testing" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" "github.com/nspcc-dev/dbft/payload" ) func TestMessageCache(t *testing.T) { - c := newCache() + c := newCache[util.Uint256, util.Uint160]() p1 := payload.NewConsensusPayload() p1.SetHeight(3) diff --git a/payload/consensus_message.go b/payload/consensus_message.go index 0ebe8de7..0b9cce49 100644 --- a/payload/consensus_message.go +++ b/payload/consensus_message.go @@ -5,6 +5,8 @@ import ( "encoding/gob" "fmt" + "github.com/nspcc-dev/dbft/crypto" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/pkg/errors" ) @@ -18,7 +20,7 @@ type ( DecodeBinary(decoder *gob.Decoder) error } - consensusMessage interface { + consensusMessage[H crypto.Hash, A crypto.Address] interface { // ViewNumber returns view number when this message was originated. ViewNumber() byte // SetViewNumber sets view number. @@ -37,15 +39,15 @@ type ( // GetChangeView returns payload as if it was ChangeView. GetChangeView() ChangeView // GetPrepareRequest returns payload as if it was PrepareRequest. - GetPrepareRequest() PrepareRequest + GetPrepareRequest() PrepareRequest[H, A] // GetPrepareResponse returns payload as if it was PrepareResponse. - GetPrepareResponse() PrepareResponse + GetPrepareResponse() PrepareResponse[H] // GetCommit returns payload as if it was Commit. GetCommit() Commit // GetRecoveryRequest returns payload as if it was RecoveryRequest. GetRecoveryRequest() RecoveryRequest // GetRecoveryMessage returns payload as if it was RecoveryMessage. - GetRecoveryMessage() RecoveryMessage + GetRecoveryMessage() RecoveryMessage[H, A] } message struct { @@ -73,7 +75,7 @@ const ( RecoveryMessageType MessageType = 0x41 ) -var _ consensusMessage = (*message)(nil) +var _ consensusMessage[util.Uint256, util.Uint160] = (*message)(nil) // String implements fmt.Stringer interface. func (m MessageType) String() string { @@ -142,12 +144,18 @@ func (m *message) DecodeBinary(r *gob.Decoder) error { return m.payload.(Serializable).DecodeBinary(dec) } -func (m message) GetChangeView() ChangeView { return m.payload.(ChangeView) } -func (m message) GetPrepareRequest() PrepareRequest { return m.payload.(PrepareRequest) } -func (m message) GetPrepareResponse() PrepareResponse { return m.payload.(PrepareResponse) } +func (m message) GetChangeView() ChangeView { return m.payload.(ChangeView) } +func (m message) GetPrepareRequest() PrepareRequest[util.Uint256, util.Uint160] { + return m.payload.(PrepareRequest[util.Uint256, util.Uint160]) +} +func (m message) GetPrepareResponse() PrepareResponse[util.Uint256] { + return m.payload.(PrepareResponse[util.Uint256]) +} func (m message) GetCommit() Commit { return m.payload.(Commit) } func (m message) GetRecoveryRequest() RecoveryRequest { return m.payload.(RecoveryRequest) } -func (m message) GetRecoveryMessage() RecoveryMessage { return m.payload.(RecoveryMessage) } +func (m message) GetRecoveryMessage() RecoveryMessage[util.Uint256, util.Uint160] { + return m.payload.(RecoveryMessage[util.Uint256, util.Uint160]) +} // ViewNumber implements ConsensusMessage interface. func (m message) ViewNumber() byte { diff --git a/payload/constructors.go b/payload/constructors.go index 7156c135..dade1898 100644 --- a/payload/constructors.go +++ b/payload/constructors.go @@ -1,17 +1,21 @@ package payload +import ( + "github.com/nspcc-dev/neo-go/pkg/util" +) + // NewConsensusPayload returns minimal ConsensusPayload implementation. -func NewConsensusPayload() ConsensusPayload { +func NewConsensusPayload() ConsensusPayload[util.Uint256, util.Uint160] { return &Payload{} } // NewPrepareRequest returns minimal prepareRequest implementation. -func NewPrepareRequest() PrepareRequest { +func NewPrepareRequest() PrepareRequest[util.Uint256, util.Uint160] { return new(prepareRequest) } // NewPrepareResponse returns minimal PrepareResponse implementation. -func NewPrepareResponse() PrepareResponse { +func NewPrepareResponse() PrepareResponse[util.Uint256] { return new(prepareResponse) } @@ -31,7 +35,7 @@ func NewRecoveryRequest() RecoveryRequest { } // NewRecoveryMessage returns minimal RecoveryMessage implementation. -func NewRecoveryMessage() RecoveryMessage { +func NewRecoveryMessage() RecoveryMessage[util.Uint256, util.Uint160] { return &recoveryMessage{ preparationPayloads: make([]preparationCompact, 0), commitPayloads: make([]commitCompact, 0), diff --git a/payload/message.go b/payload/message.go index d8ed52db..98ff90dd 100644 --- a/payload/message.go +++ b/payload/message.go @@ -11,8 +11,8 @@ import ( type ( // ConsensusPayload is a generic payload type which is exchanged // between the nodes. - ConsensusPayload interface { - consensusMessage + ConsensusPayload[H crypto.Hash, A crypto.Address] interface { + consensusMessage[H, A] // ValidatorIndex returns index of validator from which // payload was originated from. @@ -25,7 +25,7 @@ type ( SetHeight(h uint32) // Hash returns 32-byte checksum of the payload. - Hash() util.Uint256 + Hash() H } // Payload represents minimal payload containing all necessary fields. @@ -51,7 +51,7 @@ type ( } ) -var _ ConsensusPayload = (*Payload)(nil) +var _ ConsensusPayload[util.Uint256, util.Uint160] = (*Payload)(nil) // EncodeBinary implements Serializable interface. func (p Payload) EncodeBinary(w *gob.Encoder) error { diff --git a/payload/message_test.go b/payload/message_test.go index 82d27746..8d1777ac 100644 --- a/payload/message_test.go +++ b/payload/message_test.go @@ -122,11 +122,11 @@ func TestRecoveryMessage_NoPayloads(t *testing.T) { rec := m.GetRecoveryMessage() require.NotNil(t, rec) - var p ConsensusPayload + var p ConsensusPayload[util.Uint256, util.Uint160] require.NotPanics(t, func() { p = rec.GetPrepareRequest(p, validators, 0) }) require.Nil(t, p) - var ps []ConsensusPayload + var ps []ConsensusPayload[util.Uint256, util.Uint160] require.NotPanics(t, func() { ps = rec.GetPrepareResponses(p, validators) }) require.Len(t, ps, 0) diff --git a/payload/prepare_request.go b/payload/prepare_request.go index 8d6b31eb..8e9612d6 100644 --- a/payload/prepare_request.go +++ b/payload/prepare_request.go @@ -3,11 +3,12 @@ package payload import ( "encoding/gob" + "github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/neo-go/pkg/util" ) // PrepareRequest represents dBFT PrepareRequest message. -type PrepareRequest interface { +type PrepareRequest[H crypto.Hash, A crypto.Address] interface { // Timestamp returns this message's timestamp. Timestamp() uint64 // SetTimestamp sets timestamp of this message. @@ -19,15 +20,15 @@ type PrepareRequest interface { SetNonce(nonce uint64) // TransactionHashes returns hashes of all transaction in a proposed block. - TransactionHashes() []util.Uint256 + TransactionHashes() []H // SetTransactionHashes sets transaction's hashes. - SetTransactionHashes(hs []util.Uint256) + SetTransactionHashes(hs []H) // NextConsensus returns hash which is based on which validators will // try to agree on a block in the current epoch. - NextConsensus() util.Uint160 + NextConsensus() A // SetNextConsensus sets next consensus field. - SetNextConsensus(nc util.Uint160) + SetNextConsensus(nc A) } type ( @@ -46,7 +47,7 @@ type ( } ) -var _ PrepareRequest = (*prepareRequest)(nil) +var _ PrepareRequest[util.Uint256, util.Uint160] = (*prepareRequest)(nil) // EncodeBinary implements Serializable interface. func (p prepareRequest) EncodeBinary(w *gob.Encoder) error { diff --git a/payload/prepare_response.go b/payload/prepare_response.go index 0268d7df..7d0f6ba3 100644 --- a/payload/prepare_response.go +++ b/payload/prepare_response.go @@ -3,16 +3,17 @@ package payload import ( "encoding/gob" + "github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/neo-go/pkg/util" ) // PrepareResponse represents dBFT PrepareResponse message. -type PrepareResponse interface { +type PrepareResponse[H crypto.Hash] interface { // PreparationHash returns the hash of PrepareRequest payload // for this epoch. - PreparationHash() util.Uint256 + PreparationHash() H // SetPreparationHash sets preparations hash. - SetPreparationHash(h util.Uint256) + SetPreparationHash(h H) } type ( @@ -25,7 +26,7 @@ type ( } ) -var _ PrepareResponse = (*prepareResponse)(nil) +var _ PrepareResponse[util.Uint256] = (*prepareResponse)(nil) // EncodeBinary implements Serializable interface. func (p prepareResponse) EncodeBinary(w *gob.Encoder) error { diff --git a/payload/recovery_message.go b/payload/recovery_message.go index 117473e8..d64bbd8b 100644 --- a/payload/recovery_message.go +++ b/payload/recovery_message.go @@ -10,23 +10,23 @@ import ( type ( // RecoveryMessage represents dBFT Recovery message. - RecoveryMessage interface { + RecoveryMessage[H crypto.Hash, A crypto.Address] interface { // AddPayload adds payload from this epoch to be recovered. - AddPayload(p ConsensusPayload) + AddPayload(p ConsensusPayload[H, A]) // GetPrepareRequest returns PrepareRequest to be processed. - GetPrepareRequest(p ConsensusPayload, validators []crypto.PublicKey, primary uint16) ConsensusPayload + GetPrepareRequest(p ConsensusPayload[H, A], validators []crypto.PublicKey, primary uint16) ConsensusPayload[H, A] // GetPrepareResponses returns a slice of PrepareResponse in any order. - GetPrepareResponses(p ConsensusPayload, validators []crypto.PublicKey) []ConsensusPayload + GetPrepareResponses(p ConsensusPayload[H, A], validators []crypto.PublicKey) []ConsensusPayload[H, A] // GetChangeViews returns a slice of ChangeView in any order. - GetChangeViews(p ConsensusPayload, validators []crypto.PublicKey) []ConsensusPayload + GetChangeViews(p ConsensusPayload[H, A], validators []crypto.PublicKey) []ConsensusPayload[H, A] // GetCommits returns a slice of Commit in any order. - GetCommits(p ConsensusPayload, validators []crypto.PublicKey) []ConsensusPayload + GetCommits(p ConsensusPayload[H, A], validators []crypto.PublicKey) []ConsensusPayload[H, A] // PreparationHash returns has of PrepareRequest payload for this epoch. // It can be useful in case only PrepareResponse payloads were received. - PreparationHash() *util.Uint256 + PreparationHash() *H // SetPreparationHash sets preparation hash. - SetPreparationHash(h *util.Uint256) + SetPreparationHash(h *H) } recoveryMessage struct { @@ -34,7 +34,7 @@ type ( preparationPayloads []preparationCompact commitPayloads []commitCompact changeViewPayloads []changeViewCompact - prepareRequest PrepareRequest + prepareRequest PrepareRequest[util.Uint256, util.Uint160] } // recoveryMessageAux is an auxiliary structure for recoveryMessage encoding. recoveryMessageAux struct { @@ -44,7 +44,7 @@ type ( } ) -var _ RecoveryMessage = (*recoveryMessage)(nil) +var _ RecoveryMessage[util.Uint256, util.Uint160] = (*recoveryMessage)(nil) // PreparationHash implements RecoveryMessage interface. func (m *recoveryMessage) PreparationHash() *util.Uint256 { @@ -57,7 +57,7 @@ func (m *recoveryMessage) SetPreparationHash(h *util.Uint256) { } // AddPayload implements RecoveryMessage interface. -func (m *recoveryMessage) AddPayload(p ConsensusPayload) { +func (m *recoveryMessage) AddPayload(p ConsensusPayload[util.Uint256, util.Uint160]) { switch p.Type() { case PrepareRequestType: m.prepareRequest = p.GetPrepareRequest() @@ -83,7 +83,7 @@ func (m *recoveryMessage) AddPayload(p ConsensusPayload) { } } -func fromPayload(t MessageType, recovery ConsensusPayload, p Serializable) *Payload { +func fromPayload(t MessageType, recovery ConsensusPayload[util.Uint256, util.Uint160], p Serializable) *Payload { return &Payload{ message: message{ cmType: t, @@ -95,7 +95,7 @@ func fromPayload(t MessageType, recovery ConsensusPayload, p Serializable) *Payl } // GetPrepareRequest implements RecoveryMessage interface. -func (m *recoveryMessage) GetPrepareRequest(p ConsensusPayload, _ []crypto.PublicKey, ind uint16) ConsensusPayload { +func (m *recoveryMessage) GetPrepareRequest(p ConsensusPayload[util.Uint256, util.Uint160], _ []crypto.PublicKey, ind uint16) ConsensusPayload[util.Uint256, util.Uint160] { if m.prepareRequest == nil { return nil } @@ -113,12 +113,12 @@ func (m *recoveryMessage) GetPrepareRequest(p ConsensusPayload, _ []crypto.Publi } // GetPrepareResponses implements RecoveryMessage interface. -func (m *recoveryMessage) GetPrepareResponses(p ConsensusPayload, _ []crypto.PublicKey) []ConsensusPayload { +func (m *recoveryMessage) GetPrepareResponses(p ConsensusPayload[util.Uint256, util.Uint160], _ []crypto.PublicKey) []ConsensusPayload[util.Uint256, util.Uint160] { if m.preparationHash == nil { return nil } - payloads := make([]ConsensusPayload, len(m.preparationPayloads)) + payloads := make([]ConsensusPayload[util.Uint256, util.Uint160], len(m.preparationPayloads)) for i, resp := range m.preparationPayloads { payloads[i] = fromPayload(PrepareResponseType, p, &prepareResponse{ @@ -131,8 +131,8 @@ func (m *recoveryMessage) GetPrepareResponses(p ConsensusPayload, _ []crypto.Pub } // GetChangeViews implements RecoveryMessage interface. -func (m *recoveryMessage) GetChangeViews(p ConsensusPayload, _ []crypto.PublicKey) []ConsensusPayload { - payloads := make([]ConsensusPayload, len(m.changeViewPayloads)) +func (m *recoveryMessage) GetChangeViews(p ConsensusPayload[util.Uint256, util.Uint160], _ []crypto.PublicKey) []ConsensusPayload[util.Uint256, util.Uint160] { + payloads := make([]ConsensusPayload[util.Uint256, util.Uint160], len(m.changeViewPayloads)) for i, cv := range m.changeViewPayloads { payloads[i] = fromPayload(ChangeViewType, p, &changeView{ @@ -146,8 +146,8 @@ func (m *recoveryMessage) GetChangeViews(p ConsensusPayload, _ []crypto.PublicKe } // GetCommits implements RecoveryMessage interface. -func (m *recoveryMessage) GetCommits(p ConsensusPayload, _ []crypto.PublicKey) []ConsensusPayload { - payloads := make([]ConsensusPayload, len(m.commitPayloads)) +func (m *recoveryMessage) GetCommits(p ConsensusPayload[util.Uint256, util.Uint160], _ []crypto.PublicKey) []ConsensusPayload[util.Uint256, util.Uint160] { + payloads := make([]ConsensusPayload[util.Uint256, util.Uint160], len(m.commitPayloads)) for i, c := range m.commitPayloads { payloads[i] = fromPayload(CommitType, p, &commit{signature: c.Signature}) diff --git a/send.go b/send.go index 2c86fd3c..64f71bca 100644 --- a/send.go +++ b/send.go @@ -5,7 +5,7 @@ import ( "go.uber.org/zap" ) -func (d *DBFT) broadcast(msg payload.ConsensusPayload) { +func (d *DBFT[H, A]) broadcast(msg payload.ConsensusPayload[H, A]) { d.Logger.Debug("broadcasting message", zap.Stringer("type", msg.Type()), zap.Uint32("height", d.BlockIndex), @@ -15,7 +15,7 @@ func (d *DBFT) broadcast(msg payload.ConsensusPayload) { d.Broadcast(msg) } -func (c *Context) makePrepareRequest() payload.ConsensusPayload { +func (c *Context[H, A]) makePrepareRequest() payload.ConsensusPayload[H, A] { c.Fill() req := c.Config.NewPrepareRequest() @@ -27,7 +27,7 @@ func (c *Context) makePrepareRequest() payload.ConsensusPayload { return c.Config.NewConsensusPayload(c, payload.PrepareRequestType, req) } -func (d *DBFT) sendPrepareRequest() { +func (d *DBFT[H, A]) sendPrepareRequest() { msg := d.makePrepareRequest() d.PreparationPayloads[d.MyIndex] = msg d.broadcast(msg) @@ -42,7 +42,7 @@ func (d *DBFT) sendPrepareRequest() { d.checkPrepare() } -func (c *Context) makeChangeView(ts uint64, reason payload.ChangeViewReason) payload.ConsensusPayload { +func (c *Context[H, A]) makeChangeView(ts uint64, reason payload.ChangeViewReason) payload.ConsensusPayload[H, A] { cv := c.Config.NewChangeView() cv.SetNewViewNumber(c.ViewNumber + 1) cv.SetTimestamp(ts) @@ -54,7 +54,7 @@ func (c *Context) makeChangeView(ts uint64, reason payload.ChangeViewReason) pay return msg } -func (d *DBFT) sendChangeView(reason payload.ChangeViewReason) { +func (d *DBFT[H, A]) sendChangeView(reason payload.ChangeViewReason) { if d.Context.WatchOnly() { return } @@ -91,7 +91,7 @@ func (d *DBFT) sendChangeView(reason payload.ChangeViewReason) { d.checkChangeView(newView) } -func (c *Context) makePrepareResponse() payload.ConsensusPayload { +func (c *Context[H, A]) makePrepareResponse() payload.ConsensusPayload[H, A] { resp := c.Config.NewPrepareResponse() resp.SetPreparationHash(c.PreparationPayloads[c.PrimaryIndex].Hash()) @@ -101,14 +101,14 @@ func (c *Context) makePrepareResponse() payload.ConsensusPayload { return msg } -func (d *DBFT) sendPrepareResponse() { +func (d *DBFT[H, A]) sendPrepareResponse() { msg := d.makePrepareResponse() d.Logger.Info("sending PrepareResponse", zap.Uint32("height", d.BlockIndex), zap.Uint("view", uint(d.ViewNumber))) d.StopTxFlow() d.broadcast(msg) } -func (c *Context) makeCommit() payload.ConsensusPayload { +func (c *Context[H, A]) makeCommit() payload.ConsensusPayload[H, A] { if msg := c.CommitPayloads[c.MyIndex]; msg != nil { return msg } @@ -128,14 +128,14 @@ func (c *Context) makeCommit() payload.ConsensusPayload { return nil } -func (d *DBFT) sendCommit() { +func (d *DBFT[H, A]) sendCommit() { msg := d.makeCommit() d.CommitPayloads[d.MyIndex] = msg d.Logger.Info("sending Commit", zap.Uint32("height", d.BlockIndex), zap.Uint("view", uint(d.ViewNumber))) d.broadcast(msg) } -func (d *DBFT) sendRecoveryRequest() { +func (d *DBFT[H, A]) sendRecoveryRequest() { // If we're here, something is wrong, we either missing some messages or // transactions or both, so re-request missing transactions here too. if d.RequestSentOrReceived() && !d.hasAllTransactions() { @@ -146,7 +146,7 @@ func (d *DBFT) sendRecoveryRequest() { d.broadcast(d.Config.NewConsensusPayload(&d.Context, payload.RecoveryRequestType, req)) } -func (c *Context) makeRecoveryMessage() payload.ConsensusPayload { +func (c *Context[H, A]) makeRecoveryMessage() payload.ConsensusPayload[H, A] { recovery := c.Config.NewRecoveryMessage() for _, p := range c.PreparationPayloads { @@ -176,19 +176,6 @@ func (c *Context) makeRecoveryMessage() payload.ConsensusPayload { return c.Config.NewConsensusPayload(c, payload.RecoveryMessageType, recovery) } -func (d *DBFT) sendRecoveryMessage() { +func (d *DBFT[H, A]) sendRecoveryMessage() { d.broadcast(d.makeRecoveryMessage()) } - -// defaultNewConsensusPayload is default function for creating -// consensus payload of specific type. -func defaultNewConsensusPayload(c *Context, t payload.MessageType, msg any) payload.ConsensusPayload { - cp := payload.NewConsensusPayload() - cp.SetHeight(c.BlockIndex) - cp.SetValidatorIndex(uint16(c.MyIndex)) - cp.SetViewNumber(c.ViewNumber) - cp.SetType(t) - cp.SetPayload(msg) - - return cp -} diff --git a/simulation/main.go b/simulation/main.go index faae3a06..18b21a58 100644 --- a/simulation/main.go +++ b/simulation/main.go @@ -28,8 +28,8 @@ import ( type ( simNode struct { id int - d *dbft.DBFT - messages chan payload.ConsensusPayload + d *dbft.DBFT[util.Uint256, util.Uint160] + messages chan payload.ConsensusPayload[util.Uint256, util.Uint160] key crypto.PrivateKey pub crypto.PublicKey pool *memPool @@ -111,11 +111,24 @@ func initNodes(nodes []*simNode, log *zap.Logger) { } } +// defaultNewConsensusPayload is default function for creating +// consensus payload of specific type. +func defaultNewConsensusPayload(c *dbft.Context[util.Uint256, util.Uint160], t payload.MessageType, msg any) payload.ConsensusPayload[util.Uint256, util.Uint160] { + cp := payload.NewConsensusPayload() + cp.SetHeight(c.BlockIndex) + cp.SetValidatorIndex(uint16(c.MyIndex)) + cp.SetViewNumber(c.ViewNumber) + cp.SetType(t) + cp.SetPayload(msg) + + return cp +} + func initSimNode(nodes []*simNode, i int, log *zap.Logger) error { key, pub := crypto.Generate(rand.Reader) nodes[i] = &simNode{ id: i, - messages: make(chan payload.ConsensusPayload, defaultChanSize), + messages: make(chan payload.ConsensusPayload[util.Uint256, util.Uint160], defaultChanSize), key: key, pub: pub, pool: newMemoryPool(), @@ -123,19 +136,29 @@ func initSimNode(nodes []*simNode, i int, log *zap.Logger) error { cluster: nodes, } - nodes[i].d = dbft.New( - dbft.WithLogger(nodes[i].log), - dbft.WithSecondsPerBlock(time.Second*5), - dbft.WithKeyPair(key, pub), - dbft.WithGetTx(nodes[i].pool.Get), - dbft.WithGetVerified(nodes[i].pool.GetVerified), - dbft.WithBroadcast(nodes[i].Broadcast), - dbft.WithProcessBlock(nodes[i].ProcessBlock), - dbft.WithCurrentHeight(nodes[i].CurrentHeight), - dbft.WithCurrentBlockHash(nodes[i].CurrentBlockHash), - dbft.WithGetValidators(nodes[i].GetValidators), - dbft.WithVerifyPrepareRequest(nodes[i].VerifyPayload), - dbft.WithVerifyPrepareResponse(nodes[i].VerifyPayload), + nodes[i].d = dbft.New[util.Uint256, util.Uint160]( + dbft.WithLogger[util.Uint256, util.Uint160](nodes[i].log), + dbft.WithSecondsPerBlock[util.Uint256, util.Uint160](time.Second*5), + dbft.WithKeyPair[util.Uint256, util.Uint160](key, pub), + dbft.WithGetTx[util.Uint256, util.Uint160](nodes[i].pool.Get), + dbft.WithGetVerified[util.Uint256, util.Uint160](nodes[i].pool.GetVerified), + dbft.WithBroadcast[util.Uint256, util.Uint160](nodes[i].Broadcast), + dbft.WithProcessBlock[util.Uint256, util.Uint160](nodes[i].ProcessBlock), + dbft.WithCurrentHeight[util.Uint256, util.Uint160](nodes[i].CurrentHeight), + dbft.WithCurrentBlockHash[util.Uint256, util.Uint160](nodes[i].CurrentBlockHash), + dbft.WithGetValidators[util.Uint256, util.Uint160](nodes[i].GetValidators), + dbft.WithVerifyPrepareRequest[util.Uint256, util.Uint160](nodes[i].VerifyPayload), + dbft.WithVerifyPrepareResponse[util.Uint256, util.Uint160](nodes[i].VerifyPayload), + + dbft.WithNewBlockFromContext[util.Uint256, util.Uint160](dbft.NewBlockFromContext), + dbft.WithGetConsensusAddress[util.Uint256, util.Uint160](func(...crypto.PublicKey) util.Uint160 { return util.Uint160{} }), + dbft.WithNewConsensusPayload[util.Uint256, util.Uint160](defaultNewConsensusPayload), + dbft.WithNewPrepareRequest[util.Uint256, util.Uint160](payload.NewPrepareRequest), + dbft.WithNewPrepareResponse[util.Uint256, util.Uint160](payload.NewPrepareResponse), + dbft.WithNewChangeView[util.Uint256, util.Uint160](payload.NewChangeView), + dbft.WithNewCommit[util.Uint256, util.Uint160](payload.NewCommit), + dbft.WithNewRecoveryMessage[util.Uint256, util.Uint160](payload.NewRecoveryMessage), + dbft.WithNewRecoveryRequest[util.Uint256, util.Uint160](payload.NewRecoveryRequest), ) if nodes[i].d == nil { @@ -168,7 +191,7 @@ func sortValidators(pubs []crypto.PublicKey) { }) } -func (n *simNode) Broadcast(m payload.ConsensusPayload) { +func (n *simNode) Broadcast(m payload.ConsensusPayload[util.Uint256, util.Uint160]) { for i, node := range n.cluster { if i != n.id { select { @@ -184,11 +207,11 @@ func (n *simNode) CurrentHeight() uint32 { return n.height } func (n *simNode) CurrentBlockHash() util.Uint256 { return n.lastHash } // GetValidators always returns the same list of validators. -func (n *simNode) GetValidators(...block.Transaction) []crypto.PublicKey { +func (n *simNode) GetValidators(...block.Transaction[util.Uint256]) []crypto.PublicKey { return n.validators } -func (n *simNode) ProcessBlock(b block.Block) { +func (n *simNode) ProcessBlock(b block.Block[util.Uint256, util.Uint160]) { n.d.Logger.Debug("received block", zap.Uint32("height", b.Index())) for _, tx := range b.Transactions() { @@ -200,7 +223,7 @@ func (n *simNode) ProcessBlock(b block.Block) { } // VerifyPrepareRequest verifies that payload was received from a good validator. -func (n *simNode) VerifyPayload(p payload.ConsensusPayload) error { +func (n *simNode) VerifyPayload(p payload.ConsensusPayload[util.Uint256, util.Uint160]) error { if *blocked != -1 && p.ValidatorIndex() == uint16(*blocked) { return fmt.Errorf("message from blocked validator: %d", *blocked) } @@ -220,7 +243,7 @@ func (n *simNode) addTx(count int) { type tx64 uint64 -var _ block.Transaction = (*tx64)(nil) +var _ block.Transaction[util.Uint256] = (*tx64)(nil) func (t *tx64) Hash() (h util.Uint256) { binary.LittleEndian.PutUint64(h[:], uint64(*t)) @@ -252,17 +275,17 @@ func (t *tx64) UnmarshalBinary(data []byte) error { type memPool struct { mtx *sync.RWMutex - store map[util.Uint256]block.Transaction + store map[util.Uint256]block.Transaction[util.Uint256] } func newMemoryPool() *memPool { return &memPool{ mtx: new(sync.RWMutex), - store: make(map[util.Uint256]block.Transaction), + store: make(map[util.Uint256]block.Transaction[util.Uint256]), } } -func (p *memPool) Add(tx block.Transaction) { +func (p *memPool) Add(tx block.Transaction[util.Uint256]) { p.mtx.Lock() h := tx.Hash() @@ -273,7 +296,7 @@ func (p *memPool) Add(tx block.Transaction) { p.mtx.Unlock() } -func (p *memPool) Get(h util.Uint256) (tx block.Transaction) { +func (p *memPool) Get(h util.Uint256) (tx block.Transaction[util.Uint256]) { p.mtx.RLock() tx = p.store[h] p.mtx.RUnlock() @@ -287,13 +310,13 @@ func (p *memPool) Delete(h util.Uint256) { p.mtx.Unlock() } -func (p *memPool) GetVerified() (txx []block.Transaction) { +func (p *memPool) GetVerified() (txx []block.Transaction[util.Uint256]) { n := *txPerBlock if n == 0 { return } - txx = make([]block.Transaction, 0, n) + txx = make([]block.Transaction[util.Uint256], 0, n) for _, tx := range p.store { txx = append(txx, tx)