diff --git a/block.go b/block.go index 3c51e846..4705673f 100644 --- a/block.go +++ b/block.go @@ -20,6 +20,11 @@ type Block[H Hash] interface { // Transactions returns block's transaction list. Transactions() []Transaction[H] - // SetTransactions sets block's transaction list. + // SetTransactions sets block's transaction list. For anti-MEV extension + // transactions provided via this call are taken directly from PreBlock level + // and thus, may be out-of-date. Thus, with anti-MEV extension enabled it's + // suggested to use this method as a Block finalizer since it will be called + // right before the block approval. Do not rely on this with anti-MEV extension + // disabled. SetTransactions([]Transaction[H]) } diff --git a/check.go b/check.go index 4b4b9e15..6e43404c 100644 --- a/check.go +++ b/check.go @@ -30,9 +30,57 @@ func (d *DBFT[H]) checkPrepare() { zap.Int("M", d.M())) if hasRequest && count >= d.M() { + if d.isAntiMEVExtensionEnabled() { + d.sendPreCommit() + d.changeTimer(d.SecondsPerBlock) + d.checkPreCommit() + } else { + d.sendCommit() + d.changeTimer(d.SecondsPerBlock) + d.checkCommit() + } + } +} + +func (d *DBFT[H]) checkPreCommit() { + if !d.hasAllTransactions() { + d.Logger.Debug("check preCommit: some transactions are missing", zap.Any("hashes", d.MissingTransactions)) + return + } + + count := 0 + for _, msg := range d.PreCommitPayloads { + if msg != nil && msg.ViewNumber() == d.ViewNumber { + count++ + } + } + + if count < d.M() { + d.Logger.Debug("not enough PreCommits to create PreBlock", zap.Int("count", count)) + return + } + + d.preBlock = d.CreatePreBlock() + + d.Logger.Info("processing PreBlock", + zap.Uint32("height", d.BlockIndex), + zap.Uint("view", uint(d.ViewNumber)), + zap.Int("tx_count", len(d.preBlock.Transactions()))) + + if !d.preBlockProcessed { + d.preBlockProcessed = true + d.ProcessPreBlock(d.preBlock) + } + + // Require PreCommit sent by self for reliability. This condition may be removed + // in the future. + if d.PreCommitSent() { + d.verifyCommitPayloadsAgainstHeader() d.sendCommit() d.changeTimer(d.SecondsPerBlock) d.checkCommit() + } else { + d.Logger.Debug("can't send commit since self preCommit not yet sent") } } diff --git a/commit.go b/commit.go index 40a44ee7..1b140048 100644 --- a/commit.go +++ b/commit.go @@ -3,6 +3,7 @@ package dbft // Commit is an interface for dBFT Commit message. type Commit interface { // Signature returns commit's signature field - // which is a block signature for the current epoch. + // which is a final block signature for the current epoch for both dBFT 2.0 and + // for anti-MEV extension. Signature() []byte } diff --git a/config.go b/config.go index c34ac35b..973b636c 100644 --- a/config.go +++ b/config.go @@ -20,9 +20,14 @@ type Config[H Hash] struct { // if current time is less than that of previous context. // By default use millisecond precision. TimestampIncrement uint64 + // AntiMEVExtensionEnablingHeight denotes the height starting from which dBFT + // Anti-MEV extensions should be enabled. -1 means no extension is enabled. + AntiMEVExtensionEnablingHeight int64 // GetKeyPair returns an index of the node in the list of validators // together with it's key pair. GetKeyPair func([]PublicKey) (int, PrivateKey, PublicKey) + // NewPreBlockFromContext should allocate, fill from Context and return new block.PreBlock. + NewPreBlockFromContext func(ctx *Context[H]) PreBlock[H] // NewBlockFromContext should allocate, fill from Context and return new block.Block. NewBlockFromContext func(ctx *Context[H]) Block[H] // RequestTx is a callback which is called when transaction contained @@ -36,10 +41,14 @@ type Config[H Hash] struct { // GetVerified returns a slice of verified transactions // to be proposed in a new block. GetVerified func() []Transaction[H] + // VerifyPreBlock verifies if preBlock is valid. + VerifyPreBlock func(b PreBlock[H]) bool // VerifyBlock verifies if block is valid. VerifyBlock func(b Block[H]) bool // Broadcast should broadcast payload m to the consensus nodes. Broadcast func(m ConsensusPayload[H]) + // ProcessBlock is called every time new preBlock is accepted. + ProcessPreBlock func(b PreBlock[H]) // ProcessBlock is called every time new block is accepted. ProcessBlock func(b Block[H]) // GetBlock should return block with hash. @@ -63,6 +72,8 @@ type Config[H Hash] struct { NewPrepareResponse func(preparationHash H) PrepareResponse[H] // NewChangeView is a constructor for payload.ChangeView. NewChangeView func(newViewNumber byte, reason ChangeViewReason, timestamp uint64) ChangeView + // NewPreCommit is a constructor for payload.PreCommit. + NewPreCommit func(data []byte) PreCommit // NewCommit is a constructor for payload.Commit. NewCommit func(signature []byte) Commit // NewRecoveryRequest is a constructor for payload.RecoveryRequest. @@ -73,6 +84,10 @@ type Config[H Hash] struct { VerifyPrepareRequest func(p ConsensusPayload[H]) error // VerifyPrepareResponse performs external PrepareResponse verification and returns nil if it's successful. VerifyPrepareResponse func(p ConsensusPayload[H]) error + // VerifyPreCommit performs external PreCommit verification and returns nil if it's successful. + // Note that PreBlock-dependent PreCommit verification should be performed inside PreBlock.Verify + // callback. + VerifyPreCommit func(p ConsensusPayload[H]) error } const defaultSecondsPerBlock = time.Second * 15 @@ -101,6 +116,10 @@ func defaultConfig[H Hash]() *Config[H] { VerifyPrepareRequest: func(ConsensusPayload[H]) error { return nil }, VerifyPrepareResponse: func(ConsensusPayload[H]) error { return nil }, + + AntiMEVExtensionEnablingHeight: -1, + VerifyPreBlock: func(PreBlock[H]) bool { return true }, + VerifyPreCommit: func(ConsensusPayload[H]) error { return nil }, } } @@ -131,6 +150,20 @@ func checkConfig[H Hash](cfg *Config[H]) error { return errors.New("NewRecoveryRequest is nil") } else if cfg.NewRecoveryMessage == nil { return errors.New("NewRecoveryMessage is nil") + } else if cfg.AntiMEVExtensionEnablingHeight >= 0 { + if cfg.NewPreBlockFromContext == nil { + return errors.New("NewPreBlockFromContext is nil") + } else if cfg.ProcessPreBlock == nil { + return errors.New("ProcessPreBlock is nil") + } else if cfg.NewPreCommit == nil { + return errors.New("NewPreCommit is nil") + } + } else if cfg.NewPreBlockFromContext != nil { + return errors.New("NewPreBlockFromContext is set, but AntiMEVExtensionEnablingHeight is not specified") + } else if cfg.ProcessPreBlock != nil { + return errors.New("ProcessPreBlock is set, but AntiMEVExtensionEnablingHeight is not specified") + } else if cfg.NewPreCommit != nil { + return errors.New("NewPreCommit is set, but AntiMEVExtensionEnablingHeight is not specified") } return nil @@ -164,6 +197,13 @@ func WithSecondsPerBlock[H Hash](d time.Duration) func(config *Config[H]) { } } +// WithAntiMEVExtensionEnablingHeight sets AntiMEVExtensionEnablingHeight. +func WithAntiMEVExtensionEnablingHeight[H Hash](h int64) func(config *Config[H]) { + return func(cfg *Config[H]) { + cfg.AntiMEVExtensionEnablingHeight = h + } +} + // WithTimestampIncrement sets TimestampIncrement. func WithTimestampIncrement[H Hash](u uint64) func(config *Config[H]) { return func(cfg *Config[H]) { @@ -171,6 +211,13 @@ func WithTimestampIncrement[H Hash](u uint64) func(config *Config[H]) { } } +// WithNewPreBlockFromContext sets NewPreBlockFromContext. +func WithNewPreBlockFromContext[H Hash](f func(ctx *Context[H]) PreBlock[H]) func(config *Config[H]) { + return func(cfg *Config[H]) { + cfg.NewPreBlockFromContext = f + } +} + // WithNewBlockFromContext sets NewBlockFromContext. func WithNewBlockFromContext[H Hash](f func(ctx *Context[H]) Block[H]) func(config *Config[H]) { return func(cfg *Config[H]) { @@ -206,6 +253,13 @@ func WithGetVerified[H Hash](f func() []Transaction[H]) func(config *Config[H]) } } +// WithVerifyPreBlock sets VerifyPreBlock. +func WithVerifyPreBlock[H Hash](f func(b PreBlock[H]) bool) func(config *Config[H]) { + return func(cfg *Config[H]) { + cfg.VerifyPreBlock = f + } +} + // WithVerifyBlock sets VerifyBlock. func WithVerifyBlock[H Hash](f func(b Block[H]) bool) func(config *Config[H]) { return func(cfg *Config[H]) { @@ -227,6 +281,13 @@ func WithProcessBlock[H Hash](f func(b Block[H])) func(config *Config[H]) { } } +// WithProcessPreBlock sets ProcessPreBlock. +func WithProcessPreBlock[H Hash](f func(b PreBlock[H])) func(config *Config[H]) { + return func(cfg *Config[H]) { + cfg.ProcessPreBlock = f + } +} + // WithGetBlock sets GetBlock. func WithGetBlock[H Hash](f func(h H) Block[H]) func(config *Config[H]) { return func(cfg *Config[H]) { @@ -297,6 +358,13 @@ func WithNewCommit[H Hash](f func(signature []byte) Commit) func(config *Config[ } } +// WithNewPreCommit sets NewPreCommit. +func WithNewPreCommit[H Hash](f func(signature []byte) PreCommit) func(config *Config[H]) { + return func(cfg *Config[H]) { + cfg.NewPreCommit = f + } +} + // WithNewRecoveryRequest sets NewRecoveryRequest. func WithNewRecoveryRequest[H Hash](f func(ts uint64) RecoveryRequest) func(config *Config[H]) { return func(cfg *Config[H]) { @@ -324,3 +392,10 @@ func WithVerifyPrepareResponse[H Hash](f func(prepareResp ConsensusPayload[H]) e cfg.VerifyPrepareResponse = f } } + +// WithVerifyPreCommit sets VerifyPreCommit. +func WithVerifyPreCommit[H Hash](f func(preCommit ConsensusPayload[H]) error) func(config *Config[H]) { + return func(cfg *Config[H]) { + cfg.VerifyPreCommit = f + } +} diff --git a/consensus_message.go b/consensus_message.go index d8859dea..01c9d2b7 100644 --- a/consensus_message.go +++ b/consensus_message.go @@ -15,6 +15,8 @@ type ConsensusMessage[H Hash] interface { GetPrepareRequest() PrepareRequest[H] // GetPrepareResponse returns payload as if it was PrepareResponse. GetPrepareResponse() PrepareResponse[H] + // GetPreCommit returns payload as if it was PreCommit. + GetPreCommit() PreCommit // GetCommit returns payload as if it was Commit. GetCommit() Commit // GetRecoveryRequest returns payload as if it was RecoveryRequest. diff --git a/consensus_message_type.go b/consensus_message_type.go index faed09b7..f7a4f534 100644 --- a/consensus_message_type.go +++ b/consensus_message_type.go @@ -5,11 +5,12 @@ import "fmt" // MessageType is a type for dBFT consensus messages. type MessageType byte -// 6 following constants enumerate all possible type of consensus message. +// 7 following constants enumerate all possible type of consensus message. const ( ChangeViewType MessageType = 0x00 PrepareRequestType MessageType = 0x20 PrepareResponseType MessageType = 0x21 + PreCommitType MessageType = 0x31 CommitType MessageType = 0x30 RecoveryRequestType MessageType = 0x40 RecoveryMessageType MessageType = 0x41 @@ -26,6 +27,8 @@ func (m MessageType) String() string { return "PrepareResponse" case CommitType: return "Commit" + case PreCommitType: + return "PreCommit" case RecoveryRequestType: return "RecoveryRequest" case RecoveryMessageType: diff --git a/context.go b/context.go index 4a7ce73d..27ee670f 100644 --- a/context.go +++ b/context.go @@ -23,13 +23,19 @@ type Context[H Hash] struct { // Pub is node's public key. Pub PublicKey - block Block[H] - header Block[H] + preBlock PreBlock[H] + preHeader PreBlock[H] + block Block[H] + header Block[H] // blockProcessed denotes whether Config.ProcessBlock callback was called for the current // height. If so, then no second call must happen. After new block is received by the user, // dBFT stops any new transaction or messages processing as far as timeouts handling till // the next call to Reset. blockProcessed bool + // preBlockProcessed is true when Config.ProcessPreBlock callback was + // invoked for the current height. This happens once and dbft continues + // to march towards proper commit after that. + preBlockProcessed bool // BlockIndex is current block index. BlockIndex uint32 @@ -58,6 +64,14 @@ type Context[H Hash] struct { // PreparationPayloads stores consensus Prepare* payloads for the current epoch. PreparationPayloads []ConsensusPayload[H] + // PreCommitPayloads stores consensus PreCommit payloads sent through all epochs + // as a part of anti-MEV dBFT extension. It is assumed that valid PreCommit + // payloads can only be sent once by a single node per the whole set of consensus + // epochs for particular block. Invalid PreCommit payloads are kicked off this + // list immediately (if PrepareRequest was received for the current round, so + // it's possible to verify PreCommit against PreBlock built on PrepareRequest) + // or stored till the corresponding PrepareRequest receiving. + PreCommitPayloads []ConsensusPayload[H] // CommitPayloads stores consensus Commit payloads sent throughout all epochs. It // is assumed that valid Commit payload can only be sent once by a single node per // the whole set of consensus epochs for particular block. Invalid commit payloads @@ -108,11 +122,13 @@ func (c *Context[H]) IsBackup() bool { // WatchOnly returns true iff node takes no active part in consensus. func (c *Context[H]) WatchOnly() bool { return c.MyIndex < 0 || c.Config.WatchOnly() } -// CountCommitted returns number of received Commit messages not only for the current -// epoch but also for any other epoch. +// CountCommitted returns number of received Commit (or PreCommit for anti-MEV +// extension) messages not only for the current epoch but also for any other epoch. func (c *Context[H]) CountCommitted() (count int) { for i := range c.CommitPayloads { - if c.CommitPayloads[i] != nil { + // Consider both Commit and PreCommit payloads since node both Commit and PreCommit + // phases are one-directional (do not impose view change). + if c.CommitPayloads[i] != nil || c.PreCommitPayloads[i] != nil { count++ } } @@ -124,7 +140,8 @@ func (c *Context[H]) CountCommitted() (count int) { // for this view and that hasn't sent the Commit message at the previous views. func (c *Context[H]) CountFailed() (count int) { for i, hv := range c.LastSeenMessage { - if c.CommitPayloads[i] == nil && (hv == nil || hv.Height < c.BlockIndex || hv.View < c.ViewNumber) { + if (c.CommitPayloads[i] == nil && c.PreCommitPayloads[i] == nil) && + (hv == nil || hv.Height < c.BlockIndex || hv.View < c.ViewNumber) { count++ } } @@ -143,6 +160,12 @@ func (c *Context[H]) ResponseSent() bool { return !c.WatchOnly() && c.PreparationPayloads[c.MyIndex] != nil } +// PreCommitSent returns true iff PreCommit message was sent for the current epoch +// assuming that the node can't go further than current epoch after PreCommit was sent. +func (c *Context[H]) PreCommitSent() bool { + return !c.WatchOnly() && c.PreCommitPayloads[c.MyIndex] != nil +} + // CommitSent returns true iff Commit message was sent for the current epoch // assuming that the node can't go further than current epoch after commit was sent. func (c *Context[H]) CommitSent() bool { @@ -191,6 +214,10 @@ func (c *Context[H]) MoreThanFNodesCommittedOrLost() bool { return c.CountCommitted()+c.CountFailed() > c.F() } +func (c *Context[H]) PreBlock() PreBlock[H] { + return c.preHeader // without transactions +} + func (c *Context[H]) reset(view byte, ts uint64) { c.MyIndex = -1 c.lastBlockTimestamp = ts @@ -207,6 +234,7 @@ func (c *Context[H]) reset(view byte, ts uint64) { c.LastSeenMessage = make([]*HeightView, n) } c.blockProcessed = false + c.preBlockProcessed = false } else { for i := range c.Validators { m := c.ChangeViewPayloads[i] @@ -226,6 +254,7 @@ func (c *Context[H]) reset(view byte, ts uint64) { n := len(c.Validators) c.ChangeViewPayloads = make([]ConsensusPayload[H], n) if view == 0 { + c.PreCommitPayloads = make([]ConsensusPayload[H], n) c.CommitPayloads = make([]ConsensusPayload[H], n) } c.PreparationPayloads = make([]ConsensusPayload[H], n) @@ -284,12 +313,42 @@ func (c *Context[H]) CreateBlock() Block[H] { txx[i] = c.Transactions[h] } + // Anti-MEV extension properly sets PreBlock transactions once during PreBlock + // construction and then never updates these transactions in the dBFT context. + // Thus, user must not reuse txx if anti-MEV extension is enabled. However, + // we don't skip a call to Block.SetTransactions since it may be used as a + // signal to the user's code to finalize the block. c.block.SetTransactions(txx) } return c.block } +// CreatePreBlock returns PreBlock for the current epoch. +func (c *Context[H]) CreatePreBlock() PreBlock[H] { + if c.preBlock == nil { + if c.preBlock = c.MakePreHeader(); c.preBlock == nil { + return nil + } + + txx := make([]Transaction[H], len(c.TransactionHashes)) + + for i, h := range c.TransactionHashes { + txx[i] = c.Transactions[h] + } + + c.preBlock.SetTransactions(txx) + } + + return c.preBlock +} + +// isAntiMEVExtensionEnabled returns whether Anti-MEV dBFT extension is enabled +// at the currently processing block height. +func (c *Context[H]) isAntiMEVExtensionEnabled() bool { + return c.Config.AntiMEVExtensionEnablingHeight >= 0 && uint32(c.Config.AntiMEVExtensionEnablingHeight) < c.BlockIndex +} + // MakeHeader returns half-filled block for the current epoch. // All hashable fields will be filled. func (c *Context[H]) MakeHeader() Block[H] { @@ -297,12 +356,38 @@ func (c *Context[H]) MakeHeader() Block[H] { if !c.RequestSentOrReceived() { return nil } + // For anti-MEV dBFT extension it's important to have at least M PreCommits received + // because PrepareRequest is not enough to construct proper block. + if c.isAntiMEVExtensionEnabled() { + var count int + for _, preCommit := range c.PreCommitPayloads { + if preCommit != nil && preCommit.ViewNumber() == c.ViewNumber { + count++ + } + } + if count < c.M() { + return nil + } + } c.header = c.Config.NewBlockFromContext(c) } return c.header } +// MakePreHeader returns half-filled block for the current epoch. +// All hashable fields will be filled. +func (c *Context[H]) MakePreHeader() PreBlock[H] { + if c.preHeader == nil { + if !c.RequestSentOrReceived() { + return nil + } + c.preHeader = c.Config.NewPreBlockFromContext(c) + } + + return c.preHeader +} + // hasAllTransactions returns true iff all transactions were received // for the proposed block. func (c *Context[H]) hasAllTransactions() bool { diff --git a/dbft.go b/dbft.go index db6cf851..2483ecb9 100644 --- a/dbft.go +++ b/dbft.go @@ -72,7 +72,9 @@ func (d *DBFT[H]) addTransaction(tx Transaction[H]) { func (d *DBFT[H]) Start(ts uint64) { d.cache = newCache[H]() d.initializeConsensus(0, ts) - d.start() + if d.IsPrimary() { + d.sendPrepareRequest() + } } // Reset reinitializes dBFT instance with the given timestamp of the previous @@ -110,6 +112,25 @@ func (d *DBFT[H]) initializeConsensus(view byte, ts uint64) { zap.Int("index", d.MyIndex), zap.String("role", role)) + // Process cached messages if any. + if msgs := d.cache.getHeight(d.BlockIndex); msgs != nil { + for _, m := range msgs.prepare { + d.OnReceive(m) + } + + for _, m := range msgs.chViews { + d.OnReceive(m) + } + + for _, m := range msgs.preCommit { + d.OnReceive(m) + } + + for _, m := range msgs.commit { + d.OnReceive(m) + } + } + if d.Context.WatchOnly() { return } @@ -145,8 +166,8 @@ func (d *DBFT[H]) OnTransaction(tx Transaction[H]) { // zap.Bool("response_sent", d.ResponseSent()), // zap.Bool("block_sent", d.BlockSent())) if !d.IsBackup() || d.NotAcceptingPayloadsDueToViewChanging() || - !d.RequestSentOrReceived() || d.ResponseSent() || d.BlockSent() || - len(d.MissingTransactions) == 0 { + !d.RequestSentOrReceived() || d.ResponseSent() || d.PreCommitSent() || + d.CommitSent() || d.BlockSent() || len(d.MissingTransactions) == 0 { return } @@ -189,7 +210,7 @@ func (d *DBFT[H]) OnTimeout(height uint32, view byte) { if d.IsPrimary() && !d.RequestSentOrReceived() { d.sendPrepareRequest() } else if (d.IsPrimary() && d.RequestSentOrReceived()) || d.IsBackup() { - if d.CommitSent() { + if d.CommitSent() || d.PreCommitSent() { d.Logger.Debug("send recovery to resend commit") d.sendRecoveryMessage() d.changeTimer(d.SecondsPerBlock << 1) @@ -232,8 +253,6 @@ func (d *DBFT[H]) OnReceive(msg ConsensusPayload[H]) { zap.Any("cache", d.cache.mail[msg.Height()])) d.cache.addMessage(msg) return - } else if msg.ValidatorIndex() > uint16(d.N()) { - return } hv := d.LastSeenMessage[msg.ValidatorIndex()] @@ -255,6 +274,14 @@ func (d *DBFT[H]) OnReceive(msg ConsensusPayload[H]) { d.onPrepareResponse(msg) case CommitType: d.onCommit(msg) + case PreCommitType: + if !d.isAntiMEVExtensionEnabled() { + d.Logger.Error(fmt.Sprintf("%s message received but AntiMEVExtension is disabled", PreCommitType), + zap.Uint16("from", msg.ValidatorIndex()), + ) + return + } + d.onPreCommit(msg) case RecoveryRequestType: d.onRecoveryRequest(msg) case RecoveryMessageType: @@ -264,30 +291,6 @@ func (d *DBFT[H]) OnReceive(msg ConsensusPayload[H]) { } } -// start performs initial operations and returns messages to be sent. -// It must be called after every height or view increment. -func (d *DBFT[H]) start() { - if !d.IsPrimary() { - if msgs := d.cache.getHeight(d.BlockIndex); msgs != nil { - for _, m := range msgs.prepare { - d.OnReceive(m) - } - - for _, m := range msgs.chViews { - d.OnReceive(m) - } - - for _, m := range msgs.commit { - d.OnReceive(m) - } - } - - return - } - - d.sendPrepareRequest() -} - func (d *DBFT[H]) onPrepareRequest(msg ConsensusPayload[H]) { // ignore prepareRequest if we had already received it or // are in process of changing view @@ -304,7 +307,7 @@ func (d *DBFT[H]) onPrepareRequest(msg ConsensusPayload[H]) { d.Logger.Debug("ignoring wrong view number", zap.Uint("view", uint(msg.ViewNumber()))) return } else if uint(msg.ValidatorIndex()) != d.GetPrimaryIndex(d.ViewNumber) { - d.Logger.Debug("ignoring PrepareRequest from wrong node", zap.Uint16("from", msg.ValidatorIndex())) + d.Logger.Info("ignoring PrepareRequest from wrong node", zap.Uint16("from", msg.ValidatorIndex())) return } @@ -318,9 +321,6 @@ func (d *DBFT[H]) onPrepareRequest(msg ConsensusPayload[H]) { d.extendTimer(2) p := msg.GetPrepareRequest() - if len(p.TransactionHashes()) == 0 { - d.Logger.Debug("received empty PrepareRequest") - } d.Timestamp = p.Timestamp() d.Nonce = p.Nonce() @@ -366,14 +366,29 @@ func (d *DBFT[H]) processMissingTx() { // with it, it sends a changeView request and returns false. It's only valid to // call it when all transactions for this block are already collected. func (d *DBFT[H]) createAndCheckBlock() bool { - if b := d.Context.CreateBlock(); !d.VerifyBlock(b) { - d.Logger.Warn("proposed block fails verification") + var blockOK bool + if d.isAntiMEVExtensionEnabled() { + b := d.Context.CreatePreBlock() + blockOK = d.VerifyPreBlock(b) + if !blockOK { + d.Logger.Warn("proposed preBlock fails verification") + } + } else { + b := d.Context.CreateBlock() + blockOK = d.VerifyBlock(b) + if !blockOK { + d.Logger.Warn("proposed block fails verification") + } + } + if !blockOK { d.sendChangeView(CVTxInvalid) return false } return true } +// updateExistingPayloads is called _only_ from onPrepareRequest, it validates +// payloads we may have received before PrepareRequest. func (d *DBFT[H]) updateExistingPayloads(msg ConsensusPayload[H]) { for i, m := range d.PreparationPayloads { if m != nil && m.Type() == PrepareResponseType { @@ -384,6 +399,28 @@ func (d *DBFT[H]) updateExistingPayloads(msg ConsensusPayload[H]) { } } + if d.isAntiMEVExtensionEnabled() { + for i, m := range d.PreCommitPayloads { + if m != nil && m.ViewNumber() == d.ViewNumber { + if preHeader := d.MakePreHeader(); preHeader != nil { + pub := d.Validators[m.ValidatorIndex()] + if err := preHeader.Verify(pub, m.GetPreCommit().Data()); err != nil { + d.PreCommitPayloads[i] = nil + d.Logger.Warn("can't validate preCommit data", + zap.Error(err)) + } + } + } + } + // Commits can't be verified, we have no idea what's the header. + } else { + d.verifyCommitPayloadsAgainstHeader() + } +} + +// verifyCommitPayloadsAgainstHeader performs verification of commit payloads +// against generated header. +func (d *DBFT[H]) verifyCommitPayloadsAgainstHeader() { for i, m := range d.CommitPayloads { if m != nil && m.ViewNumber() == d.ViewNumber { if header := d.MakeHeader(); header != nil { @@ -444,7 +481,7 @@ func (d *DBFT[H]) onPrepareResponse(msg ConsensusPayload[H]) { d.extendTimer(2) - if !d.Context.WatchOnly() && !d.CommitSent() && d.RequestSentOrReceived() { + if !d.Context.WatchOnly() && !d.CommitSent() && (!d.isAntiMEVExtensionEnabled() || !d.PreCommitSent()) && d.RequestSentOrReceived() { d.checkPrepare() } } @@ -459,8 +496,8 @@ func (d *DBFT[H]) onChangeView(msg ConsensusPayload[H]) { return } - if d.CommitSent() { - d.Logger.Debug("ignoring ChangeView: commit sent") + if d.CommitSent() || d.PreCommitSent() { + d.Logger.Debug("ignoring ChangeView: preCommit or commit sent") d.sendRecoveryMessage() return } @@ -480,6 +517,52 @@ func (d *DBFT[H]) onChangeView(msg ConsensusPayload[H]) { d.checkChangeView(p.NewViewNumber()) } +func (d *DBFT[H]) onPreCommit(msg ConsensusPayload[H]) { + existing := d.PreCommitPayloads[msg.ValidatorIndex()] + if existing != nil { + if existing.Hash() != msg.Hash() { + d.Logger.Warn("rejecting preCommit due to existing", + zap.Uint("validator", uint(msg.ValidatorIndex())), + zap.Uint("existing view", uint(existing.ViewNumber())), + zap.Uint("view", uint(msg.ViewNumber())), + zap.Stringer("existing hash", existing.Hash()), + zap.Stringer("hash", msg.Hash()), + ) + } + return + } + d.PreCommitPayloads[msg.ValidatorIndex()] = msg + if d.ViewNumber == msg.ViewNumber() { + if err := d.VerifyPreCommit(msg); err != nil { + d.Logger.Warn("invalid PreCommit", zap.Uint16("from", msg.ValidatorIndex()), zap.String("error", err.Error())) + return + } + + d.Logger.Info("received PreCommit", zap.Uint("validator", uint(msg.ValidatorIndex()))) + d.extendTimer(4) + + preHeader := d.MakePreHeader() + if preHeader != nil { + pub := d.Validators[msg.ValidatorIndex()] + if err := preHeader.Verify(pub, msg.GetPreCommit().Data()); err == nil { + d.checkPreCommit() + } else { + d.PreCommitPayloads[msg.ValidatorIndex()] = nil + d.Logger.Warn("invalid preCommit data", + zap.Uint("validator", uint(msg.ValidatorIndex())), + zap.Error(err), + ) + } + } + return + } + + d.Logger.Info("received preCommit for different view", + zap.Uint("validator", uint(msg.ValidatorIndex())), + zap.Uint("view", uint(msg.ViewNumber())), + ) +} + func (d *DBFT[H]) onCommit(msg ConsensusPayload[H]) { existing := d.CommitPayloads[msg.ValidatorIndex()] if existing != nil { @@ -494,18 +577,17 @@ func (d *DBFT[H]) onCommit(msg ConsensusPayload[H]) { } return } + d.CommitPayloads[msg.ValidatorIndex()] = msg if d.ViewNumber == msg.ViewNumber() { d.Logger.Info("received Commit", zap.Uint("validator", uint(msg.ValidatorIndex()))) d.extendTimer(4) header := d.MakeHeader() - if header == nil { - d.CommitPayloads[msg.ValidatorIndex()] = msg - } else { + if header != nil { pub := d.Validators[msg.ValidatorIndex()] if header.Verify(pub, msg.GetCommit().Signature()) == nil { - d.CommitPayloads[msg.ValidatorIndex()] = msg d.checkCommit() } else { + d.CommitPayloads[msg.ValidatorIndex()] = nil d.Logger.Warn("invalid commit signature", zap.Uint("validator", uint(msg.ValidatorIndex())), ) @@ -519,11 +601,10 @@ func (d *DBFT[H]) onCommit(msg ConsensusPayload[H]) { zap.Uint("validator", uint(msg.ValidatorIndex())), zap.Uint("view", uint(msg.ViewNumber())), ) - d.CommitPayloads[msg.ValidatorIndex()] = msg } func (d *DBFT[H]) onRecoveryRequest(msg ConsensusPayload[H]) { - if !d.CommitSent() { + if !d.CommitSent() && (!d.isAntiMEVExtensionEnabled() || !d.PreCommitSent()) { // Limit recoveries to be sent from no more than F nodes // TODO replace loop with a single if shouldSend := false @@ -548,27 +629,28 @@ func (d *DBFT[H]) onRecoveryMessage(msg ConsensusPayload[H]) { d.Logger.Debug("recovery message received", zap.Any("dump", msg)) var ( - validPrepResp, validChViews, validCommits int - validPrepReq, totalPrepReq int + validPrepResp, validChViews int + validPreCommits, validCommits int + validPrepReq, totalPrepReq int + recovery = msg.GetRecoveryMessage() + total = len(d.Validators) ) - recovery := msg.GetRecoveryMessage() - total := len(d.Validators) - // isRecovering is always set to false again after OnRecoveryMessageReceived d.recovering = true defer func() { - d.Logger.Sugar().Debugf("recovering finished cv=%d/%d preq=%d/%d presp=%d/%d co=%d/%d", + d.Logger.Sugar().Debugf("recovering finished cv=%d/%d preq=%d/%d presp=%d/%d pco=%d/%d co=%d/%d", validChViews, total, validPrepReq, totalPrepReq, validPrepResp, total, + validPreCommits, total, validCommits, total) d.recovering = false }() if msg.ViewNumber() > d.ViewNumber { - if d.CommitSent() { + if d.CommitSent() || d.PreCommitSent() { return } @@ -578,7 +660,7 @@ func (d *DBFT[H]) onRecoveryMessage(msg ConsensusPayload[H]) { } } - if msg.ViewNumber() == d.ViewNumber && !(d.ViewChanging() && !d.MoreThanFNodesCommittedOrLost()) && !d.CommitSent() { + if msg.ViewNumber() == d.ViewNumber && !(d.ViewChanging() && !d.MoreThanFNodesCommittedOrLost()) && !d.CommitSent() && (!d.isAntiMEVExtensionEnabled() || !d.PreCommitSent()) { if !d.RequestSentOrReceived() { prepReq := recovery.GetPrepareRequest(msg, d.Validators, uint16(d.PrimaryIndex)) if prepReq != nil { @@ -595,8 +677,13 @@ func (d *DBFT[H]) onRecoveryMessage(msg ConsensusPayload[H]) { } } + // Ensure we know about all (pre) commits from lower view numbers. if msg.ViewNumber() <= d.ViewNumber { - // Ensure we know about all commits from lower view numbers. + for _, m := range recovery.GetPreCommits(msg, d.Validators) { + validPreCommits++ + d.OnReceive(m) + } + for _, m := range recovery.GetCommits(msg, d.Validators) { validCommits++ d.OnReceive(m) @@ -613,7 +700,7 @@ func (d *DBFT[H]) changeTimer(delay time.Duration) { } func (d *DBFT[H]) extendTimer(count time.Duration) { - if !d.CommitSent() && !d.ViewChanging() { + if !d.CommitSent() && (!d.isAntiMEVExtensionEnabled() || !d.PreCommitSent()) && !d.ViewChanging() { d.Timer.Extend(count * d.SecondsPerBlock / time.Duration(d.M())) } } @@ -623,3 +710,9 @@ func (d *DBFT[H]) extendTimer(count time.Duration) { func (d *DBFT[H]) Header() Block[H] { return d.header } + +// PreHeader returns current preHeader from context. May be nil in case if no +// preHeader is constructed yet. Do not change the resulting preHeader. +func (d *DBFT[H]) PreHeader() PreBlock[H] { + return d.preHeader +} diff --git a/dbft_test.go b/dbft_test.go index 86fe3c89..a08195fe 100644 --- a/dbft_test.go +++ b/dbft_test.go @@ -3,6 +3,7 @@ package dbft_test import ( "crypto/rand" "encoding/binary" + "fmt" "testing" "time" @@ -25,6 +26,7 @@ type testState struct { currHeight uint32 currHash crypto.Uint256 pool *testPool + preBlocks []dbft.PreBlock[crypto.Uint256] blocks []dbft.Block[crypto.Uint256] verify func(b dbft.Block[crypto.Uint256]) bool } @@ -43,7 +45,8 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) { t.Run("backup sends nothing on start", func(t *testing.T) { s.currHeight = 0 - service, _ := dbft.New[crypto.Uint256](s.getOptions()...) + service, err := dbft.New[crypto.Uint256](s.getOptions()...) + require.NoError(t, err) service.Start(0) require.Nil(t, s.tryRecv()) @@ -87,31 +90,48 @@ func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) { } func TestDBFT_SingleNode(t *testing.T) { - s := newTestState(0, 1) - - s.currHeight = 2 - service, _ := dbft.New[crypto.Uint256](s.getOptions()...) - - service.Start(0) - p := s.tryRecv() - require.NotNil(t, p) - require.Equal(t, dbft.PrepareRequestType, p.Type()) - require.EqualValues(t, 3, p.Height()) - require.EqualValues(t, 0, p.ViewNumber()) - require.NotNil(t, p.Payload()) - require.EqualValues(t, 0, p.ValidatorIndex()) - - cm := s.tryRecv() - require.NotNil(t, cm) - require.Equal(t, dbft.CommitType, cm.Type()) - require.EqualValues(t, s.currHeight+1, cm.Height()) - require.EqualValues(t, 0, cm.ViewNumber()) - require.NotNil(t, cm.Payload()) - require.EqualValues(t, 0, cm.ValidatorIndex()) + for _, amev := range []bool{false, true} { + t.Run(fmt.Sprintf("AMEV %t", amev), func(t *testing.T) { + s := newTestState(0, 1) + + s.currHeight = 2 + opts := s.getOptions() + if amev { + opts = s.getAMEVOptions() + } + service, _ := dbft.New[crypto.Uint256](opts...) + + service.Start(0) + p := s.tryRecv() + require.NotNil(t, p) + require.Equal(t, dbft.PrepareRequestType, p.Type()) + require.EqualValues(t, 3, p.Height()) + require.EqualValues(t, 0, p.ViewNumber()) + require.NotNil(t, p.Payload()) + require.EqualValues(t, 0, p.ValidatorIndex()) + + if amev { + cm := s.tryRecv() + require.NotNil(t, cm) + require.Equal(t, dbft.PreCommitType, cm.Type()) + require.EqualValues(t, s.currHeight+1, cm.Height()) + require.EqualValues(t, 0, cm.ViewNumber()) + require.NotNil(t, cm.Payload()) + require.EqualValues(t, 0, cm.ValidatorIndex()) + } + cm := s.tryRecv() + require.NotNil(t, cm) + require.Equal(t, dbft.CommitType, cm.Type()) + require.EqualValues(t, s.currHeight+1, cm.Height()) + require.EqualValues(t, 0, cm.ViewNumber()) + require.NotNil(t, cm.Payload()) + require.EqualValues(t, 0, cm.ValidatorIndex()) - b := s.nextBlock() - require.NotNil(t, b) - require.Equal(t, s.currHeight+1, b.Index()) + b := s.nextBlock() + require.NotNil(t, b) + require.Equal(t, s.currHeight+1, b.Index()) + }) + } } func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) { @@ -151,7 +171,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) { require.Nil(t, s.tryRecv()) t.Run("receive response from primary", func(t *testing.T) { - resp := s.getPrepareResponse(5, p.Hash()) + resp := s.getPrepareResponse(5, p.Hash(), 0) service.OnReceive(resp) require.Nil(t, s.tryRecv()) @@ -243,8 +263,8 @@ func TestDBFT_CommitOnTransaction(t *testing.T) { tx := testTx(42) req := s.getPrepareRequest(2, tx.Hash()) srv.OnReceive(req) - srv.OnReceive(s.getPrepareResponse(1, req.Hash())) - srv.OnReceive(s.getPrepareResponse(3, req.Hash())) + srv.OnReceive(s.getPrepareResponse(1, req.Hash(), 0)) + srv.OnReceive(s.getPrepareResponse(3, req.Hash(), 0)) require.Nil(t, srv.Header()) // missing transaction. // Test state for forming header. @@ -259,13 +279,13 @@ func TestDBFT_CommitOnTransaction(t *testing.T) { srv1, _ := dbft.New[crypto.Uint256](s1.getOptions()...) srv1.Start(0) srv1.OnReceive(req) - srv1.OnReceive(s1.getPrepareResponse(1, req.Hash())) - srv1.OnReceive(s1.getPrepareResponse(3, req.Hash())) + srv1.OnReceive(s1.getPrepareResponse(1, req.Hash(), 0)) + srv1.OnReceive(s1.getPrepareResponse(3, req.Hash(), 0)) require.NotNil(t, srv1.Header()) for _, i := range []uint16{1, 2, 3} { require.NoError(t, srv1.Header().Sign(s1.privs[i])) - c := s1.getCommit(i, srv1.Header().Signature()) + c := s1.getCommit(i, srv1.Header().Signature(), 0) srv.OnReceive(c) } @@ -284,11 +304,11 @@ func TestDBFT_OnReceiveCommit(t *testing.T) { req := s.tryRecv() require.NotNil(t, req) - resp := s.getPrepareResponse(1, req.Hash()) + resp := s.getPrepareResponse(1, req.Hash(), 0) service.OnReceive(resp) require.Nil(t, s.tryRecv()) - resp = s.getPrepareResponse(0, req.Hash()) + resp = s.getPrepareResponse(0, req.Hash(), 0) service.OnReceive(resp) cm := s.tryRecv() @@ -316,14 +336,14 @@ func TestDBFT_OnReceiveCommit(t *testing.T) { t.Run("process block after enough commits", func(t *testing.T) { s0 := s.copyWithIndex(0) require.NoError(t, service.Header().Sign(s0.privs[0])) - c0 := s0.getCommit(0, service.Header().Signature()) + c0 := s0.getCommit(0, service.Header().Signature(), 0) service.OnReceive(c0) require.Nil(t, s.tryRecv()) require.Nil(t, s.nextBlock()) s1 := s.copyWithIndex(1) require.NoError(t, service.Header().Sign(s1.privs[1])) - c1 := s1.getCommit(1, service.Header().Signature()) + c1 := s1.getCommit(1, service.Header().Signature(), 0) service.OnReceive(c1) require.Nil(t, s.tryRecv()) @@ -344,11 +364,11 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) { req := s.tryRecv() require.NotNil(t, req) - resp := s.getPrepareResponse(1, req.Hash()) + resp := s.getPrepareResponse(1, req.Hash(), 0) service.OnReceive(resp) require.Nil(t, s.tryRecv()) - resp = s.getPrepareResponse(0, req.Hash()) + resp = s.getPrepareResponse(0, req.Hash(), 0) service.OnReceive(resp) cm := s.tryRecv() require.NotNil(t, cm) @@ -743,6 +763,211 @@ func TestDBFT_FourGoodNodesDeadlock(t *testing.T) { require.NotNil(t, r1.nextBlock()) } +func TestDBFT_OnReceiveCommitAMEV(t *testing.T) { + s := newTestState(2, 4) + t.Run("send preCommit after enough responses", func(t *testing.T) { + s.currHeight = 1 + service, _ := dbft.New[crypto.Uint256](s.getAMEVOptions()...) + service.Start(0) + + req := s.tryRecv() + require.NotNil(t, req) + + resp := s.getPrepareResponse(1, req.Hash(), 0) + service.OnReceive(resp) + require.Nil(t, s.tryRecv()) + + resp = s.getPrepareResponse(0, req.Hash(), 0) + service.OnReceive(resp) + + cm := s.tryRecv() + require.NotNil(t, cm) + require.Equal(t, dbft.PreCommitType, cm.Type()) + require.EqualValues(t, s.currHeight+1, cm.Height()) + require.EqualValues(t, 0, cm.ViewNumber()) + require.EqualValues(t, s.myIndex, cm.ValidatorIndex()) + require.NotNil(t, cm.Payload()) + + pub := s.pubs[s.myIndex] + require.NoError(t, service.PreHeader().Verify(pub, cm.GetPreCommit().Data())) + + t.Run("send commit after enough preCommits", func(t *testing.T) { + s0 := s.copyWithIndex(0) + require.NoError(t, service.PreHeader().SetData(s0.privs[0])) + preC0 := s0.getPreCommit(0, service.PreHeader().Data(), 0) + service.OnReceive(preC0) + require.Nil(t, s.tryRecv()) + require.Nil(t, s.nextPreBlock()) + require.Nil(t, s.nextBlock()) + + s1 := s.copyWithIndex(1) + require.NoError(t, service.PreHeader().SetData(s1.privs[1])) + preC1 := s1.getPreCommit(1, service.PreHeader().Data(), 0) + service.OnReceive(preC1) + + b := s.nextPreBlock() + require.NotNil(t, b) + require.Equal(t, []byte{0, 0, 0, 2}, b.Data()) // After SetData it's equal to node index. + require.Nil(t, s.nextBlock()) + + c := s.tryRecv() + require.NotNil(t, c) + require.Equal(t, dbft.CommitType, c.Type()) + require.EqualValues(t, s.currHeight+1, c.Height()) + require.EqualValues(t, 0, c.ViewNumber()) + require.EqualValues(t, s.myIndex, c.ValidatorIndex()) + require.NotNil(t, c.Payload()) + + t.Run("process block a after enough commitAcks", func(t *testing.T) { + s0 := s.copyWithIndex(0) + require.NoError(t, service.Header().Sign(s0.privs[0])) + c0 := s0.getAMEVCommit(0, service.Header().Signature()) + service.OnReceive(c0) + require.Nil(t, s.tryRecv()) + require.Nil(t, s.nextPreBlock()) + require.Nil(t, s.nextBlock()) + + s1 := s.copyWithIndex(1) + require.NoError(t, service.Header().Sign(s1.privs[1])) + c1 := s1.getAMEVCommit(1, service.Header().Signature()) + service.OnReceive(c1) + require.Nil(t, s.tryRecv()) + require.Nil(t, s.nextPreBlock()) + + b := s.nextBlock() + require.NotNil(t, b) + require.Equal(t, s.currHeight+1, b.Index()) + }) + }) + }) +} + +func TestDBFT_CachedMessages(t *testing.T) { + for _, amev := range []bool{false, true} { + t.Run(fmt.Sprintf("AMEV %t", amev), func(t *testing.T) { + s2 := newTestState(2, 4) + s2.currHeight = 1 + s1 := newTestState(1, 4) + s1.currHeight = 1 + + opts := s2.getOptions() + if amev { + opts = s2.getAMEVOptions() + } + service2, _ := dbft.New[crypto.Uint256](opts...) + service2.Start(0) + + opts = s1.getOptions() + if amev { + opts = s1.getAMEVOptions() + } + service1, _ := dbft.New[crypto.Uint256](opts...) + service1.Start(0) + + req := s2.tryRecv() + require.NotNil(t, req) // Primary sends a request. + require.Equal(t, dbft.PrepareRequestType, req.Type()) + + require.Nil(t, s1.tryRecv()) // Backup waits. + + cv0 := s1.getChangeView(0, 1) + cv3 := s1.getChangeView(3, 1) + service1.OnReceive(cv0) + service1.OnReceive(cv3) + service1.OnTimeout(s1.currHeight+1, 0) + + cv := s1.tryRecv() + require.NotNil(t, cv) + require.Equal(t, dbft.ChangeViewType, cv.Type()) + + service1.OnTimeout(s1.currHeight+1, 1) + req = s1.tryRecv() + require.NotNil(t, req) + require.Equal(t, dbft.PrepareRequestType, req.Type()) + + resp := s1.getPrepareResponse(3, req.Hash(), 1) + service1.OnReceive(resp) + require.Nil(t, s1.tryRecv()) + service2.OnReceive(resp) // From the future. + require.Nil(t, s2.tryRecv()) + + resp = s1.getPrepareResponse(0, req.Hash(), 1) + service2.OnReceive(resp) // From the future. + require.Nil(t, s2.tryRecv()) + + service1.OnReceive(resp) + cm := s1.tryRecv() + require.NotNil(t, cm) + + service2.OnReceive(cm) + require.Nil(t, s2.tryRecv()) + + if amev { + require.Equal(t, dbft.PreCommitType, cm.Type()) + require.EqualValues(t, s1.currHeight+1, cm.Height()) + require.EqualValues(t, 1, cm.ViewNumber()) + require.EqualValues(t, s1.myIndex, cm.ValidatorIndex()) + require.NotNil(t, cm.Payload()) + pub := s1.pubs[s1.myIndex] + require.NoError(t, service1.PreHeader().Verify(pub, cm.GetPreCommit().Data())) + } else { + require.Equal(t, dbft.CommitType, cm.Type()) + require.EqualValues(t, s1.currHeight+1, cm.Height()) + require.EqualValues(t, 1, cm.ViewNumber()) + require.EqualValues(t, s1.myIndex, cm.ValidatorIndex()) + require.NotNil(t, cm.Payload()) + } + + service2.OnReceive(cv0) + service2.OnReceive(cv3) + service2.OnTimeout(s2.currHeight+1, 0) + cv = s2.tryRecv() + require.NotNil(t, cv) + require.Equal(t, dbft.ChangeViewType, cv.Type()) + + require.Equal(t, 1, int(service2.ViewNumber)) + + // s2 has some PrepareResponses, but doesn't have a request. + service2.OnReceive(req) + + resp = s2.tryRecv() + require.NotNil(t, resp) + require.Equal(t, dbft.PrepareResponseType, resp.Type()) + + cm = s2.tryRecv() + require.NotNil(t, cm) + + if amev { + require.Equal(t, dbft.PreCommitType, cm.Type()) + require.EqualValues(t, s2.currHeight+1, cm.Height()) + require.EqualValues(t, 1, cm.ViewNumber()) + require.EqualValues(t, s2.myIndex, cm.ValidatorIndex()) + require.NotNil(t, cm.Payload()) + pub := s1.pubs[s1.myIndex] + require.NoError(t, service1.PreHeader().Verify(pub, cm.GetPreCommit().Data())) + + service2.OnReceive(s2.getPreCommit(0, service2.PreHeader().Data(), 1)) + cm = s2.tryRecv() + require.NotNil(t, cm) + require.Equal(t, dbft.CommitType, cm.Type()) + } else { + require.Equal(t, dbft.CommitType, cm.Type()) + require.EqualValues(t, s2.currHeight+1, cm.Height()) + require.EqualValues(t, 1, cm.ViewNumber()) + require.EqualValues(t, s2.myIndex, cm.ValidatorIndex()) + require.NotNil(t, cm.Payload()) + + require.NoError(t, service2.Header().Sign(s2.privs[0])) + service2.OnReceive(s2.getCommit(0, service2.Header().Signature(), 1)) + require.Nil(t, s2.tryRecv()) + b := s2.nextBlock() + require.NotNil(t, b) + require.Equal(t, s2.currHeight+1, b.Index()) + } + }) + } +} + func (s testState) getChangeView(from uint16, view byte) Payload { cv := consensus.NewChangeView(view, 0, 0) @@ -755,16 +980,28 @@ func (s testState) getRecoveryRequest(from uint16) Payload { return p } -func (s testState) getCommit(from uint16, sign []byte) Payload { +func (s testState) getCommit(from uint16, sign []byte, view byte) Payload { c := consensus.NewCommit(sign) + p := consensus.NewConsensusPayload(dbft.CommitType, s.currHeight+1, from, view, c) + return p +} + +func (s testState) getAMEVCommit(from uint16, sign []byte) Payload { + c := consensus.NewAMEVCommit(sign) p := consensus.NewConsensusPayload(dbft.CommitType, s.currHeight+1, from, 0, c) return p } -func (s testState) getPrepareResponse(from uint16, phash crypto.Uint256) Payload { +func (s testState) getPreCommit(from uint16, data []byte, view byte) Payload { + c := consensus.NewPreCommit(data) + p := consensus.NewConsensusPayload(dbft.PreCommitType, s.currHeight+1, from, view, c) + return p +} + +func (s testState) getPrepareResponse(from uint16, phash crypto.Uint256, view byte) Payload { resp := consensus.NewPrepareResponse(phash) - p := consensus.NewConsensusPayload(dbft.PrepareResponseType, s.currHeight+1, from, 0, resp) + p := consensus.NewConsensusPayload(dbft.PrepareResponseType, s.currHeight+1, from, view, resp) return p } @@ -813,6 +1050,17 @@ func (s *testState) nextBlock() dbft.Block[crypto.Uint256] { return b } +func (s *testState) nextPreBlock() dbft.PreBlock[crypto.Uint256] { + if len(s.preBlocks) == 0 { + return nil + } + + b := s.preBlocks[0] + s.preBlocks = s.preBlocks[1:] + + return b +} + func (s testState) copyWithIndex(myIndex int) *testState { return &testState{ myIndex: myIndex, @@ -874,6 +1122,20 @@ func (s *testState) getOptions() []func(*dbft.Config[crypto.Uint256]) { return opts } +func (s *testState) getAMEVOptions() []func(*dbft.Config[crypto.Uint256]) { + opts := s.getOptions() + opts = append(opts, + dbft.WithAntiMEVExtensionEnablingHeight[crypto.Uint256](0), + dbft.WithNewPreCommit[crypto.Uint256](consensus.NewPreCommit), + dbft.WithNewCommit[crypto.Uint256](consensus.NewAMEVCommit), + dbft.WithNewPreBlockFromContext[crypto.Uint256](newPreBlockFromContext), + dbft.WithNewBlockFromContext[crypto.Uint256](newAMEVBlockFromContext), + dbft.WithProcessPreBlock(func(b dbft.PreBlock[crypto.Uint256]) { s.preBlocks = append(s.preBlocks, b) }), + ) + + return opts +} + func newBlockFromContext(ctx *dbft.Context[crypto.Uint256]) dbft.Block[crypto.Uint256] { if ctx.TransactionHashes == nil { return nil @@ -882,6 +1144,28 @@ func newBlockFromContext(ctx *dbft.Context[crypto.Uint256]) dbft.Block[crypto.Ui return block } +func newPreBlockFromContext(ctx *dbft.Context[crypto.Uint256]) dbft.PreBlock[crypto.Uint256] { + if ctx.TransactionHashes == nil { + return nil + } + pre := consensus.NewPreBlock(ctx.Timestamp, ctx.BlockIndex, ctx.PrevHash, ctx.Nonce, ctx.TransactionHashes) + return pre +} + +func newAMEVBlockFromContext(ctx *dbft.Context[crypto.Uint256]) dbft.Block[crypto.Uint256] { + if ctx.TransactionHashes == nil { + return nil + } + var data [][]byte + for _, c := range ctx.PreCommitPayloads { + if c != nil && c.ViewNumber() == ctx.ViewNumber { + data = append(data, c.GetPreCommit().Data()) + } + } + pre := consensus.NewAMEVBlock(ctx.PreBlock(), data, ctx.M()) + return pre +} + // newConsensusPayload is a function for creating consensus payload of specific // type. func newConsensusPayload(c *dbft.Context[crypto.Uint256], t dbft.MessageType, msg any) dbft.ConsensusPayload[crypto.Uint256] { diff --git a/helpers.go b/helpers.go index 21f34e32..975b8453 100644 --- a/helpers.go +++ b/helpers.go @@ -3,9 +3,10 @@ package dbft type ( // inbox is a structure storing messages from a single epoch. inbox[H Hash] struct { - prepare map[uint16]ConsensusPayload[H] - chViews map[uint16]ConsensusPayload[H] - commit map[uint16]ConsensusPayload[H] + prepare map[uint16]ConsensusPayload[H] + chViews map[uint16]ConsensusPayload[H] + preCommit map[uint16]ConsensusPayload[H] + commit map[uint16]ConsensusPayload[H] } // cache is an auxiliary structure storing messages @@ -17,9 +18,10 @@ type ( func newInbox[H Hash]() *inbox[H] { return &inbox[H]{ - prepare: make(map[uint16]ConsensusPayload[H]), - chViews: make(map[uint16]ConsensusPayload[H]), - commit: make(map[uint16]ConsensusPayload[H]), + prepare: make(map[uint16]ConsensusPayload[H]), + chViews: make(map[uint16]ConsensusPayload[H]), + preCommit: make(map[uint16]ConsensusPayload[H]), + commit: make(map[uint16]ConsensusPayload[H]), } } @@ -50,6 +52,8 @@ func (c *cache[H]) addMessage(m ConsensusPayload[H]) { msgs.prepare[m.ValidatorIndex()] = m case ChangeViewType: msgs.chViews[m.ValidatorIndex()] = m + case PreCommitType: + msgs.preCommit[m.ValidatorIndex()] = m case CommitType: msgs.commit[m.ValidatorIndex()] = m } diff --git a/helpers_test.go b/helpers_test.go index 33a60920..0c62421c 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -51,6 +51,7 @@ func (p payloadStub) GetPrepareResponse() PrepareResponse[hash] { func (p payloadStub) GetCommit() Commit { panic("TODO") } +func (p payloadStub) GetPreCommit() PreCommit { panic("TODO") } func (p payloadStub) GetRecoveryRequest() RecoveryRequest { panic("TODO") } @@ -94,13 +95,21 @@ func TestMessageCache(t *testing.T) { } c.addMessage(p3) + p4 := payloadStub{ + height: 3, + typ: PreCommitType, + } + c.addMessage(p4) + box := c.getHeight(3) require.Len(t, box.chViews, 0) require.Len(t, box.prepare, 1) + require.Len(t, box.preCommit, 1) require.Len(t, box.commit, 0) box = c.getHeight(4) require.Len(t, box.chViews, 1) require.Len(t, box.prepare, 0) + require.Len(t, box.preCommit, 0) require.Len(t, box.commit, 1) } diff --git a/internal/consensus/amev_block.go b/internal/consensus/amev_block.go new file mode 100644 index 00000000..8e51c5b6 --- /dev/null +++ b/internal/consensus/amev_block.go @@ -0,0 +1,126 @@ +package consensus + +import ( + "bytes" + "encoding/binary" + "encoding/gob" + "math" + + "github.com/nspcc-dev/dbft" + "github.com/nspcc-dev/dbft/internal/crypto" + "github.com/nspcc-dev/dbft/internal/merkle" +) + +type amevBlock struct { + base + + transactions []dbft.Transaction[crypto.Uint256] + signature []byte + hash *crypto.Uint256 +} + +// NewAMEVBlock returns new block based on PreBlock and additional Commit-level data +// collected from M consensus nodes. +func NewAMEVBlock(pre dbft.PreBlock[crypto.Uint256], cnData [][]byte, m int) dbft.Block[crypto.Uint256] { + preB := pre.(*preBlock) + res := new(amevBlock) + res.base = preB.base + + // Based on the provided cnData we'll add one more transaction to the resulting block. + // Some artificial rules of new tx creation are invented here, but in Neo X there will + // be well-defined custom rules for Envelope transactions. + var sum uint32 + for i := 0; i < m; i++ { + sum += binary.BigEndian.Uint32(cnData[i]) + } + tx := Tx64(math.MaxInt64 - int64(sum)) + res.transactions = append(preB.initialTransactions, &tx) + + // Rebuild Merkle root for the new set of transations. + txHashes := make([]crypto.Uint256, len(res.transactions)) + for i := range txHashes { + txHashes[i] = res.transactions[i].Hash() + } + mt := merkle.NewMerkleTree(txHashes...) + res.base.MerkleRoot = mt.Root().Hash + + return res +} + +// PrevHash implements Block interface. +func (b *amevBlock) PrevHash() crypto.Uint256 { + return b.base.PrevHash +} + +// Index implements Block interface. +func (b *amevBlock) Index() uint32 { + return b.base.Index +} + +// MerkleRoot implements Block interface. +func (b *amevBlock) MerkleRoot() crypto.Uint256 { + return b.base.MerkleRoot +} + +// Transactions implements Block interface. +func (b *amevBlock) Transactions() []dbft.Transaction[crypto.Uint256] { + return b.transactions +} + +// SetTransactions implements Block interface. This method is special since it's +// left for dBFT 2.0 compatibility and transactions from this method must not be +// reused to fill final Block's transactions. +func (b *amevBlock) SetTransactions(_ []dbft.Transaction[crypto.Uint256]) { +} + +// Signature implements Block interface. +func (b *amevBlock) Signature() []byte { + return b.signature +} + +// GetHashData returns data for hashing and signing. +// It must be an injection of the set of blocks to the set +// of byte slices, i.e: +// 1. It must have only one valid result for one block. +// 2. Two different blocks must have different hash data. +func (b *amevBlock) GetHashData() []byte { + buf := bytes.Buffer{} + w := gob.NewEncoder(&buf) + _ = b.EncodeBinary(w) + + return buf.Bytes() +} + +// Sign implements Block interface. +func (b *amevBlock) Sign(key dbft.PrivateKey) error { + data := b.GetHashData() + + sign, err := key.Sign(data) + if err != nil { + return err + } + + b.signature = sign + + return nil +} + +// Verify implements Block interface. +func (b *amevBlock) Verify(pub dbft.PublicKey, sign []byte) error { + data := b.GetHashData() + return pub.(*crypto.ECDSAPub).Verify(data, sign) +} + +// Hash implements Block interface. +func (b *amevBlock) Hash() (h crypto.Uint256) { + if b.hash != nil { + return *b.hash + } else if b.transactions == nil { + return + } + + hash := crypto.Hash256(b.GetHashData()) + b.hash = &hash + + return hash +} diff --git a/internal/consensus/amev_commit.go b/internal/consensus/amev_commit.go new file mode 100644 index 00000000..1b4e509e --- /dev/null +++ b/internal/consensus/amev_commit.go @@ -0,0 +1,44 @@ +package consensus + +import ( + "encoding/gob" + + "github.com/nspcc-dev/dbft" +) + +type ( + // amevCommit implements dbft.Commit. + amevCommit struct { + data [dataSize]byte + } + // amevCommitAux is an auxiliary structure for amevCommit encoding. + amevCommitAux struct { + Data [dataSize]byte + } +) + +const dataSize = 64 + +var _ dbft.Commit = (*amevCommit)(nil) + +// EncodeBinary implements Serializable interface. +func (c amevCommit) EncodeBinary(w *gob.Encoder) error { + return w.Encode(amevCommitAux{ + Data: c.data, + }) +} + +// DecodeBinary implements Serializable interface. +func (c *amevCommit) DecodeBinary(r *gob.Decoder) error { + aux := new(amevCommitAux) + if err := r.Decode(aux); err != nil { + return err + } + c.data = aux.Data + return nil +} + +// Signature implements Commit interface. +func (c amevCommit) Signature() []byte { + return c.data[:] +} diff --git a/internal/consensus/amev_preBlock.go b/internal/consensus/amev_preBlock.go new file mode 100644 index 00000000..23ee7331 --- /dev/null +++ b/internal/consensus/amev_preBlock.go @@ -0,0 +1,79 @@ +package consensus + +import ( + "encoding/binary" + "errors" + + "github.com/nspcc-dev/dbft" + "github.com/nspcc-dev/dbft/internal/crypto" + "github.com/nspcc-dev/dbft/internal/merkle" +) + +type preBlock struct { + base + + // A magic number CN nodes should exchange during Commit phase + // and used to construct the final list of transactions for amevBlock. + data uint32 + + initialTransactions []dbft.Transaction[crypto.Uint256] +} + +var _ dbft.PreBlock[crypto.Uint256] = new(preBlock) + +// NewPreBlock returns new preBlock. +func NewPreBlock(timestamp uint64, index uint32, prevHash crypto.Uint256, nonce uint64, txHashes []crypto.Uint256) dbft.PreBlock[crypto.Uint256] { + pre := new(preBlock) + pre.base.Timestamp = uint32(timestamp / 1000000000) + pre.base.Index = index + + // NextConsensus and Version information is not provided by dBFT context, + // these are implementation-specific fields, and thus, should be managed outside the + // dBFT library. For simulation simplicity, let's assume that these fields are filled + // by every CN separately and is not verified. + pre.base.NextConsensus = crypto.Uint160{1, 2, 3} + pre.base.Version = 0 + + pre.base.PrevHash = prevHash + pre.base.ConsensusData = nonce + + // Canary default value. + pre.data = 0xff + + if len(txHashes) != 0 { + mt := merkle.NewMerkleTree(txHashes...) + pre.base.MerkleRoot = mt.Root().Hash + } + return pre +} + +func (pre *preBlock) Data() []byte { + var res = make([]byte, 4) + binary.BigEndian.PutUint32(res, pre.data) + return res +} + +func (pre *preBlock) SetData(_ dbft.PrivateKey) error { + // Just an artificial rule for data construction, it can be anything, and in Neo X + // it will be decrypted transactions fragments. + pre.data = pre.base.Index + return nil +} + +func (pre *preBlock) Verify(_ dbft.PublicKey, data []byte) error { + if len(data) != 4 { + return errors.New("invalid data len") + } + if binary.BigEndian.Uint32(data) != pre.base.Index { // Just an artificial verification rule, and for NeoX it should be decrypted transactions fragments verification. + return errors.New("invalid data") + } + return nil +} + +func (pre *preBlock) Transactions() []dbft.Transaction[crypto.Uint256] { + return pre.initialTransactions +} + +func (pre *preBlock) SetTransactions(txs []dbft.Transaction[crypto.Uint256]) { + pre.initialTransactions = txs +} diff --git a/internal/consensus/amev_preCommit.go b/internal/consensus/amev_preCommit.go new file mode 100644 index 00000000..956539f8 --- /dev/null +++ b/internal/consensus/amev_preCommit.go @@ -0,0 +1,45 @@ +package consensus + +import ( + "encoding/binary" + "encoding/gob" + + "github.com/nspcc-dev/dbft" +) + +type ( + // preCommit implements dbft.PreCommit. + preCommit struct { + magic uint32 // some magic data CN have to exchange to properly construct final amevBlock. + } + // preCommitAux is an auxiliary structure for preCommit encoding. + preCommitAux struct { + Magic uint32 + } +) + +var _ dbft.PreCommit = (*preCommit)(nil) + +// EncodeBinary implements Serializable interface. +func (c preCommit) EncodeBinary(w *gob.Encoder) error { + return w.Encode(preCommitAux{ + Magic: c.magic, + }) +} + +// DecodeBinary implements Serializable interface. +func (c *preCommit) DecodeBinary(r *gob.Decoder) error { + aux := new(preCommitAux) + if err := r.Decode(aux); err != nil { + return err + } + c.magic = aux.Magic + return nil +} + +// Data implements PreCommit interface. +func (c preCommit) Data() []byte { + res := make([]byte, 4) + binary.BigEndian.PutUint32(res, c.magic) + return res +} diff --git a/internal/consensus/compact.go b/internal/consensus/compact.go index 35a98eca..533a89fb 100644 --- a/internal/consensus/compact.go +++ b/internal/consensus/compact.go @@ -11,6 +11,12 @@ type ( Timestamp uint32 } + preCommitCompact struct { + ViewNumber byte + ValidatorIndex uint16 + Data []byte + } + commitCompact struct { ViewNumber byte ValidatorIndex uint16 @@ -32,6 +38,16 @@ func (p *changeViewCompact) DecodeBinary(r *gob.Decoder) error { return r.Decode(p) } +// EncodeBinary implements Serializable interface. +func (p preCommitCompact) EncodeBinary(w *gob.Encoder) error { + return w.Encode(p) +} + +// DecodeBinary implements Serializable interface. +func (p *preCommitCompact) DecodeBinary(r *gob.Decoder) error { + return r.Decode(p) +} + // EncodeBinary implements Serializable interface. func (p commitCompact) EncodeBinary(w *gob.Encoder) error { return w.Encode(p) diff --git a/internal/consensus/consensus_message.go b/internal/consensus/consensus_message.go index cd51b967..ca0412c5 100644 --- a/internal/consensus/consensus_message.go +++ b/internal/consensus/consensus_message.go @@ -88,6 +88,7 @@ func (m message) GetPrepareResponse() dbft.PrepareResponse[crypto.Uint256] { return m.payload.(dbft.PrepareResponse[crypto.Uint256]) } func (m message) GetCommit() dbft.Commit { return m.payload.(dbft.Commit) } +func (m message) GetPreCommit() dbft.PreCommit { return m.payload.(dbft.PreCommit) } func (m message) GetRecoveryRequest() dbft.RecoveryRequest { return m.payload.(dbft.RecoveryRequest) } func (m message) GetRecoveryMessage() dbft.RecoveryMessage[crypto.Uint256] { return m.payload.(dbft.RecoveryMessage[crypto.Uint256]) diff --git a/internal/consensus/constructors.go b/internal/consensus/constructors.go index 44326b02..096fa37d 100644 --- a/internal/consensus/constructors.go +++ b/internal/consensus/constructors.go @@ -1,6 +1,8 @@ package consensus import ( + "encoding/binary" + "github.com/nspcc-dev/dbft" "github.com/nspcc-dev/dbft/internal/crypto" ) @@ -49,6 +51,20 @@ func NewCommit(signature []byte) dbft.Commit { return c } +// NewPreCommit returns minimal dbft.PreCommit implementation. +func NewPreCommit(data []byte) dbft.PreCommit { + c := new(preCommit) + c.magic = binary.BigEndian.Uint32(data) + return c +} + +// NewAMEVCommit returns minimal dbft.Commit implementation for anti-MEV extension. +func NewAMEVCommit(data []byte) dbft.Commit { + c := new(amevCommit) + copy(c.data[:], data) + return c +} + // NewRecoveryRequest returns minimal RecoveryRequest implementation. func NewRecoveryRequest(ts uint64) dbft.RecoveryRequest { return &recoveryRequest{ diff --git a/internal/consensus/recovery_message.go b/internal/consensus/recovery_message.go index 686b7cee..ae9f8a46 100644 --- a/internal/consensus/recovery_message.go +++ b/internal/consensus/recovery_message.go @@ -1,6 +1,7 @@ package consensus import ( + "encoding/binary" "encoding/gob" "errors" @@ -12,6 +13,7 @@ type ( recoveryMessage struct { preparationHash *crypto.Uint256 preparationPayloads []preparationCompact + preCommitPayloads []preCommitCompact commitPayloads []commitCompact changeViewPayloads []changeViewCompact prepareRequest dbft.PrepareRequest[crypto.Uint256] @@ -19,6 +21,7 @@ type ( // recoveryMessageAux is an auxiliary structure for recoveryMessage encoding. recoveryMessageAux struct { PreparationPayloads []preparationCompact + PreCommitPayloads []preCommitCompact CommitPayloads []commitCompact ChangeViewPayloads []changeViewCompact } @@ -48,6 +51,13 @@ func (m *recoveryMessage) AddPayload(p dbft.ConsensusPayload[crypto.Uint256]) { OriginalViewNumber: p.ViewNumber(), Timestamp: 0, }) + case dbft.PreCommitType: + pcc := preCommitCompact{ + ViewNumber: p.ViewNumber(), + ValidatorIndex: p.ValidatorIndex(), + Data: p.GetPreCommit().Data(), + } + m.preCommitPayloads = append(m.preCommitPayloads, pcc) case dbft.CommitType: cc := commitCompact{ ViewNumber: p.ViewNumber(), @@ -119,6 +129,18 @@ func (m *recoveryMessage) GetChangeViews(p dbft.ConsensusPayload[crypto.Uint256] return payloads } +// GetPreCommits implements RecoveryMessage interface. +func (m *recoveryMessage) GetPreCommits(p dbft.ConsensusPayload[crypto.Uint256], _ []dbft.PublicKey) []dbft.ConsensusPayload[crypto.Uint256] { + payloads := make([]dbft.ConsensusPayload[crypto.Uint256], len(m.preCommitPayloads)) + + for i, c := range m.preCommitPayloads { + payloads[i] = fromPayload(dbft.PreCommitType, p, &preCommit{magic: binary.BigEndian.Uint32(c.Data)}) + payloads[i].SetValidatorIndex(c.ValidatorIndex) + } + + return payloads +} + // GetCommits implements RecoveryMessage interface. func (m *recoveryMessage) GetCommits(p dbft.ConsensusPayload[crypto.Uint256], _ []dbft.PublicKey) []dbft.ConsensusPayload[crypto.Uint256] { payloads := make([]dbft.ConsensusPayload[crypto.Uint256], len(m.commitPayloads)) diff --git a/internal/consensus/transaction.go b/internal/consensus/transaction.go new file mode 100644 index 00000000..cc4f3b90 --- /dev/null +++ b/internal/consensus/transaction.go @@ -0,0 +1,41 @@ +package consensus + +import ( + "encoding/binary" + "errors" + + "github.com/nspcc-dev/dbft" + "github.com/nspcc-dev/dbft/internal/crypto" +) + +// ============================= +// Small transaction. +// ============================= + +type Tx64 uint64 + +var _ dbft.Transaction[crypto.Uint256] = (*Tx64)(nil) + +func (t *Tx64) Hash() (h crypto.Uint256) { + binary.LittleEndian.PutUint64(h[:], uint64(*t)) + return +} + +// MarshalBinary implements encoding.BinaryMarshaler interface. +func (t *Tx64) MarshalBinary() ([]byte, error) { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(*t)) + + return b, nil +} + +// UnmarshalBinary implements encoding.BinaryUnarshaler interface. +func (t *Tx64) UnmarshalBinary(data []byte) error { + if len(data) != 8 { + return errors.New("length must equal 8 bytes") + } + + *t = Tx64(binary.LittleEndian.Uint64(data)) + + return nil +} diff --git a/internal/simulation/main.go b/internal/simulation/main.go index 5c42aa09..f1595c90 100644 --- a/internal/simulation/main.go +++ b/internal/simulation/main.go @@ -3,8 +3,6 @@ package main import ( "context" "crypto/rand" - "encoding/binary" - "errors" "flag" "fmt" "net/http" @@ -202,43 +200,11 @@ func (n *simNode) VerifyPayload(p dbft.ConsensusPayload[crypto.Uint256]) error { func (n *simNode) addTx(count int) { for i := 0; i < count; i++ { - tx := tx64(uint64(i)) + tx := consensus.Tx64(uint64(i)) n.pool.Add(&tx) } } -// ============================= -// Small transaction. -// ============================= - -type tx64 uint64 - -var _ dbft.Transaction[crypto.Uint256] = (*tx64)(nil) - -func (t *tx64) Hash() (h crypto.Uint256) { - binary.LittleEndian.PutUint64(h[:], uint64(*t)) - return -} - -// MarshalBinary implements encoding.BinaryMarshaler interface. -func (t *tx64) MarshalBinary() ([]byte, error) { - b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, uint64(*t)) - - return b, nil -} - -// UnmarshalBinary implements encoding.BinaryUnarshaler interface. -func (t *tx64) UnmarshalBinary(data []byte) error { - if len(data) != 8 { - return errors.New("length must equal 8 bytes") - } - - *t = tx64(binary.LittleEndian.Uint64(data)) - - return nil -} - // ============================= // Memory pool for transactions. // ============================= diff --git a/pre_block.go b/pre_block.go new file mode 100644 index 00000000..4b8a476c --- /dev/null +++ b/pre_block.go @@ -0,0 +1,24 @@ +package dbft + +// PreBlock is a generic interface for a PreBlock used by anti-MEV dBFT extension. +// It holds a "draft" of block that should be converted to a final block with the +// help of additional data held by PreCommit messages. +type PreBlock[H Hash] interface { + // Data returns PreBlock's data CNs need to exchange during PreCommit phase. + // Data represents additional information not related to a final block signature. + Data() []byte + // SetData generates and sets PreBlock's data CNs need to exchange during + // PreCommit phase. + SetData(key PrivateKey) error + // Verify checks if data related to PreCommit phase is correct. This method is + // refined on PreBlock rather than on PreCommit message since PreBlock itself is + // required for PreCommit's data verification. + Verify(key PublicKey, data []byte) error + + // Transactions returns PreBlock's transaction list. This list may be different + // comparing to the final set of Block's transactions. + Transactions() []Transaction[H] + // SetTransactions sets PreBlock's transaction list. This list may be different + // comparing to the final set of Block's transactions. + SetTransactions([]Transaction[H]) +} diff --git a/pre_commit.go b/pre_commit.go new file mode 100644 index 00000000..24d0a507 --- /dev/null +++ b/pre_commit.go @@ -0,0 +1,10 @@ +package dbft + +// PreCommit is an interface for dBFT PreCommit message. This message is used right +// before the Commit phase to exchange additional information required for the final +// block construction in anti-MEV dBFT extension. +type PreCommit interface { + // Data returns PreCommit's data that should be used for the final + // Block construction in anti-MEV dBFT extension. + Data() []byte +} diff --git a/recovery_message.go b/recovery_message.go index 303ecad0..02e62014 100644 --- a/recovery_message.go +++ b/recovery_message.go @@ -10,6 +10,10 @@ type RecoveryMessage[H Hash] interface { GetPrepareResponses(p ConsensusPayload[H], validators []PublicKey) []ConsensusPayload[H] // GetChangeViews returns a slice of ChangeView in any order. GetChangeViews(p ConsensusPayload[H], validators []PublicKey) []ConsensusPayload[H] + // GetPreCommits returns a slice of PreCommit messages in any order. + // If implemented on networks with no AntiMEV extension it can just + // always return nil. + GetPreCommits(p ConsensusPayload[H], validators []PublicKey) []ConsensusPayload[H] // GetCommits returns a slice of Commit in any order. GetCommits(p ConsensusPayload[H], validators []PublicKey) []ConsensusPayload[H] diff --git a/send.go b/send.go index 9ab7ac76..5b25e4e1 100644 --- a/send.go +++ b/send.go @@ -99,6 +99,25 @@ func (d *DBFT[H]) sendPrepareResponse() { d.broadcast(msg) } +func (c *Context[H]) makePreCommit() ConsensusPayload[H] { + if msg := c.PreCommitPayloads[c.MyIndex]; msg != nil { + return msg + } + + if preB := c.MakePreHeader(); preB != nil { + var preData []byte + if err := preB.SetData(c.Priv); err == nil { + preData = preB.Data() + } + + preCommit := c.Config.NewPreCommit(preData) + + return c.Config.NewConsensusPayload(c, PreCommitType, preCommit) + } + + return nil +} + func (c *Context[H]) makeCommit() ConsensusPayload[H] { if msg := c.CommitPayloads[c.MyIndex]; msg != nil { return msg @@ -118,6 +137,13 @@ func (c *Context[H]) makeCommit() ConsensusPayload[H] { return nil } +func (d *DBFT[H]) sendPreCommit() { + msg := d.makePreCommit() + d.PreCommitPayloads[d.MyIndex] = msg + d.Logger.Info("sending PreCommit", zap.Uint32("height", d.BlockIndex), zap.Uint("view", uint(d.ViewNumber))) + d.broadcast(msg) +} + func (d *DBFT[H]) sendCommit() { msg := d.makeCommit() d.CommitPayloads[d.MyIndex] = msg @@ -154,6 +180,14 @@ func (c *Context[H]) makeRecoveryMessage() ConsensusPayload[H] { } } + if c.PreCommitSent() { + for _, p := range c.PreCommitPayloads { + if p != nil { + recovery.AddPayload(p) + } + } + } + if c.CommitSent() { for _, p := range c.CommitPayloads { if p != nil {