From 3c7071c891dc85133c32b10646138ea9b0e36d3d Mon Sep 17 00:00:00 2001 From: Matthias <5011972+fasmat@users.noreply.github.com> Date: Wed, 30 Oct 2024 17:43:39 +0000 Subject: [PATCH] Continue implementing invalid post proof --- activation/handler_v2_test.go | 2 +- activation/wire/interface.go | 2 +- .../wire/malfeasance_double_marry_test.go | 10 +- activation/wire/malfeasance_invalid_post.go | 93 +++++++++++++------ .../wire/malfeasance_invalid_post_scale.go | 58 ++++++++++-- activation/wire/mocks.go | 26 +++--- 6 files changed, 137 insertions(+), 54 deletions(-) diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 0da9a40095..c27ab6b494 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -1657,7 +1657,7 @@ func Test_Marriages(t *testing.T) { atxHandler.expectAtxV2(atx2) verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) - verifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return atxHandler.edVerifier.Verify(d, nodeID, m, sig) }).MinTimes(1) diff --git a/activation/wire/interface.go b/activation/wire/interface.go index aa52f4e55e..ba5006e3cd 100644 --- a/activation/wire/interface.go +++ b/activation/wire/interface.go @@ -19,7 +19,7 @@ type MalfeasanceValidator interface { post *types.Post, challenge []byte, numUnits uint32, - idx uint64, + idx int, ) error // Signature validates the given signature against the given message and public key. diff --git a/activation/wire/malfeasance_double_marry_test.go b/activation/wire/malfeasance_double_marry_test.go index 760c00aad3..0876a60288 100644 --- a/activation/wire/malfeasance_double_marry_test.go +++ b/activation/wire/malfeasance_double_marry_test.go @@ -48,7 +48,7 @@ func Test_DoubleMarryProof(t *testing.T) { ctrl := gomock.NewController(t) verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() @@ -147,7 +147,7 @@ func Test_DoubleMarryProof(t *testing.T) { ctrl := gomock.NewController(t) verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() @@ -200,7 +200,7 @@ func Test_DoubleMarryProof(t *testing.T) { ctrl := gomock.NewController(t) verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() @@ -243,7 +243,7 @@ func Test_DoubleMarryProof(t *testing.T) { ctrl := gomock.NewController(t) verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() @@ -284,7 +284,7 @@ func Test_DoubleMarryProof(t *testing.T) { ctrl := gomock.NewController(t) verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() diff --git a/activation/wire/malfeasance_invalid_post.go b/activation/wire/malfeasance_invalid_post.go index b06c7ddb6f..6e5ec329d6 100644 --- a/activation/wire/malfeasance_invalid_post.go +++ b/activation/wire/malfeasance_invalid_post.go @@ -27,9 +27,12 @@ type ProofInvalidPost struct { // NodeID is the node ID that created the invalid proof NodeID types.NodeID + // Commitment is the proof for the commitment ATX of the smesher. It is generated from the initial ATX of `NodeID`. Commitment CommitmentProof - InvalidPost InvalidATXPostProof + // InvalidPost is the proof for the invalid PoST of the ATX. It contains the PoST and the merkle proofs to verify + // the PoST. + InvalidPost InvalidPostProof // TODO(mafa): add marriage ATX proof - the marriage index is needed to verify that NodeID created the proof } @@ -43,16 +46,14 @@ func NewInvalidPostProof(atx, initialAtx *ActivationTxV2) (*ProofInvalidPost, er // Valid returns true if the proof is valid. It verifies that the two proofs have the same publish epoch, smesher ID, // and a valid signature but different ATX IDs as well as that the provided merkle proofs are valid. -func (p ProofInvalidPost) Valid(edVerifier *signing.EdVerifier) (types.NodeID, error) { - if err := p.Commitment.Valid(edVerifier, p.NodeID); err != nil { +func (p ProofInvalidPost) Valid(malValidator MalfeasanceValidator) (types.NodeID, error) { + if err := p.Commitment.Valid(malValidator, p.NodeID); err != nil { return types.EmptyNodeID, fmt.Errorf("invalid commitment proof: %w", err) } // TODO(mafa): verify p.NodeID to match the ID in the marriage ATX via the marriage index - // TODO(mafa): nil is not a valid post verifier - replace edVerifier with interface to validation functions - // (signature and post) - if err := p.InvalidPost.Valid(edVerifier, nil, p.Commitment.CommitmentATX, p.NodeID); err != nil { + if err := p.InvalidPost.Valid(malValidator, p.NodeID, p.Commitment.CommitmentATX); err != nil { return types.EmptyNodeID, fmt.Errorf("invalid invalid post proof: %w", err) } @@ -82,8 +83,8 @@ type CommitmentProof struct { // Valid returns no error if the proof is valid. It verifies that the signature is valid and that the merkle proofs // are valid. -func (p CommitmentProof) Valid(edVerifier *signing.EdVerifier, nodeID types.NodeID) error { - if !edVerifier.Verify(signing.ATX, nodeID, p.ATXID.Bytes(), p.Signature) { +func (p CommitmentProof) Valid(malValidator MalfeasanceValidator, nodeID types.NodeID) error { + if !malValidator.Signature(signing.ATX, nodeID, p.ATXID.Bytes(), p.Signature) { return errors.New("invalid signature") } @@ -130,20 +131,33 @@ func (p CommitmentProof) Valid(edVerifier *signing.EdVerifier, nodeID types.Node return nil } -type InvalidATXPostProof struct { - // ATXID is the ID of the ATX being proven. It is the merkle root from the contents of the ATX. +type InvalidPostProof struct { + // ATXID is the ID of the ATX containing the invalid PoST. ATXID types.ATXID - // NiPostsRoot is the root of the NiPoST merkle tree. + // --- NiPost --- + + // NiPostsTreeRoot is the root of the merkle tree containing the NiPoSTs of the ATX. + NiPostsTreeRoot types.Hash32 + // NiPostsTreeProof contains the merkle path from the root of the ATX merkle tree (ATXID) to the Post field. + NiPostsTreeProof []types.Hash32 `scale:"max=32"` + + // NiPostsRoot is the root of the NiPoST containing the invalid PoST. NiPostsRoot types.Hash32 - // NiPostsProof contains the merkle path from the root of the ATX merkle tree (ATXID) to the Post field. - NiPostsProof []types.Hash32 `scale:"max=32"` + // NiPostsRootIndex is the index of the NiPoST in the NiPoSTs tree. + NiPostRootIndex uint16 + // NiPostsRootProof contains the merkle path from the NiPostsTreeRoot to the NiPostRoot field. + NiPostsRootProof []types.Hash32 `scale:"max=32"` + + // --- Challenge for PoST --- // Challenge for the NiPoST. Challenge types.Hash32 // ChallengeProof contains the merkle path from the NiPostsRoot to the Challenge field. ChallengeProof []types.Hash32 `scale:"max=32"` + // --- PoST --- + // PostsRoot is the root of the PoST merkle tree. PostsRoot types.Hash32 // PostsRootProof contains the merkle path from the NiPostsRoot to the PostsRoot field. @@ -152,7 +166,7 @@ type InvalidATXPostProof struct { // SubPostRoot is the root of the sub PoST merkle tree. SubPostRoot types.Hash32 // SubPostRootIndex is the index of the sub PoST in the NiPoST. - SubPostRootIndex uint64 + SubPostRootIndex uint16 // SubPostRootProof contains the merkle path from the PostsRoot to the SubPostRoot field. SubPostRootProof []types.Hash32 `scale:"max=32"` @@ -160,10 +174,12 @@ type InvalidATXPostProof struct { Post PostV1 // PostProof contains the merkle path from the SubPostRoot to the PoST field. PostProof []types.Hash32 `scale:"max=32"` + // NumUnits is the number of units in the PoST. NumUnits uint32 // NumUnitsProof contains the merkle path from the PoST to the NumUnits field. NumUnitsProof []types.Hash32 `scale:"max=32"` + // InvalidPostIndex is the index of the leaf that was identified to be invalid. InvalidPostIndex uint32 @@ -175,24 +191,25 @@ type InvalidATXPostProof struct { // Valid returns no error if the proof is valid. It verifies that the signature is valid, that the merkle proofs are // and that the provided post is invalid. -func (p InvalidATXPostProof) Valid( - edVerifier *signing.EdVerifier, - validator postVerifier, - commitmentATX types.ATXID, +func (p InvalidPostProof) Valid( + malValidator MalfeasanceValidator, nodeID types.NodeID, + commitmentATX types.ATXID, ) error { - if !edVerifier.Verify(signing.ATX, p.SmesherID, p.ATXID.Bytes(), p.Signature) { + if !malValidator.Signature(signing.ATX, p.SmesherID, p.ATXID.Bytes(), p.Signature) { return errors.New("invalid signature") } - nipostsProof := make([][]byte, len(p.NiPostsProof)) - for i, h := range p.NiPostsProof { - nipostsProof[i] = h.Bytes() + // -- NiPoST -- + + nipostsTreeProof := make([][]byte, len(p.NiPostsTreeProof)) + for i, h := range p.NiPostsTreeProof { + nipostsTreeProof[i] = h.Bytes() } ok, err := merkle.ValidatePartialTree( []uint64{uint64(NIPostsRootIndex)}, - [][]byte{p.NiPostsRoot.Bytes()}, - nipostsProof, + [][]byte{p.NiPostsTreeRoot.Bytes()}, + nipostsTreeProof, p.ATXID.Bytes(), atxTreeHash, ) @@ -203,6 +220,26 @@ func (p InvalidATXPostProof) Valid( return errors.New("invalid NiPoST root proof") } + nipostsProof := make([][]byte, len(p.NiPostsRootProof)) + for i, h := range p.NiPostsRootProof { + nipostsProof[i] = h.Bytes() + } + ok, err = merkle.ValidatePartialTree( + []uint64{uint64(p.NiPostRootIndex)}, + [][]byte{p.NiPostsRoot.Bytes()}, + nipostsProof, + p.NiPostsTreeRoot.Bytes(), + atxTreeHash, + ) + if err != nil { + return fmt.Errorf("validate NiPoST proof: %w", err) + } + if !ok { + return errors.New("invalid NiPoST proof") + } + + // -- Challenge for PoST -- + challengeProof := make([][]byte, len(p.ChallengeProof)) for i, h := range p.ChallengeProof { challengeProof[i] = h.Bytes() @@ -221,6 +258,8 @@ func (p InvalidATXPostProof) Valid( return errors.New("invalid NiPoST challenge proof") } + // --- PoST --- + postsProof := make([][]byte, len(p.PostsRootProof)) for i, h := range p.PostsRootProof { postsProof[i] = h.Bytes() @@ -229,7 +268,7 @@ func (p InvalidATXPostProof) Valid( []uint64{uint64(PostsRootIndex)}, [][]byte{p.PostsRoot.Bytes()}, postsProof, - p.NiPostsRoot.Bytes(), + p.NiPostsTreeRoot.Bytes(), atxTreeHash, ) if err != nil { @@ -244,7 +283,7 @@ func (p InvalidATXPostProof) Valid( subPostProof[i] = h.Bytes() } ok, err = merkle.ValidatePartialTree( - []uint64{p.SubPostRootIndex}, + []uint64{uint64(p.SubPostRootIndex)}, [][]byte{p.SubPostRoot.Bytes()}, subPostProof, p.PostsRoot.Bytes(), @@ -296,7 +335,7 @@ func (p InvalidATXPostProof) Valid( return errors.New("invalid PoST num units proof") } - if err := validator.PostV2Idx( + if err := malValidator.PostIndex( context.Background(), nodeID, commitmentATX, diff --git a/activation/wire/malfeasance_invalid_post_scale.go b/activation/wire/malfeasance_invalid_post_scale.go index 570e6742f3..0e14de62eb 100644 --- a/activation/wire/malfeasance_invalid_post_scale.go +++ b/activation/wire/malfeasance_invalid_post_scale.go @@ -152,7 +152,7 @@ func (t *CommitmentProof) DecodeScale(dec *scale.Decoder) (total int, err error) return total, nil } -func (t *InvalidATXPostProof) EncodeScale(enc *scale.Encoder) (total int, err error) { +func (t *InvalidPostProof) EncodeScale(enc *scale.Encoder) (total int, err error) { { n, err := scale.EncodeByteArray(enc, t.ATXID[:]) if err != nil { @@ -160,6 +160,20 @@ func (t *InvalidATXPostProof) EncodeScale(enc *scale.Encoder) (total int, err er } total += n } + { + n, err := scale.EncodeByteArray(enc, t.NiPostsTreeRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.NiPostsTreeProof, 32) + if err != nil { + return total, err + } + total += n + } { n, err := scale.EncodeByteArray(enc, t.NiPostsRoot[:]) if err != nil { @@ -168,7 +182,14 @@ func (t *InvalidATXPostProof) EncodeScale(enc *scale.Encoder) (total int, err er total += n } { - n, err := scale.EncodeStructSliceWithLimit(enc, t.NiPostsProof, 32) + n, err := scale.EncodeCompact16(enc, uint16(t.NiPostRootIndex)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.NiPostsRootProof, 32) if err != nil { return total, err } @@ -210,7 +231,7 @@ func (t *InvalidATXPostProof) EncodeScale(enc *scale.Encoder) (total int, err er total += n } { - n, err := scale.EncodeCompact64(enc, uint64(t.SubPostRootIndex)) + n, err := scale.EncodeCompact16(enc, uint16(t.SubPostRootIndex)) if err != nil { return total, err } @@ -275,7 +296,7 @@ func (t *InvalidATXPostProof) EncodeScale(enc *scale.Encoder) (total int, err er return total, nil } -func (t *InvalidATXPostProof) DecodeScale(dec *scale.Decoder) (total int, err error) { +func (t *InvalidPostProof) DecodeScale(dec *scale.Decoder) (total int, err error) { { n, err := scale.DecodeByteArray(dec, t.ATXID[:]) if err != nil { @@ -283,6 +304,21 @@ func (t *InvalidATXPostProof) DecodeScale(dec *scale.Decoder) (total int, err er } total += n } + { + n, err := scale.DecodeByteArray(dec, t.NiPostsTreeRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.NiPostsTreeProof = field + } { n, err := scale.DecodeByteArray(dec, t.NiPostsRoot[:]) if err != nil { @@ -290,13 +326,21 @@ func (t *InvalidATXPostProof) DecodeScale(dec *scale.Decoder) (total int, err er } total += n } + { + field, n, err := scale.DecodeCompact16(dec) + if err != nil { + return total, err + } + total += n + t.NiPostRootIndex = uint16(field) + } { field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) if err != nil { return total, err } total += n - t.NiPostsProof = field + t.NiPostsRootProof = field } { n, err := scale.DecodeByteArray(dec, t.Challenge[:]) @@ -336,12 +380,12 @@ func (t *InvalidATXPostProof) DecodeScale(dec *scale.Decoder) (total int, err er total += n } { - field, n, err := scale.DecodeCompact64(dec) + field, n, err := scale.DecodeCompact16(dec) if err != nil { return total, err } total += n - t.SubPostRootIndex = uint64(field) + t.SubPostRootIndex = uint16(field) } { field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) diff --git a/activation/wire/mocks.go b/activation/wire/mocks.go index 0cff39cb16..ae0fd1be61 100644 --- a/activation/wire/mocks.go +++ b/activation/wire/mocks.go @@ -43,7 +43,7 @@ func (m *MockMalfeasanceValidator) EXPECT() *MockMalfeasanceValidatorMockRecorde } // PostIndex mocks base method. -func (m *MockMalfeasanceValidator) PostIndex(ctx context.Context, smesherID types.NodeID, commitment types.ATXID, post *types.Post, challenge []byte, numUnits uint32, idx uint64) error { +func (m *MockMalfeasanceValidator) PostIndex(ctx context.Context, smesherID types.NodeID, commitment types.ATXID, post *types.Post, challenge []byte, numUnits uint32, idx int) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PostIndex", ctx, smesherID, commitment, post, challenge, numUnits, idx) ret0, _ := ret[0].(error) @@ -69,13 +69,13 @@ func (c *MockMalfeasanceValidatorPostIndexCall) Return(arg0 error) *MockMalfeasa } // Do rewrite *gomock.Call.Do -func (c *MockMalfeasanceValidatorPostIndexCall) Do(f func(context.Context, types.NodeID, types.ATXID, *types.Post, []byte, uint32, uint64) error) *MockMalfeasanceValidatorPostIndexCall { +func (c *MockMalfeasanceValidatorPostIndexCall) Do(f func(context.Context, types.NodeID, types.ATXID, *types.Post, []byte, uint32, int) error) *MockMalfeasanceValidatorPostIndexCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockMalfeasanceValidatorPostIndexCall) DoAndReturn(f func(context.Context, types.NodeID, types.ATXID, *types.Post, []byte, uint32, uint64) error) *MockMalfeasanceValidatorPostIndexCall { +func (c *MockMalfeasanceValidatorPostIndexCall) DoAndReturn(f func(context.Context, types.NodeID, types.ATXID, *types.Post, []byte, uint32, int) error) *MockMalfeasanceValidatorPostIndexCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -83,37 +83,37 @@ func (c *MockMalfeasanceValidatorPostIndexCall) DoAndReturn(f func(context.Conte // Signature mocks base method. func (m_2 *MockMalfeasanceValidator) Signature(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { m_2.ctrl.T.Helper() - ret := m_2.ctrl.Call(m_2, "Verify", d, nodeID, m, sig) + ret := m_2.ctrl.Call(m_2, "Signature", d, nodeID, m, sig) ret0, _ := ret[0].(bool) return ret0 } -// Verify indicates an expected call of Verify. -func (mr *MockMalfeasanceValidatorMockRecorder) Verify(d, nodeID, m, sig any) *MockMalfeasanceValidatorVerifyCall { +// Signature indicates an expected call of Signature. +func (mr *MockMalfeasanceValidatorMockRecorder) Signature(d, nodeID, m, sig any) *MockMalfeasanceValidatorSignatureCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockMalfeasanceValidator)(nil).Signature), d, nodeID, m, sig) - return &MockMalfeasanceValidatorVerifyCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signature", reflect.TypeOf((*MockMalfeasanceValidator)(nil).Signature), d, nodeID, m, sig) + return &MockMalfeasanceValidatorSignatureCall{Call: call} } -// MockMalfeasanceValidatorVerifyCall wrap *gomock.Call -type MockMalfeasanceValidatorVerifyCall struct { +// MockMalfeasanceValidatorSignatureCall wrap *gomock.Call +type MockMalfeasanceValidatorSignatureCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockMalfeasanceValidatorVerifyCall) Return(arg0 bool) *MockMalfeasanceValidatorVerifyCall { +func (c *MockMalfeasanceValidatorSignatureCall) Return(arg0 bool) *MockMalfeasanceValidatorSignatureCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockMalfeasanceValidatorVerifyCall) Do(f func(signing.Domain, types.NodeID, []byte, types.EdSignature) bool) *MockMalfeasanceValidatorVerifyCall { +func (c *MockMalfeasanceValidatorSignatureCall) Do(f func(signing.Domain, types.NodeID, []byte, types.EdSignature) bool) *MockMalfeasanceValidatorSignatureCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockMalfeasanceValidatorVerifyCall) DoAndReturn(f func(signing.Domain, types.NodeID, []byte, types.EdSignature) bool) *MockMalfeasanceValidatorVerifyCall { +func (c *MockMalfeasanceValidatorSignatureCall) DoAndReturn(f func(signing.Domain, types.NodeID, []byte, types.EdSignature) bool) *MockMalfeasanceValidatorSignatureCall { c.Call = c.Call.DoAndReturn(f) return c }