From 4963924563a2208de0e328740fdba05fa97a3618 Mon Sep 17 00:00:00 2001 From: Faulty Tolly <@faulttolerance.net> Date: Thu, 19 Sep 2024 19:42:38 +0200 Subject: [PATCH] fix pr suggestions --- types/errors.go | 2 ++ types/validation.go | 16 +++++----- types/validation_test.go | 64 +++++++++++++++++++++------------------- 3 files changed, 45 insertions(+), 37 deletions(-) diff --git a/types/errors.go b/types/errors.go index 56a463d6c..de553f577 100644 --- a/types/errors.go +++ b/types/errors.go @@ -16,6 +16,8 @@ var ( ErrInvalidBlockHeight = errors.New("invalid block height") ErrInvalidHeaderDataHash = errors.New("header not matching block data") ErrMissingProposerPubKey = fmt.Errorf("missing proposer public key: %w", gerrc.ErrNotFound) + ErrVersionMismatch = errors.New("version mismatch") + ErrEmptyProposerAddress = errors.New("no proposer address") ) type ErrFraudHeightMismatch struct { diff --git a/types/validation.go b/types/validation.go index 1fe575d1e..af6e5e1a6 100644 --- a/types/validation.go +++ b/types/validation.go @@ -12,7 +12,8 @@ import ( "github.com/dymensionxyz/dymint/fraud" ) -var MaxDrift = 10 * time.Minute +// TimeFraudMaxDrift is the maximum allowed time drift between the block time and the local time. +var TimeFraudMaxDrift = 10 * time.Minute type ErrTimeFraud struct { drift time.Duration @@ -38,9 +39,10 @@ func NewErrTimeFraud(block *Block, currentTime time.Time) error { func (e ErrTimeFraud) Error() string { return fmt.Sprintf( - `Sequencer %X posted a block with header hash %X, at height %dwith a time drift of %s, -when the maximum allowed limit is %s. Sequencer reported block time was %s while the node local time is %s`, - e.proposerAddress, e.headerHash, e.headerHeight, e.drift, MaxDrift, e.headerTime, e.currentTime, + "Sequencer posted a block with invalid time. "+ + "Max allowed drift exceeded. "+ + "proposerAddress=%s headerHash=%s headerHeight=%d drift=%s MaxDrift=%s headerTime=%s currentTime=%s", + e.proposerAddress, e.headerHash, e.headerHeight, e.drift, TimeFraudMaxDrift, e.headerTime, e.currentTime, ) } @@ -62,7 +64,7 @@ func ValidateProposedTransition(state *State, block *Block, commit *Commit, prop // ValidateBasic performs basic validation of a block. func (b *Block) ValidateBasic() error { currentTime := time.Now().UTC() - if currentTime.Add(MaxDrift).Before(time.Unix(0, int64(b.Header.Time))) { + if currentTime.Add(TimeFraudMaxDrift).Before(time.Unix(0, int64(b.Header.Time))) { return NewErrTimeFraud(b, currentTime) } @@ -94,7 +96,7 @@ func (b *Block) ValidateWithState(state *State) error { } if b.Header.Version.App != state.Version.Consensus.App || b.Header.Version.Block != state.Version.Consensus.Block { - return errors.New("b version mismatch") + return ErrVersionMismatch } if b.Header.Height != state.NextHeight() { @@ -115,7 +117,7 @@ func (b *Block) ValidateWithState(state *State) error { // ValidateBasic performs basic validation of a header. func (h *Header) ValidateBasic() error { if len(h.ProposerAddress) == 0 { - return errors.New("no proposer address") + return ErrEmptyProposerAddress } return nil diff --git a/types/validation_test.go b/types/validation_test.go index 7b40803f2..e1d3fa4ab 100644 --- a/types/validation_test.go +++ b/types/validation_test.go @@ -48,12 +48,13 @@ func TestBlock_ValidateWithState(t *testing.T) { validBlock.Header.DataHash = [32]byte(GetDataHash(validBlock)) tests := []struct { - name string - block *Block - state *State - wantErr bool - errMsg string - isFraud bool + name string + block *Block + state *State + wantErr bool + theErr error + expectedErrType interface{} + isFraud bool }{ { name: "Valid block", @@ -79,8 +80,8 @@ func TestBlock_ValidateWithState(t *testing.T) { }, }, state: validState, + theErr: ErrVersionMismatch, wantErr: true, - errMsg: "b version mismatch", isFraud: false, }, { @@ -101,7 +102,7 @@ func TestBlock_ValidateWithState(t *testing.T) { }, state: validState, wantErr: true, - errMsg: "b version mismatch", + theErr: ErrVersionMismatch, isFraud: false, }, { @@ -117,10 +118,10 @@ func TestBlock_ValidateWithState(t *testing.T) { DataHash: [32]byte(GetDataHash(validBlock)), }, }, - state: validState, - wantErr: true, - errMsg: "height mismatch", - isFraud: true, + state: validState, + wantErr: true, + expectedErrType: &ErrFraudHeightMismatch{}, + isFraud: true, }, { name: "Invalid AppHash", @@ -135,10 +136,10 @@ func TestBlock_ValidateWithState(t *testing.T) { DataHash: [32]byte(GetDataHash(validBlock)), }, }, - state: validState, - wantErr: true, - errMsg: "AppHash mismatch", - isFraud: true, + state: validState, + expectedErrType: &ErrFraudAppHashMismatch{}, + wantErr: true, + isFraud: true, }, { name: "Invalid LastResultsHash", @@ -153,10 +154,10 @@ func TestBlock_ValidateWithState(t *testing.T) { DataHash: [32]byte(GetDataHash(validBlock)), }, }, - state: validState, - wantErr: true, - errMsg: "LastResultsHash mismatch", - isFraud: true, + state: validState, + wantErr: true, + expectedErrType: &ErrLastResultsHashMismatch{}, + isFraud: true, }, { name: "Future block time", @@ -164,16 +165,16 @@ func TestBlock_ValidateWithState(t *testing.T) { Header: Header{ Version: validBlock.Header.Version, Height: 10, - Time: uint64(currentTime.Add(2 * MaxDrift).UnixNano()), + Time: uint64(currentTime.Add(2 * TimeFraudMaxDrift).UnixNano()), AppHash: [32]byte{1, 2, 3}, LastResultsHash: [32]byte{4, 5, 6}, ProposerAddress: []byte("proposer"), }, }, - state: validState, - wantErr: true, - errMsg: "Sequencer", - isFraud: true, + state: validState, + wantErr: true, + expectedErrType: &ErrTimeFraud{}, + isFraud: true, }, { name: "Invalid proposer address", @@ -187,10 +188,10 @@ func TestBlock_ValidateWithState(t *testing.T) { ProposerAddress: []byte{}, }, }, - state: validState, - wantErr: true, - errMsg: "no proposer address", - isFraud: false, + state: validState, + wantErr: true, + expectedErrType: ErrEmptyProposerAddress, + isFraud: false, }, } @@ -199,9 +200,12 @@ func TestBlock_ValidateWithState(t *testing.T) { err := tt.block.ValidateWithState(tt.state) if tt.wantErr { assert.Error(t, err) - assert.Contains(t, err.Error(), tt.errMsg) if tt.isFraud { require.True(t, errors.Is(err, fraud.ErrFraud)) + if tt.expectedErrType != nil { + assert.True(t, errors.As(err, &tt.expectedErrType), + "expected error of type %T, got %T", tt.expectedErrType, err) + } } else { require.False(t, errors.Is(err, fraud.ErrFraud)) }