Skip to content

Commit

Permalink
Fix ATX handling of double initial ATXs
Browse files Browse the repository at this point in the history
  • Loading branch information
fasmat committed May 13, 2024
1 parent 9aff88d commit 766ab41
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 27 deletions.
59 changes: 40 additions & 19 deletions activation/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,32 +483,45 @@ func (h *Handler) checkWrongPrevAtx(
return nil, fmt.Errorf("%s referenced incorrect previous ATX", atx.SmesherID.ShortString())
}

// check if atx.PrevATXID is actually the last published ATX by the same node
prev, err := atxs.Get(tx, prevID)
if err != nil {
return nil, fmt.Errorf("get prev atx: %w", err)
}

// if atx references a previous ATX that is not the last ATX by the same node, there must be at least one
// atx published between prevATX and the current epoch
var atx2 *types.VerifiedActivationTx
pubEpoch := h.clock.CurrentLayer().GetEpoch()
for pubEpoch > prev.PublishEpoch {
id, err := atxs.PrevIDByNodeID(tx, atx.SmesherID, pubEpoch)
if atx.PrevATXID == types.EmptyATXID {
// if the ATX references an empty previous ATX, we can just take the initial ATX and create a proof
// that the node referenced the wrong previous ATX
id, err := atxs.GetFirstIDByNodeID(tx, atx.SmesherID)
if err != nil {
return nil, fmt.Errorf("get prev atx id by node id: %w", err)
return nil, fmt.Errorf("get initial atx: %w", err)
}

atx2, err = atxs.Get(tx, id)
if err != nil {
return nil, fmt.Errorf("get initial atx: %w", err)
}
} else {
prev, err := atxs.Get(tx, atx.PrevATXID)
if err != nil {
return nil, fmt.Errorf("get prev atx: %w", err)
}

if atx.ID() != atx2.ID() && atx.PrevATXID == atx2.PrevATXID {
// found an ATX that points to the same previous ATX
break
// if atx references a previous ATX that is not the last ATX by the same node, there must be at least one
// atx published between prevATX and the current epoch
pubEpoch := h.clock.CurrentLayer().GetEpoch()
for pubEpoch > prev.PublishEpoch {
id, err := atxs.PrevIDByNodeID(tx, atx.SmesherID, pubEpoch)
if err != nil {
return nil, fmt.Errorf("get prev atx id by node id: %w", err)
}

atx2, err = atxs.Get(tx, id)
if err != nil {
return nil, fmt.Errorf("get prev atx: %w", err)
}

if atx.ID() != atx2.ID() && atx.PrevATXID == atx2.PrevATXID {
// found an ATX that points to the same previous ATX
break
}
pubEpoch = atx2.PublishEpoch
}
pubEpoch = atx2.PublishEpoch
}

if atx2 == nil || atx2.PrevATXID != atx.PrevATXID {
Expand Down Expand Up @@ -580,7 +593,15 @@ func (h *Handler) storeAtx(ctx context.Context, atx *types.VerifiedActivationTx)
return nil, fmt.Errorf("store atx: %w", err)
}
if nonce == nil {
return nil, errors.New("no nonce")
if proof == nil {
return nil, errors.New("no nonce")
}
// special handling for a fake initial ATX (not first, but referencing empty as previous) not containing a nonce
vrf, err := atxs.VRFNonce(h.cdb, atx.SmesherID, atx.PublishEpoch)
if err != nil {
return nil, fmt.Errorf("get vrf nonce: %w", err)
}
nonce = &vrf
}
atxs.AtxAdded(h.cdb, atx)
if proof != nil {
Expand Down Expand Up @@ -713,7 +734,7 @@ func (h *Handler) processATX(
poetRef, atxIDs := collectAtxDeps(h.goldenATXID, &atx)
h.registerHashes(peer, poetRef, atxIDs)
if err := h.fetchReferences(ctx, poetRef, atxIDs); err != nil {
return nil, fmt.Errorf("fetching references for atx %x: %w", atx.ID(), err)
return nil, fmt.Errorf("fetching references for atx %v: %w", atx.ID(), err)
}

vAtx, proof, err := h.SyntacticallyValidateDeps(ctx, &atx)
Expand Down Expand Up @@ -768,7 +789,7 @@ func (h *Handler) fetchReferences(ctx context.Context, poetRef types.Hash32, atx
}

if err := h.fetcher.GetAtxs(ctx, atxIDs, system.WithoutLimiting()); err != nil {
return fmt.Errorf("missing atxs %x: %w", atxIDs, err)
return fmt.Errorf("missing atxs %v: %w", atxIDs, err)
}

h.log.WithContext(ctx).With().Debug("done fetching references", log.Int("fetched", len(atxIDs)))
Expand Down
127 changes: 120 additions & 7 deletions activation/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ func TestHandler_ProcessAtx_SamePrevATX(t *testing.T) {
sig,
0,
types.EmptyATXID,
types.EmptyATXID,
goldenATXID,
nil,
types.EpochID(2),
0,
Expand Down Expand Up @@ -1167,14 +1167,124 @@ func TestHandler_ProcessAtx_SamePrevATX(t *testing.T) {
require.NoError(t, err)
require.Nil(t, proof)

// second non-initial ATX references prevATX as prevATX
// valid first non-initial ATX
atx2 := newActivationTx(
t,
sig,
1,
atx1.ID(),
atx1.ID(),
nil,
types.EpochID(4),
0,
100,
coinbase,
100,
&types.NIPost{PostMetadata: &types.PostMetadata{}},
)
atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any())
atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any())
proof, err = atxHdlr.processVerifiedATX(context.Background(), atx2)
require.NoError(t, err)
require.Nil(t, proof)

// second non-initial ATX references prevATX as prevATX
atx3 := newActivationTx(
t,
sig,
2,
prevATX.ID(),
atx1.ID(),
nil,
types.EpochID(5),
0,
100,
coinbase,
100,
&types.NIPost{PostMetadata: &types.PostMetadata{}},
)
atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any())
atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any())
atxHdlr.mtortoise.EXPECT().OnMalfeasance(gomock.Any())
atxHdlr.mclock.EXPECT().CurrentLayer().Return(types.EpochID(5).FirstLayer())
proof, err = atxHdlr.processVerifiedATX(context.Background(), atx3)
require.NoError(t, err)
proof.SetReceived(time.Time{})
nodeID, err := malfeasance.Validate(
context.Background(),
atxHdlr.log,
atxHdlr.cdb,
atxHdlr.edVerifier,
nil,
&mwire.MalfeasanceGossip{
MalfeasanceProof: *proof,
},
)
require.NoError(t, err)
require.Equal(t, sig.NodeID(), nodeID)
}

func TestHandler_ProcessAtx_SamePrevATX_NewInitial(t *testing.T) {
// Arrange
goldenATXID := types.ATXID{2, 3, 4}
atxHdlr := newTestHandler(t, goldenATXID)

sig, err := signing.NewEdSigner()
require.NoError(t, err)

coinbase := types.GenerateAddress([]byte("aaaa"))

// Act & Assert
prevATX := newActivationTx(
t,
sig,
0,
types.EmptyATXID,
goldenATXID,
nil,
types.EpochID(2),
0,
100,
coinbase,
100,
&types.NIPost{PostMetadata: &types.PostMetadata{}},
withVrfNonce(7),
)
atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any())
atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any())
proof, err := atxHdlr.processVerifiedATX(context.Background(), prevATX)
require.NoError(t, err)
require.Nil(t, proof)

