From dd1530e91bfcbd6f3e66f347a7b044ab70c28383 Mon Sep 17 00:00:00 2001 From: Axay Sagathiya Date: Fri, 20 Dec 2024 20:21:16 +0530 Subject: [PATCH] Feat(dot/parachain): Add `CompactStatement` type (#4424) Introduced the `CompactStatement` and `compactStatementInner` types. We will only deal with `CompactStatement` in other codebases. `compactStatementInner` is only used for encoding/decoding of `CompactStatement`, which is why it's not exported. Implemented `MarshalSCALE` and `UnmarshalSCALE` methods to have custom encoding/decoding logic for `CompactStatement`. Also, the sign logic of statementVDT has been updated to encode the data properly before signing. --- dot/parachain/types/statement.go | 171 +++++++++++++++++++++++--- dot/parachain/types/statement_test.go | 62 ++++++++++ 2 files changed, 213 insertions(+), 20 deletions(-) diff --git a/dot/parachain/types/statement.go b/dot/parachain/types/statement.go index e73d1a4549..636096d7ba 100644 --- a/dot/parachain/types/statement.go +++ b/dot/parachain/types/statement.go @@ -4,7 +4,9 @@ package parachaintypes import ( + "bytes" "fmt" + "io" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto/sr25519" @@ -12,6 +14,8 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" ) +var backingStatementMagic = [4]byte{'B', 'K', 'N', 'G'} + // Statement is a result of candidate validation. It could be either `Valid` or `Seconded`. type StatementVDTValues interface { Valid | Seconded @@ -80,10 +84,42 @@ type Seconded CommittedCandidateReceipt // Valid represents a statement that a validator has deemed a candidate valid. type Valid CandidateHash -// statementVDTAndSigningContext is just a wrapper struct to hold both the statement and the signing context. -type statementVDTAndSigningContext struct { - Statement StatementVDT - Context SigningContext +// encodeSignData encodes the statement and signing context into a byte slice. +func encodeSignData(statement StatementVDT, signingContext SigningContext) ([]byte, error) { + buffer := bytes.NewBuffer(nil) + encoder := scale.NewEncoder(buffer) + + compact, err := statement.CompactStatement() + if err != nil { + return nil, fmt.Errorf("getting compact statement: %w", err) + } + + err = encoder.Encode(compact) + if err != nil { + return nil, fmt.Errorf("encoding compact statement: %w", err) + } + + err = encoder.Encode(signingContext) + if err != nil { + return nil, fmt.Errorf("encoding signing context: %w", err) + } + + return buffer.Bytes(), nil +} + +// CompactStatement returns a compact representation of the statement. +func (s StatementVDT) CompactStatement() (any, error) { + switch s := s.inner.(type) { + case Valid: + return CompactStatement[Valid]{Value: s}, nil + case Seconded: + hash, err := GetCandidateHash(CommittedCandidateReceipt(s)) + if err != nil { + return nil, fmt.Errorf("getting candidate hash: %w", err) + } + return CompactStatement[SecondedCandidateHash]{Value: SecondedCandidateHash(hash)}, nil + } + return nil, fmt.Errorf("unsupported type") } func (s *StatementVDT) Sign( @@ -91,14 +127,9 @@ func (s *StatementVDT) Sign( signingContext SigningContext, key ValidatorID, ) (*ValidatorSignature, error) { - statementAndSigningCtx := statementVDTAndSigningContext{ - Statement: *s, - Context: signingContext, - } - - encodedData, err := scale.Marshal(statementAndSigningCtx) + data, err := encodeSignData(*s, signingContext) if err != nil { - return nil, fmt.Errorf("marshalling statement and signing-context: %w", err) + return nil, fmt.Errorf("encoding data to sign: %w", err) } validatorPublicKey, err := sr25519.NewPublicKey(key[:]) @@ -106,7 +137,7 @@ func (s *StatementVDT) Sign( return nil, fmt.Errorf("getting public key: %w", err) } - signatureBytes, err := keystore.GetKeypair(validatorPublicKey).Sign(encodedData) + signatureBytes, err := keystore.GetKeypair(validatorPublicKey).Sign(data) if err != nil { return nil, fmt.Errorf("signing data: %w", err) } @@ -123,14 +154,9 @@ func (s *StatementVDT) VerifySignature( signingContext SigningContext, validatorSignature ValidatorSignature, ) (bool, error) { - statementAndSigningCtx := statementVDTAndSigningContext{ - Statement: *s, - Context: signingContext, - } - - encodedMsg, err := scale.Marshal(statementAndSigningCtx) + data, err := encodeSignData(*s, signingContext) if err != nil { - return false, fmt.Errorf("marshalling statement and signing-context: %w", err) + return false, fmt.Errorf("encoding signed data: %w", err) } publicKey, err := sr25519.NewPublicKey(validator[:]) @@ -138,7 +164,7 @@ func (s *StatementVDT) VerifySignature( return false, fmt.Errorf("getting public key: %w", err) } - return publicKey.Verify(encodedMsg, validatorSignature[:]) + return publicKey.Verify(data, validatorSignature[:]) } // UncheckedSignedFullStatement is a Variant of `SignedFullStatement` where the signature has not yet been verified. @@ -177,3 +203,108 @@ type SignedFullStatementWithPVD struct { // otherwise, it should be nil. PersistedValidationData *PersistedValidationData } + +type SecondedCandidateHash CandidateHash + +type CompactStatementValues interface { + Valid | SecondedCandidateHash +} + +// compactStatementInner is a helper struct that is used to encode/decode CompactStatement. +type compactStatementInner struct { + inner any +} + +func setCompactStatement[Value CompactStatementValues](mvdt *compactStatementInner, value Value) { + mvdt.inner = value +} + +func (mvdt *compactStatementInner) SetValue(value any) (err error) { + switch value := value.(type) { + case Valid: + setCompactStatement(mvdt, value) + return + case SecondedCandidateHash: + setCompactStatement(mvdt, value) + return + default: + return fmt.Errorf("unsupported type") + } +} + +func (mvdt compactStatementInner) IndexValue() (index uint, value any, err error) { + switch mvdt.inner.(type) { + case Valid: + return 2, mvdt.inner, nil + case SecondedCandidateHash: + return 1, mvdt.inner, nil + } + return 0, nil, scale.ErrUnsupportedVaryingDataTypeValue +} + +func (mvdt compactStatementInner) Value() (value any, err error) { + _, value, err = mvdt.IndexValue() + return +} + +func (mvdt compactStatementInner) ValueAt(index uint) (value any, err error) { + switch index { + case 2: + return Valid{}, nil + case 1: + return SecondedCandidateHash{}, nil + } + return nil, scale.ErrUnknownVaryingDataTypeValue +} + +// CompactStatement is a compact representation of a statement that can be made about parachain candidates. +// this is the actual value that is signed. +type CompactStatement[T CompactStatementValues] struct { + Value T +} + +func (c CompactStatement[CompactStatementValues]) MarshalSCALE() ([]byte, error) { + inner := compactStatementInner{} + err := inner.SetValue(c.Value) + if err != nil { + return nil, fmt.Errorf("setting value: %w", err) + } + + buffer := bytes.NewBuffer(backingStatementMagic[:]) + encoder := scale.NewEncoder(buffer) + + err = encoder.Encode(inner) + if err != nil { + return nil, err + } + + return buffer.Bytes(), nil +} + +func (c *CompactStatement[CompactStatementValues]) UnmarshalSCALE(reader io.Reader) error { + decoder := scale.NewDecoder(reader) + + var magicBytes [4]byte + err := decoder.Decode(&magicBytes) + if err != nil { + return err + } + + if !bytes.Equal(magicBytes[:], backingStatementMagic[:]) { + return fmt.Errorf("invalid magic bytes") + } + + var inner compactStatementInner + err = decoder.Decode(&inner) + if err != nil { + return fmt.Errorf("decoding compactStatementInner: %w", err) + } + + value, err := inner.Value() + if err != nil { + return fmt.Errorf("getting value: %w", err) + } + + c.Value = value.(CompactStatementValues) + return nil +} diff --git a/dot/parachain/types/statement_test.go b/dot/parachain/types/statement_test.go index 4e8a80024f..68821d0ee8 100644 --- a/dot/parachain/types/statement_test.go +++ b/dot/parachain/types/statement_test.go @@ -169,3 +169,65 @@ func TestStatementVDT_Sign(t *testing.T) { require.NoError(t, err) require.True(t, ok) } + +func TestCompactStatement(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + compactStatement any + encodingValue []byte + expectedErr error + }{ + { + name: "SecondedCandidateHash", + compactStatement: CompactStatement[SecondedCandidateHash]{ + Value: SecondedCandidateHash{Value: getDummyHash(6)}, + }, + encodingValue: []byte{66, 75, 78, 71, 1, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6}, + }, + { + name: "Valid", + compactStatement: CompactStatement[Valid]{ + Value: Valid{Value: getDummyHash(7)}, + }, + encodingValue: []byte{ + 66, 75, 78, 71, 2, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7}, + }, + } + + for _, c := range testCases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + t.Run("marshal", func(t *testing.T) { + t.Parallel() + + compactStatementBytes, err := scale.Marshal(c.compactStatement) + require.NoError(t, err) + require.Equal(t, c.encodingValue, compactStatementBytes) + }) + + t.Run("unmarshal", func(t *testing.T) { + t.Parallel() + + switch expectedSatetement := c.compactStatement.(type) { + case CompactStatement[Valid]: + var actualStatement CompactStatement[Valid] + err := scale.Unmarshal(c.encodingValue, &actualStatement) + require.NoError(t, err) + require.EqualValues(t, expectedSatetement, actualStatement) + case CompactStatement[SecondedCandidateHash]: + var actualStatement CompactStatement[SecondedCandidateHash] + err := scale.Unmarshal(c.encodingValue, &actualStatement) + require.NoError(t, err) + require.EqualValues(t, expectedSatetement, actualStatement) + } + }) + + }) + } +}