From 766ab41d4903450348b695b4a7837c655977b597 Mon Sep 17 00:00:00 2001 From: Matthias <5011972+fasmat@users.noreply.github.com> Date: Mon, 13 May 2024 00:35:35 +0000 Subject: [PATCH] Fix ATX handling of double initial ATXs --- activation/handler.go | 59 +++++++++++------ activation/handler_test.go | 127 +++++++++++++++++++++++++++++++++++-- sql/atxs/atxs.go | 2 +- 3 files changed, 161 insertions(+), 27 deletions(-) diff --git a/activation/handler.go b/activation/handler.go index a650b46b1f6..011b063b68c 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -483,32 +483,45 @@ func (h *Handler) checkWrongPrevAtx( return nil, fmt.Errorf("%s referenced incorrect previous ATX", atx.SmesherID.ShortString()) } - // check if atx.PrevATXID is actually the last published ATX by the same node - prev, err := atxs.Get(tx, prevID) - if err != nil { - return nil, fmt.Errorf("get prev atx: %w", err) - } - - // if atx references a previous ATX that is not the last ATX by the same node, there must be at least one - // atx published between prevATX and the current epoch var atx2 *types.VerifiedActivationTx - pubEpoch := h.clock.CurrentLayer().GetEpoch() - for pubEpoch > prev.PublishEpoch { - id, err := atxs.PrevIDByNodeID(tx, atx.SmesherID, pubEpoch) + if atx.PrevATXID == types.EmptyATXID { + // if the ATX references an empty previous ATX, we can just take the initial ATX and create a proof + // that the node referenced the wrong previous ATX + id, err := atxs.GetFirstIDByNodeID(tx, atx.SmesherID) if err != nil { - return nil, fmt.Errorf("get prev atx id by node id: %w", err) + return nil, fmt.Errorf("get initial atx: %w", err) } atx2, err = atxs.Get(tx, id) + if err != nil { + return nil, fmt.Errorf("get initial atx: %w", err) + } + } else { + prev, err := atxs.Get(tx, atx.PrevATXID) if err != nil { return nil, fmt.Errorf("get prev atx: %w", err) } - if atx.ID() != atx2.ID() && atx.PrevATXID == atx2.PrevATXID { - // found an ATX that points to the same previous ATX - break + // if atx references a previous ATX that is not the last ATX by the same node, there must be at least one + // atx published between prevATX and the current epoch + pubEpoch := h.clock.CurrentLayer().GetEpoch() + for pubEpoch > prev.PublishEpoch { + id, err := atxs.PrevIDByNodeID(tx, atx.SmesherID, pubEpoch) + if err != nil { + return nil, fmt.Errorf("get prev atx id by node id: %w", err) + } + + atx2, err = atxs.Get(tx, id) + if err != nil { + return nil, fmt.Errorf("get prev atx: %w", err) + } + + if atx.ID() != atx2.ID() && atx.PrevATXID == atx2.PrevATXID { + // found an ATX that points to the same previous ATX + break + } + pubEpoch = atx2.PublishEpoch } - pubEpoch = atx2.PublishEpoch } if atx2 == nil || atx2.PrevATXID != atx.PrevATXID { @@ -580,7 +593,15 @@ func (h *Handler) storeAtx(ctx context.Context, atx *types.VerifiedActivationTx) return nil, fmt.Errorf("store atx: %w", err) } if nonce == nil { - return nil, errors.New("no nonce") + if proof == nil { + return nil, errors.New("no nonce") + } + // special handling for a fake initial ATX (not first, but referencing empty as previous) not containing a nonce + vrf, err := atxs.VRFNonce(h.cdb, atx.SmesherID, atx.PublishEpoch) + if err != nil { + return nil, fmt.Errorf("get vrf nonce: %w", err) + } + nonce = &vrf } atxs.AtxAdded(h.cdb, atx) if proof != nil { @@ -713,7 +734,7 @@ func (h *Handler) processATX( poetRef, atxIDs := collectAtxDeps(h.goldenATXID, &atx) h.registerHashes(peer, poetRef, atxIDs) if err := h.fetchReferences(ctx, poetRef, atxIDs); err != nil { - return nil, fmt.Errorf("fetching references for atx %x: %w", atx.ID(), err) + return nil, fmt.Errorf("fetching references for atx %v: %w", atx.ID(), err) } vAtx, proof, err := h.SyntacticallyValidateDeps(ctx, &atx) @@ -768,7 +789,7 @@ func (h *Handler) fetchReferences(ctx context.Context, poetRef types.Hash32, atx } if err := h.fetcher.GetAtxs(ctx, atxIDs, system.WithoutLimiting()); err != nil { - return fmt.Errorf("missing atxs %x: %w", atxIDs, err) + return fmt.Errorf("missing atxs %v: %w", atxIDs, err) } h.log.WithContext(ctx).With().Debug("done fetching references", log.Int("fetched", len(atxIDs))) diff --git a/activation/handler_test.go b/activation/handler_test.go index 3c6eef52e90..96573237d5c 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -1130,7 +1130,7 @@ func TestHandler_ProcessAtx_SamePrevATX(t *testing.T) { sig, 0, types.EmptyATXID, - types.EmptyATXID, + goldenATXID, nil, types.EpochID(2), 0, @@ -1167,14 +1167,124 @@ func TestHandler_ProcessAtx_SamePrevATX(t *testing.T) { require.NoError(t, err) require.Nil(t, proof) - // second non-initial ATX references prevATX as prevATX + // valid first non-initial ATX atx2 := newActivationTx( + t, + sig, + 1, + atx1.ID(), + atx1.ID(), + nil, + types.EpochID(4), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + proof, err = atxHdlr.processVerifiedATX(context.Background(), atx2) + require.NoError(t, err) + require.Nil(t, proof) + + // second non-initial ATX references prevATX as prevATX + atx3 := newActivationTx( t, sig, 2, prevATX.ID(), atx1.ID(), nil, + types.EpochID(5), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnMalfeasance(gomock.Any()) + atxHdlr.mclock.EXPECT().CurrentLayer().Return(types.EpochID(5).FirstLayer()) + proof, err = atxHdlr.processVerifiedATX(context.Background(), atx3) + require.NoError(t, err) + proof.SetReceived(time.Time{}) + nodeID, err := malfeasance.Validate( + context.Background(), + atxHdlr.log, + atxHdlr.cdb, + atxHdlr.edVerifier, + nil, + &mwire.MalfeasanceGossip{ + MalfeasanceProof: *proof, + }, + ) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), nodeID) +} + +func TestHandler_ProcessAtx_SamePrevATX_NewInitial(t *testing.T) { + // Arrange + goldenATXID := types.ATXID{2, 3, 4} + atxHdlr := newTestHandler(t, goldenATXID) + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + coinbase := types.GenerateAddress([]byte("aaaa")) + + // Act & Assert + prevATX := newActivationTx( + t, + sig, + 0, + types.EmptyATXID, + goldenATXID, + nil, + types.EpochID(2), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + withVrfNonce(7), + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + proof, err := atxHdlr.processVerifiedATX(context.Background(), prevATX) + require.NoError(t, err) + require.Nil(t, proof) + + // valid first non-initial ATX + atx1 := newActivationTx( + t, + sig, + 1, + prevATX.ID(), + prevATX.ID(), + nil, + types.EpochID(3), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + proof, err = atxHdlr.processVerifiedATX(context.Background(), atx1) + require.NoError(t, err) + require.Nil(t, proof) + + // second non-initial ATX references empty as prevATX + atx2 := newActivationTx( + t, + sig, + 2, + types.EmptyATXID, + atx1.ID(), + nil, types.EpochID(4), 0, 100, @@ -1185,7 +1295,6 @@ func TestHandler_ProcessAtx_SamePrevATX(t *testing.T) { atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) atxHdlr.mtortoise.EXPECT().OnMalfeasance(gomock.Any()) - atxHdlr.mclock.EXPECT().CurrentLayer().Return(types.EpochID(4).FirstLayer()) proof, err = atxHdlr.processVerifiedATX(context.Background(), atx2) require.NoError(t, err) proof.SetReceived(time.Time{}) @@ -1450,8 +1559,8 @@ func BenchmarkNewActivationDb(b *testing.B) { goldenATXID := types.ATXID{2, 3, 4} atxHdlr := newTestHandler(b, goldenATXID) - atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) - atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()).AnyTimes() + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() const ( numOfMiners = 300 @@ -1486,6 +1595,8 @@ func BenchmarkNewActivationDb(b *testing.B) { } npst := newNIPostWithPoet(b, poetBytes) atx = newAtx(challenge, npst.NIPost, npst.NumUnits, coinbase) + atx.VRFNonce = new(types.VRFPostIndex) + *atx.VRFNonce = types.VRFPostIndex(7) SignAndFinalizeAtx(sigs[i], atx) vAtx, err := atx.Verify(0, 1) r.NoError(err) @@ -1957,8 +2068,8 @@ func TestHandler_HandleSyncedAtx(t *testing.T) { func BenchmarkGetAtxHeaderWithConcurrentProcessAtx(b *testing.B) { goldenATXID := types.ATXID{2, 3, 4} atxHdlr := newTestHandler(b, goldenATXID) - atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) - atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()).AnyTimes() + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() var ( stop uint64 @@ -1980,6 +2091,8 @@ func BenchmarkGetAtxHeaderWithConcurrentProcessAtx(b *testing.B) { sig, err := signing.NewEdSigner() require.NoError(b, err) atx := newAtx(challenge, nil, 1, types.Address{}) + atx.VRFNonce = new(types.VRFPostIndex) + *atx.VRFNonce = types.VRFPostIndex(7) require.NoError(b, SignAndFinalizeAtx(sig, atx)) vAtx, err := atx.Verify(0, 1) if !assert.NoError(b, err) { diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 5f747afa49c..2d2f6840581 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -372,7 +372,7 @@ func Add(db sql.Executor, atx *types.VerifiedActivationTx) error { // AddGettingNonce adds an ATX for a given ATX ID and returns the nonce for the newly added ATX. func AddGettingNonce(db sql.Executor, atx *types.VerifiedActivationTx) (*types.VRFPostIndex, error) { - if atx.ActivationTx.VRFNonce == nil && atx.PrevATXID != types.EmptyATXID { + if atx.VRFNonce == nil && atx.PrevATXID != types.EmptyATXID { nonce, err := NonceByID(db, atx.PrevATXID) if err != nil && !errors.Is(err, sql.ErrNotFound) { return nil, fmt.Errorf("error getting nonce: %w", err)