// valid first non-initial ATX
atx1 := newActivationTx(
t,
sig,
1,
prevATX.ID(),
prevATX.ID(),
nil,
types.EpochID(3),
0,
100,
coinbase,
100,
&types.NIPost{PostMetadata: &types.PostMetadata{}},
)
atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any())
atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any())
proof, err = atxHdlr.processVerifiedATX(context.Background(), atx1)
require.NoError(t, err)
require.Nil(t, proof)

// second non-initial ATX references empty as prevATX
atx2 := newActivationTx(
t,
sig,
2,
types.EmptyATXID,
atx1.ID(),
nil,
types.EpochID(4),
0,
100,
Expand All @@ -1185,7 +1295,6 @@ func TestHandler_ProcessAtx_SamePrevATX(t *testing.T) {
atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any())
atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any())
atxHdlr.mtortoise.EXPECT().OnMalfeasance(gomock.Any())
atxHdlr.mclock.EXPECT().CurrentLayer().Return(types.EpochID(4).FirstLayer())
proof, err = atxHdlr.processVerifiedATX(context.Background(), atx2)
require.NoError(t, err)
proof.SetReceived(time.Time{})
Expand Down Expand Up @@ -1450,8 +1559,8 @@ func BenchmarkNewActivationDb(b *testing.B) {

goldenATXID := types.ATXID{2, 3, 4}
atxHdlr := newTestHandler(b, goldenATXID)
atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any())
atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any())
atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()).AnyTimes()
atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()

const (
numOfMiners = 300
Expand Down Expand Up @@ -1486,6 +1595,8 @@ func BenchmarkNewActivationDb(b *testing.B) {
}
npst := newNIPostWithPoet(b, poetBytes)
atx = newAtx(challenge, npst.NIPost, npst.NumUnits, coinbase)
atx.VRFNonce = new(types.VRFPostIndex)
*atx.VRFNonce = types.VRFPostIndex(7)
SignAndFinalizeAtx(sigs[i], atx)
vAtx, err := atx.Verify(0, 1)
r.NoError(err)
Expand Down Expand Up @@ -1957,8 +2068,8 @@ func TestHandler_HandleSyncedAtx(t *testing.T) {
func BenchmarkGetAtxHeaderWithConcurrentProcessAtx(b *testing.B) {
goldenATXID := types.ATXID{2, 3, 4}
atxHdlr := newTestHandler(b, goldenATXID)
atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any())
atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any())
atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()).AnyTimes()
atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()

var (
stop uint64
Expand All @@ -1980,6 +2091,8 @@ func BenchmarkGetAtxHeaderWithConcurrentProcessAtx(b *testing.B) {
sig, err := signing.NewEdSigner()
require.NoError(b, err)
atx := newAtx(challenge, nil, 1, types.Address{})
atx.VRFNonce = new(types.VRFPostIndex)
*atx.VRFNonce = types.VRFPostIndex(7)
require.NoError(b, SignAndFinalizeAtx(sig, atx))
vAtx, err := atx.Verify(0, 1)
if !assert.NoError(b, err) {
Expand Down
2 changes: 1 addition & 1 deletion sql/atxs/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ func Add(db sql.Executor, atx *types.VerifiedActivationTx) error {

// AddGettingNonce adds an ATX for a given ATX ID and returns the nonce for the newly added ATX.
func AddGettingNonce(db sql.Executor, atx *types.VerifiedActivationTx) (*types.VRFPostIndex, error) {
if atx.ActivationTx.VRFNonce == nil && atx.PrevATXID != types.EmptyATXID {
if atx.VRFNonce == nil && atx.PrevATXID != types.EmptyATXID {
nonce, err := NonceByID(db, atx.PrevATXID)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return nil, fmt.Errorf("error getting nonce: %w", err)
Expand Down

0 comments on commit 766ab41

Please sign in to comment.