Skip to content

Commit

Permalink
dbft: make future message cache work properly, add a test
Browse files Browse the repository at this point in the history
start() was designed to be called at every view change, but looks like it
_never_ worked this way. Which means two things:
 * on every view change Primary doesn't send PrepareRequest during
   initialization (which is mostly OK, OnTimeout() will be triggered
   immediately with 0 timeout)
 * our future message caching system has never really worked since start() is
   the only place where messages can be picked up from it

Just drop start(), make caches work and add a test for them.

Signed-off-by: Roman Khimov <[email protected]>
  • Loading branch information
roman-khimov committed Jul 31, 2024
1 parent 9f0ad22 commit d5baa08
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 51 deletions.
51 changes: 22 additions & 29 deletions dbft.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ func (d *DBFT[H]) addTransaction(tx Transaction[H]) {
func (d *DBFT[H]) Start(ts uint64) {
d.cache = newCache[H]()
d.initializeConsensus(0, ts)
d.start()
if d.IsPrimary() {
d.sendPrepareRequest()
}
}

// Reset reinitializes dBFT instance with the given timestamp of the previous
Expand Down Expand Up @@ -110,6 +112,25 @@ func (d *DBFT[H]) initializeConsensus(view byte, ts uint64) {
zap.Int("index", d.MyIndex),
zap.String("role", role))

// Process cached messages if any.
if msgs := d.cache.getHeight(d.BlockIndex); msgs != nil {
for _, m := range msgs.prepare {
d.OnReceive(m)
}

for _, m := range msgs.chViews {
d.OnReceive(m)

Check warning on line 122 in dbft.go

View check run for this annotation

Codecov / codecov/patch

dbft.go#L122

Added line #L122 was not covered by tests
}

for _, m := range msgs.preCommit {
d.OnReceive(m)
}

for _, m := range msgs.commit {
d.OnReceive(m)
}
}

if d.Context.WatchOnly() {
return
}
Expand Down Expand Up @@ -270,34 +291,6 @@ func (d *DBFT[H]) OnReceive(msg ConsensusPayload[H]) {
}
}

// start performs initial operations and returns messages to be sent.
// It must be called after every height or view increment.
func (d *DBFT[H]) start() {
if !d.IsPrimary() {
if msgs := d.cache.getHeight(d.BlockIndex); msgs != nil {
for _, m := range msgs.prepare {
d.OnReceive(m)
}

for _, m := range msgs.chViews {
d.OnReceive(m)
}

for _, m := range msgs.preCommit {
d.OnReceive(m)
}

for _, m := range msgs.commit {
d.OnReceive(m)
}
}

return
}

d.sendPrepareRequest()
}

func (d *DBFT[H]) onPrepareRequest(msg ConsensusPayload[H]) {
// ignore prepareRequest if we had already received it or
// are in process of changing view
Expand Down
170 changes: 148 additions & 22 deletions dbft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) {
require.Nil(t, s.tryRecv())

t.Run("receive response from primary", func(t *testing.T) {
resp := s.getPrepareResponse(5, p.Hash())
resp := s.getPrepareResponse(5, p.Hash(), 0)

service.OnReceive(resp)
require.Nil(t, s.tryRecv())
Expand Down Expand Up @@ -263,8 +263,8 @@ func TestDBFT_CommitOnTransaction(t *testing.T) {
tx := testTx(42)
req := s.getPrepareRequest(2, tx.Hash())
srv.OnReceive(req)
srv.OnReceive(s.getPrepareResponse(1, req.Hash()))
srv.OnReceive(s.getPrepareResponse(3, req.Hash()))
srv.OnReceive(s.getPrepareResponse(1, req.Hash(), 0))
srv.OnReceive(s.getPrepareResponse(3, req.Hash(), 0))
require.Nil(t, srv.Header()) // missing transaction.

// Test state for forming header.
Expand All @@ -279,13 +279,13 @@ func TestDBFT_CommitOnTransaction(t *testing.T) {
srv1, _ := dbft.New[crypto.Uint256](s1.getOptions()...)
srv1.Start(0)
srv1.OnReceive(req)
srv1.OnReceive(s1.getPrepareResponse(1, req.Hash()))
srv1.OnReceive(s1.getPrepareResponse(3, req.Hash()))
srv1.OnReceive(s1.getPrepareResponse(1, req.Hash(), 0))
srv1.OnReceive(s1.getPrepareResponse(3, req.Hash(), 0))
require.NotNil(t, srv1.Header())

for _, i := range []uint16{1, 2, 3} {
require.NoError(t, srv1.Header().Sign(s1.privs[i]))
c := s1.getCommit(i, srv1.Header().Signature())
c := s1.getCommit(i, srv1.Header().Signature(), 0)
srv.OnReceive(c)
}

Expand All @@ -304,11 +304,11 @@ func TestDBFT_OnReceiveCommit(t *testing.T) {
req := s.tryRecv()
require.NotNil(t, req)

resp := s.getPrepareResponse(1, req.Hash())
resp := s.getPrepareResponse(1, req.Hash(), 0)
service.OnReceive(resp)
require.Nil(t, s.tryRecv())

resp = s.getPrepareResponse(0, req.Hash())
resp = s.getPrepareResponse(0, req.Hash(), 0)
service.OnReceive(resp)

cm := s.tryRecv()
Expand Down Expand Up @@ -336,14 +336,14 @@ func TestDBFT_OnReceiveCommit(t *testing.T) {
t.Run("process block after enough commits", func(t *testing.T) {
s0 := s.copyWithIndex(0)
require.NoError(t, service.Header().Sign(s0.privs[0]))
c0 := s0.getCommit(0, service.Header().Signature())
c0 := s0.getCommit(0, service.Header().Signature(), 0)
service.OnReceive(c0)
require.Nil(t, s.tryRecv())
require.Nil(t, s.nextBlock())

s1 := s.copyWithIndex(1)
require.NoError(t, service.Header().Sign(s1.privs[1]))
c1 := s1.getCommit(1, service.Header().Signature())
c1 := s1.getCommit(1, service.Header().Signature(), 0)
service.OnReceive(c1)
require.Nil(t, s.tryRecv())

Expand All @@ -364,11 +364,11 @@ func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) {
req := s.tryRecv()
require.NotNil(t, req)

resp := s.getPrepareResponse(1, req.Hash())
resp := s.getPrepareResponse(1, req.Hash(), 0)
service.OnReceive(resp)
require.Nil(t, s.tryRecv())

resp = s.getPrepareResponse(0, req.Hash())
resp = s.getPrepareResponse(0, req.Hash(), 0)
service.OnReceive(resp)
cm := s.tryRecv()
require.NotNil(t, cm)
Expand Down Expand Up @@ -773,11 +773,11 @@ func TestDBFT_OnReceiveCommitAMEV(t *testing.T) {
req := s.tryRecv()
require.NotNil(t, req)

resp := s.getPrepareResponse(1, req.Hash())
resp := s.getPrepareResponse(1, req.Hash(), 0)
service.OnReceive(resp)
require.Nil(t, s.tryRecv())

resp = s.getPrepareResponse(0, req.Hash())
resp = s.getPrepareResponse(0, req.Hash(), 0)
service.OnReceive(resp)

cm := s.tryRecv()
Expand All @@ -794,15 +794,15 @@ func TestDBFT_OnReceiveCommitAMEV(t *testing.T) {
t.Run("send commit after enough preCommits", func(t *testing.T) {
s0 := s.copyWithIndex(0)
require.NoError(t, service.PreHeader().SetData(s0.privs[0]))
preC0 := s0.getPreCommit(0, service.PreHeader().Data())
preC0 := s0.getPreCommit(0, service.PreHeader().Data(), 0)
service.OnReceive(preC0)
require.Nil(t, s.tryRecv())
require.Nil(t, s.nextPreBlock())
require.Nil(t, s.nextBlock())

s1 := s.copyWithIndex(1)
require.NoError(t, service.PreHeader().SetData(s1.privs[1]))
preC1 := s1.getPreCommit(1, service.PreHeader().Data())
preC1 := s1.getPreCommit(1, service.PreHeader().Data(), 0)
service.OnReceive(preC1)

b := s.nextPreBlock()
Expand Down Expand Up @@ -842,6 +842,132 @@ func TestDBFT_OnReceiveCommitAMEV(t *testing.T) {
})
}

func TestDBFT_CachedMessages(t *testing.T) {
for _, amev := range []bool{false, true} {
t.Run(fmt.Sprintf("AMEV %t", amev), func(t *testing.T) {
s2 := newTestState(2, 4)
s2.currHeight = 1
s1 := newTestState(1, 4)
s1.currHeight = 1

opts := s2.getOptions()
if amev {
opts = s2.getAMEVOptions()
}
service2, _ := dbft.New[crypto.Uint256](opts...)
service2.Start(0)

opts = s1.getOptions()
if amev {
opts = s1.getAMEVOptions()
}
service1, _ := dbft.New[crypto.Uint256](opts...)
service1.Start(0)

req := s2.tryRecv()
require.NotNil(t, req) // Primary sends a request.
require.Equal(t, dbft.PrepareRequestType, req.Type())

require.Nil(t, s1.tryRecv()) // Backup waits.

cv0 := s1.getChangeView(0, 1)
cv3 := s1.getChangeView(3, 1)
service1.OnReceive(cv0)
service1.OnReceive(cv3)
service1.OnTimeout(s1.currHeight+1, 0)

cv := s1.tryRecv()
require.NotNil(t, cv)
require.Equal(t, dbft.ChangeViewType, cv.Type())

service1.OnTimeout(s1.currHeight+1, 1)
req = s1.tryRecv()
require.NotNil(t, req)
require.Equal(t, dbft.PrepareRequestType, req.Type())

resp := s1.getPrepareResponse(3, req.Hash(), 1)
service1.OnReceive(resp)
require.Nil(t, s1.tryRecv())
service2.OnReceive(resp) // From the future.
require.Nil(t, s2.tryRecv())

resp = s1.getPrepareResponse(0, req.Hash(), 1)
service2.OnReceive(resp) // From the future.
require.Nil(t, s2.tryRecv())

service1.OnReceive(resp)
cm := s1.tryRecv()
require.NotNil(t, cm)

service2.OnReceive(cm)
require.Nil(t, s2.tryRecv())

if amev {
require.Equal(t, dbft.PreCommitType, cm.Type())
require.EqualValues(t, s1.currHeight+1, cm.Height())
require.EqualValues(t, 1, cm.ViewNumber())
require.EqualValues(t, s1.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())
pub := s1.pubs[s1.myIndex]
require.NoError(t, service1.PreHeader().Verify(pub, cm.GetPreCommit().Data()))
} else {
require.Equal(t, dbft.CommitType, cm.Type())
require.EqualValues(t, s1.currHeight+1, cm.Height())
require.EqualValues(t, 1, cm.ViewNumber())
require.EqualValues(t, s1.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())
}

service2.OnReceive(cv0)
service2.OnReceive(cv3)
service2.OnTimeout(s2.currHeight+1, 0)
cv = s2.tryRecv()
require.NotNil(t, cv)
require.Equal(t, dbft.ChangeViewType, cv.Type())

require.Equal(t, 1, int(service2.ViewNumber))

// s2 has some PrepareResponses, but doesn't have a request.
service2.OnReceive(req)

resp = s2.tryRecv()
require.NotNil(t, resp)
require.Equal(t, dbft.PrepareResponseType, resp.Type())

cm = s2.tryRecv()
require.NotNil(t, cm)

if amev {
require.Equal(t, dbft.PreCommitType, cm.Type())
require.EqualValues(t, s2.currHeight+1, cm.Height())
require.EqualValues(t, 1, cm.ViewNumber())
require.EqualValues(t, s2.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())
pub := s1.pubs[s1.myIndex]
require.NoError(t, service1.PreHeader().Verify(pub, cm.GetPreCommit().Data()))

service2.OnReceive(s2.getPreCommit(0, service2.PreHeader().Data(), 1))
cm = s2.tryRecv()
require.NotNil(t, cm)
require.Equal(t, dbft.CommitType, cm.Type())
} else {
require.Equal(t, dbft.CommitType, cm.Type())
require.EqualValues(t, s2.currHeight+1, cm.Height())
require.EqualValues(t, 1, cm.ViewNumber())
require.EqualValues(t, s2.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())

require.NoError(t, service2.Header().Sign(s2.privs[0]))
service2.OnReceive(s2.getCommit(0, service2.Header().Signature(), 1))
require.Nil(t, s2.tryRecv())
b := s2.nextBlock()
require.NotNil(t, b)
require.Equal(t, s2.currHeight+1, b.Index())
}
})
}
}

func (s testState) getChangeView(from uint16, view byte) Payload {
cv := consensus.NewChangeView(view, 0, 0)

Expand All @@ -854,9 +980,9 @@ func (s testState) getRecoveryRequest(from uint16) Payload {
return p
}

func (s testState) getCommit(from uint16, sign []byte) Payload {
func (s testState) getCommit(from uint16, sign []byte, view byte) Payload {
c := consensus.NewCommit(sign)
p := consensus.NewConsensusPayload(dbft.CommitType, s.currHeight+1, from, 0, c)
p := consensus.NewConsensusPayload(dbft.CommitType, s.currHeight+1, from, view, c)
return p
}

Expand All @@ -866,16 +992,16 @@ func (s testState) getAMEVCommit(from uint16, sign []byte) Payload {
return p
}

func (s testState) getPreCommit(from uint16, data []byte) Payload {
func (s testState) getPreCommit(from uint16, data []byte, view byte) Payload {
c := consensus.NewPreCommit(data)
p := consensus.NewConsensusPayload(dbft.PreCommitType, s.currHeight+1, from, 0, c)
p := consensus.NewConsensusPayload(dbft.PreCommitType, s.currHeight+1, from, view, c)
return p
}

func (s testState) getPrepareResponse(from uint16, phash crypto.Uint256) Payload {
func (s testState) getPrepareResponse(from uint16, phash crypto.Uint256, view byte) Payload {
resp := consensus.NewPrepareResponse(phash)

p := consensus.NewConsensusPayload(dbft.PrepareResponseType, s.currHeight+1, from, 0, resp)
p := consensus.NewConsensusPayload(dbft.PrepareResponseType, s.currHeight+1, from, view, resp)
return p
}

Expand Down

0 comments on commit d5baa08

Please sign in to comment.