From cb85d2ce439a112ff206c31f4c03fdb9d2e65538 Mon Sep 17 00:00:00 2001 From: Matthias Fasching <5011972+fasmat@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:40:16 +0000 Subject: [PATCH 1/8] Improve performance of proposal builder by not DB inserting active set if not necessary (#6422) ## Motivation Closes #6418 The first proposal created by a node for any identity triggers persisting the activeset that is used by all identities of the node. This can be a slow process (especially when many ATXs are received by the node) and is usually not necessary, because the activeset is most likely already in the DB. This PR moves persisting the active set to the beginning of an epoch and avoids a DB write if not necessary (which will usually be the case). --- activation/handler_v1.go | 2 +- activation/handler_v2.go | 4 +- api/grpcserver/transaction_service_test.go | 2 +- api/grpcserver/v2alpha1/transaction_test.go | 2 +- blocks/certifier.go | 2 +- checkpoint/recovery.go | 4 +- cmd/merge-nodes/internal/merge_action.go | 2 +- common/types/poet.go | 5 ++ datastore/store.go | 6 +-- malfeasance/handler.go | 2 +- mesh/ballotwriter/ballotwriter.go | 12 ++--- mesh/ballotwriter/ballotwriter_test.go | 6 +-- mesh/mesh.go | 6 +-- miner/proposal_builder.go | 51 +++++++++++---------- miner/proposal_builder_test.go | 17 ------- sql/activesets/activesets.go | 13 ++---- sql/atxs/atxs_test.go | 4 +- sql/database.go | 9 ++-- sql/database_test.go | 2 +- sql/schema.go | 6 +-- sql/transactions/iterator_test.go | 6 +-- sql/transactions/transactions_test.go | 22 ++++----- syncer/atxsync/syncer.go | 2 +- syncer/malsync/syncer.go | 13 ++++-- syncer/malsync/syncer_test.go | 17 +++---- txs/cache.go | 8 ++-- 26 files changed, 101 insertions(+), 124 deletions(-) diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 7ab8c25d54..5519573562 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -500,7 +500,7 @@ func (h *HandlerV1) storeAtx( proof *mwire.MalfeasanceProof malicious bool ) - if err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { + if err := h.cdb.WithTxImmediate(ctx, func(tx sql.Transaction) error { var err error malicious, err = identities.IsMalicious(tx, atx.SmesherID) if err != nil { diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 7136dd46a1..3f3e6606a2 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -851,7 +851,7 @@ func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx sql.Transaction, atx *a // Store an ATX in the DB. func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx *activationTx) error { - if err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { + if err := h.cdb.WithTxImmediate(ctx, func(tx sql.Transaction) error { if len(watx.marriages) != 0 { newMarriageID, err := marriage.NewID(tx) if err != nil { @@ -927,7 +927,7 @@ func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx atxs.AtxAdded(h.cdb, atx) malicious := false - err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { + err := h.cdb.WithTxImmediate(ctx, func(tx sql.Transaction) error { // malfeasance check happens after storing the ATX because storing updates the marriage set // that is needed for the malfeasance proof // TODO(mafa): don't store own ATX if it would mark the node as malicious diff --git a/api/grpcserver/transaction_service_test.go b/api/grpcserver/transaction_service_test.go index caa29e0a76..f6a6786154 100644 --- a/api/grpcserver/transaction_service_test.go +++ b/api/grpcserver/transaction_service_test.go @@ -37,7 +37,7 @@ func TestTransactionService_StreamResults(t *testing.T) { gen := fixture.NewTransactionResultGenerator(). WithAddresses(2) txs := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(ctx, func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(ctx, func(dtx sql.Transaction) error { for i := range txs { tx := gen.Next() diff --git a/api/grpcserver/v2alpha1/transaction_test.go b/api/grpcserver/v2alpha1/transaction_test.go index c743c6e842..6177c7e9b2 100644 --- a/api/grpcserver/v2alpha1/transaction_test.go +++ b/api/grpcserver/v2alpha1/transaction_test.go @@ -43,7 +43,7 @@ func TestTransactionService_List(t *testing.T) { gen := fixture.NewTransactionResultGenerator().WithAddresses(2) txsList := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(ctx, func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(ctx, func(dtx sql.Transaction) error { for i := range txsList { tx := gen.Next() diff --git a/blocks/certifier.go b/blocks/certifier.go index 0931b022ec..ab03ade84d 100644 --- a/blocks/certifier.go +++ b/blocks/certifier.go @@ -564,7 +564,7 @@ func (c *Certifier) save( if len(valid)+len(invalid) == 0 { return certificates.Add(c.db, lid, cert) } - return c.db.WithTx(ctx, func(dbtx sql.Transaction) error { + return c.db.WithTxImmediate(ctx, func(dbtx sql.Transaction) error { if err := certificates.Add(dbtx, lid, cert); err != nil { return err } diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 7fb0cb775d..f4a5088940 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -138,7 +138,7 @@ func Recover( } defer localDB.Close() logger.Info("clearing atx and malfeasance sync metadata from local database") - if err := localDB.WithTx(ctx, func(tx sql.Transaction) error { + if err := localDB.WithTxImmediate(ctx, func(tx sql.Transaction) error { if err := atxsync.Clear(tx); err != nil { return err } @@ -274,7 +274,7 @@ func RecoverFromLocalFile( zap.Int("num accounts", len(data.accounts)), zap.Int("num atxs", len(data.atxs)), ) - if err = newDB.WithTx(ctx, func(tx sql.Transaction) error { + if err = newDB.WithTxImmediate(ctx, func(tx sql.Transaction) error { for _, acct := range data.accounts { if err = accounts.Update(tx, acct); err != nil { return fmt.Errorf("restore account snapshot: %w", err) diff --git a/cmd/merge-nodes/internal/merge_action.go b/cmd/merge-nodes/internal/merge_action.go index caa78b830b..7f8b3dbfb8 100644 --- a/cmd/merge-nodes/internal/merge_action.go +++ b/cmd/merge-nodes/internal/merge_action.go @@ -159,7 +159,7 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { } dbLog.Info("merging databases", zap.String("from", from), zap.String("to", to)) - err = dstDB.WithTx(ctx, func(tx sql.Transaction) error { + err = dstDB.WithTxImmediate(ctx, func(tx sql.Transaction) error { enc := func(stmt *sql.Statement) { stmt.BindText(1, filepath.Join(from, localDbFile)) } diff --git a/common/types/poet.go b/common/types/poet.go index c8c8b95a51..27e5764243 100644 --- a/common/types/poet.go +++ b/common/types/poet.go @@ -19,6 +19,11 @@ type PoetServer struct { Pubkey Base64Enc `mapstructure:"pubkey" json:"pubkey"` } +func ByteToPoetProofRef(b []byte) (ref PoetProofRef) { + copy(ref[:], b) + return ref +} + type PoetProofRef Hash32 func (r *PoetProofRef) String() string { diff --git a/datastore/store.go b/datastore/store.go index 256301e52e..c3daf491fc 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -351,13 +351,11 @@ func (bs *BlobStore) Has(hint Hint, key []byte) (bool, error) { case TXDB: return transactions.Has(bs.DB, types.TransactionID(types.BytesToHash(key))) case POETDB: - var ref types.PoetProofRef - copy(ref[:], key) - return poets.Has(bs.DB, ref) + return poets.Has(bs.DB, types.ByteToPoetProofRef(key)) case Malfeasance: return identities.IsMalicious(bs.DB, types.BytesToNodeID(key)) case ActiveSet: - return activesets.Has(bs.DB, key) + return activesets.Has(bs.DB, types.BytesToHash(key)) } return false, fmt.Errorf("blob store not found %s", hint) } diff --git a/malfeasance/handler.go b/malfeasance/handler.go index a6e1d0f847..ded77bd507 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -195,7 +195,7 @@ func (h *Handler) validateAndSave(ctx context.Context, p *wire.MalfeasanceProof) return types.EmptyNodeID, errors.Join(err, pubsub.ErrValidationReject) } proofBytes := codec.MustEncode(p) - if err := h.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { + if err := h.cdb.WithTxImmediate(ctx, func(dbtx sql.Transaction) error { malicious, err := identities.IsMalicious(dbtx, nodeID) if err != nil { return fmt.Errorf("check known malicious: %w", err) diff --git a/mesh/ballotwriter/ballotwriter.go b/mesh/ballotwriter/ballotwriter.go index 35cb2ec43c..991293f13e 100644 --- a/mesh/ballotwriter/ballotwriter.go +++ b/mesh/ballotwriter/ballotwriter.go @@ -20,7 +20,7 @@ import ( var writerDelay = 100 * time.Millisecond type BallotWriter struct { - db db + db sql.StateDatabase logger *zap.Logger atxMu sync.Mutex @@ -30,7 +30,7 @@ type BallotWriter struct { ballotBatchResult *batchResult } -func New(db db, logger *zap.Logger) *BallotWriter { +func New(db sql.StateDatabase, logger *zap.Logger) *BallotWriter { // create a stopped ticker that can be started later timer := time.NewTicker(writerDelay) timer.Stop() @@ -78,7 +78,7 @@ func (w *BallotWriter) Start(ctx context.Context) { // we use a context.Background() because: on shutdown the canceling of the // context may exit the transaction halfway and leave the db in some state where it // causes crawshaw to panic on a "not all connections returned to pool". - if err := w.db.WithTx(context.Background(), func(tx sql.Transaction) error { + if err := w.db.WithTxImmediate(context.Background(), func(tx sql.Transaction) error { for _, ballot := range batch { if !ballot.IsMalicious() { layerBallotStart := time.Now() @@ -163,9 +163,3 @@ type batchResult struct { doneC chan struct{} err error } - -type db interface { - sql.Executor - - WithTx(context.Context, func(sql.Transaction) error) error -} diff --git a/mesh/ballotwriter/ballotwriter_test.go b/mesh/ballotwriter/ballotwriter_test.go index 5da0eb4942..79c85aad74 100644 --- a/mesh/ballotwriter/ballotwriter_test.go +++ b/mesh/ballotwriter/ballotwriter_test.go @@ -123,7 +123,7 @@ func BenchmarkWriteCoalescing(b *testing.B) { db := newDiskSqlite(b) b.ResetTimer() for i := 0; i < b.N; i++ { - if err := db.WithTx(context.Background(), func(tx sql.Transaction) error { + if err := db.WithTxImmediate(context.Background(), func(tx sql.Transaction) error { if err := writeFn(a[i], tx); err != nil { b.Fatal(err) } @@ -138,7 +138,7 @@ func BenchmarkWriteCoalescing(b *testing.B) { db := newDiskSqlite(b) b.ResetTimer() for j := 0; j < b.N/1000; j++ { - if err := db.WithTx(context.Background(), func(tx sql.Transaction) error { + if err := db.WithTxImmediate(context.Background(), func(tx sql.Transaction) error { var err error for i := (j * 1000); i < (j*1000)+1000; i++ { if err = writeFn(a[i], tx); err != nil { @@ -156,7 +156,7 @@ func BenchmarkWriteCoalescing(b *testing.B) { db := newDiskSqlite(b) b.ResetTimer() for j := 0; j < b.N/5000; j++ { - if err := db.WithTx(context.Background(), func(tx sql.Transaction) error { + if err := db.WithTxImmediate(context.Background(), func(tx sql.Transaction) error { var err error for i := (j * 5000); i < (j*5000)+5000; i++ { if err = writeFn(a[i], tx); err != nil { diff --git a/mesh/mesh.go b/mesh/mesh.go index d72dc954c3..1d509fc9f7 100644 --- a/mesh/mesh.go +++ b/mesh/mesh.go @@ -95,7 +95,7 @@ func NewMesh( } genesis := types.GetEffectiveGenesis() - if err = db.WithTx(context.Background(), func(dbtx sql.Transaction) error { + if err = db.WithTxImmediate(context.Background(), func(dbtx sql.Transaction) error { if err = layers.SetProcessed(dbtx, genesis); err != nil { return fmt.Errorf("mesh init: %w", err) } @@ -385,7 +385,7 @@ func (msh *Mesh) applyResults(ctx context.Context, results []result.Layer) error return fmt.Errorf("execute block %v/%v: %w", layer.Layer, target, err) } } - if err := msh.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { + if err := msh.cdb.WithTxImmediate(ctx, func(dbtx sql.Transaction) error { if err := layers.SetApplied(dbtx, layer.Layer, target); err != nil { return fmt.Errorf("set applied for %v/%v: %w", layer.Layer, target, err) } @@ -440,7 +440,7 @@ func (msh *Mesh) saveHareOutput(ctx context.Context, lid types.LayerID, bid type certs []certificates.CertValidity err error ) - if err = msh.cdb.WithTx(ctx, func(tx sql.Transaction) error { + if err = msh.cdb.WithTxImmediate(ctx, func(tx sql.Transaction) error { // check if a certificate has been generated or sync'ed. // - node generated the certificate when it collected enough certify messages // - hare outputs are processed in layer order. i.e. when hare fails for a previous layer N, diff --git a/miner/proposal_builder.go b/miner/proposal_builder.go index d0c7873f8f..f0ec54d16a 100644 --- a/miner/proposal_builder.go +++ b/miner/proposal_builder.go @@ -65,7 +65,7 @@ type ProposalBuilder struct { logger *zap.Logger cfg config - db sql.Executor + db sql.StateDatabase localdb sql.Executor atxsdata atxsData clock layerClock @@ -204,7 +204,7 @@ func WithLayerSize(size uint32) Opt { } } -// WithWorkersLimit configures paralelization factor for builder operation when working with +// WithWorkersLimit configures parallelization factor for builder operation when working with // more than one signer. func WithWorkersLimit(limit int) Opt { return func(pb *ProposalBuilder) { @@ -270,7 +270,7 @@ func WithActivesetPreparation(prep ActiveSetPreparation) Opt { // New creates a struct of block builder type. func New( clock layerClock, - db sql.Executor, + db sql.StateDatabase, localdb sql.Executor, atxsdata atxsData, publisher pubsub.Publisher, @@ -449,7 +449,7 @@ func (pb *ProposalBuilder) UpdateActiveSet(target types.EpochID, set []types.ATX pb.activeGen.updateFallback(target, set) } -func (pb *ProposalBuilder) initSharedData(current types.LayerID) error { +func (pb *ProposalBuilder) initSharedData(ctx context.Context, current types.LayerID) error { if pb.shared.epoch != current.GetEpoch() { pb.shared = sharedSession{epoch: current.GetEpoch()} } @@ -476,7 +476,27 @@ func (pb *ProposalBuilder) initSharedData(current types.LayerID) error { pb.shared.active.id = id pb.shared.active.set = set pb.shared.active.weight = weight - return nil + + // Ideally we only persist the active set when we are actually eligible with at least one identity in at least one + // layer, but since at the moment we use a bootstrapped activeset, `activesets.Has` will always return + // true anyways. + // + // Additionally all activesets that are older than 2 epochs are deleted at the beginning of an epoch anyway, but + // maybe we should revisit this when activesets are no longer bootstrapped. + return pb.db.WithTx(ctx, func(tx sql.Transaction) error { + yes, err := activesets.Has(tx, pb.shared.active.id) + if err != nil { + return err + } + if yes { + return nil + } + + return activesets.Add(tx, pb.shared.active.id, &types.EpochActiveSet{ + Epoch: pb.shared.epoch, + Set: pb.shared.active.set, + }) + }) } func (pb *ProposalBuilder) initSignerData(ss *signerSession, lid types.LayerID) error { @@ -548,7 +568,7 @@ func (pb *ProposalBuilder) initSignerData(ss *signerSession, lid types.LayerID) func (pb *ProposalBuilder) build(ctx context.Context, lid types.LayerID) error { buildStartTime := time.Now() - if err := pb.initSharedData(lid); err != nil { + if err := pb.initSharedData(ctx, lid); err != nil { return err } @@ -578,17 +598,6 @@ func (pb *ProposalBuilder) build(ctx context.Context, lid types.LayerID) error { return meshHash }) - persistActiveSetOnce := sync.OnceValue(func() error { - err := activesets.Add(pb.db, pb.shared.active.id, &types.EpochActiveSet{ - Epoch: pb.shared.epoch, - Set: pb.shared.active.set, - }) - if err != nil && !errors.Is(err, sql.ErrObjectExists) { - return err - } - return nil - }) - // Two stage pipeline, with the stages running in parallel. // 1. Initializes signers. Runs limited number of goroutines because the initialization is CPU and DB bound. // 2. Collects eligible signers' sessions from the stage 1 and creates and publishes proposals. @@ -662,14 +671,6 @@ func (pb *ProposalBuilder) build(ctx context.Context, lid types.LayerID) error { ss.latency.hash = time.Since(start) eg2.Go(func() error { - // needs to be saved before publishing, as we will query it in handler - if ss.session.ref == types.EmptyBallotID { - start := time.Now() - if err := persistActiveSetOnce(); err != nil { - return err - } - ss.latency.activeSet = time.Since(start) - } proofs := ss.session.eligibilities.proofs[lid] start = time.Now() diff --git a/miner/proposal_builder_test.go b/miner/proposal_builder_test.go index c7fa9f74bf..2e9c4d368b 100644 --- a/miner/proposal_builder_test.go +++ b/miner/proposal_builder_test.go @@ -3,7 +3,6 @@ package miner import ( "bytes" "context" - "encoding/hex" "errors" "fmt" "math/rand" @@ -1272,19 +1271,3 @@ func BenchmarkDoubleCache(b *testing.B) { require.Equal(b, types.EmptyATXID, found) } - -func BenchmarkDB(b *testing.B) { - db, err := statesql.Open("file:state.sql") - require.NoError(b, err) - defer db.Close() - - bytes, err := hex.DecodeString("00003ce28800fadd692c522f7b1db219f675b49108aec7f818e2c4fd935573f6") - require.NoError(b, err) - nodeID := types.BytesToNodeID(bytes) - var found types.ATXID - b.ResetTimer() - for i := 0; i < b.N; i++ { - found, _ = atxs.GetByEpochAndNodeID(db, 30, nodeID) - } - require.NotEqual(b, types.EmptyATXID, found) -} diff --git a/sql/activesets/activesets.go b/sql/activesets/activesets.go index 58d30339e7..7e95202ad7 100644 --- a/sql/activesets/activesets.go +++ b/sql/activesets/activesets.go @@ -18,7 +18,7 @@ func Add(db sql.Executor, id types.Hash32, set *types.EpochActiveSet) error { (id, epoch, active_set) values (?1, ?2, ?3);`, func(stmt *sql.Statement) { - stmt.BindBytes(1, id[:]) + stmt.BindBytes(1, id.Bytes()) stmt.BindInt64(2, int64(set.Epoch)) stmt.BindBytes(3, codec.MustEncode(set)) }, nil) @@ -100,9 +100,7 @@ func getBlob(ctx context.Context, db sql.Executor, id []byte) ([]byte, error) { func DeleteBeforeEpoch(db sql.Executor, epoch types.EpochID) error { _, err := db.Exec("delete from activesets where epoch < ?1;", - func(stmt *sql.Statement) { - stmt.BindInt64(1, int64(epoch)) - }, + func(stmt *sql.Statement) { stmt.BindInt64(1, int64(epoch)) }, nil, ) if err != nil { @@ -111,10 +109,9 @@ func DeleteBeforeEpoch(db sql.Executor, epoch types.EpochID) error { return nil } -func Has(db sql.Executor, id []byte) (bool, error) { - rows, err := db.Exec( - "select 1 from activesets where id = ?1;", - func(stmt *sql.Statement) { stmt.BindBytes(1, id) }, +func Has(db sql.Executor, id types.Hash32) (bool, error) { + rows, err := db.Exec("select 1 from activesets where id = ?1;", + func(stmt *sql.Statement) { stmt.BindBytes(1, id.Bytes()) }, nil, ) if err != nil { diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 817e3f40d1..124b914d6f 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -433,7 +433,7 @@ func TestGetIDsByEpochCached(t *testing.T) { require.Equal(t, 11, db.QueryCount()) } - require.NoError(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(tx sql.Transaction) error { atxs.Add(tx, atx5, types.AtxBlob{}) return nil })) @@ -445,7 +445,7 @@ func TestGetIDsByEpochCached(t *testing.T) { require.ElementsMatch(t, []types.ATXID{atx4.ID(), atx5.ID()}, ids3) require.Equal(t, 13, db.QueryCount()) // not incremented after Add - require.Error(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { + require.Error(t, db.WithTxImmediate(context.Background(), func(tx sql.Transaction) error { atxs.Add(tx, atx6, types.AtxBlob{}) return errors.New("fail") // rollback })) diff --git a/sql/database.go b/sql/database.go index 3006e1cda5..4f4224b710 100644 --- a/sql/database.go +++ b/sql/database.go @@ -617,6 +617,8 @@ func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx } tx := &sqliteTx{queryCache: db.queryCache, db: db, conn: conn, freeConn: cancel} if err := tx.begin(initstmt); err != nil { + cancel() + db.pool.Put(conn) return nil, err } return tx, nil @@ -686,7 +688,7 @@ func (db *sqliteDatabase) Tx(ctx context.Context) (Transaction, error) { // WithTx will pass initialized deferred transaction to exec callback. // Will commit only if error is nil. func (db *sqliteDatabase) WithTx(ctx context.Context, exec func(Transaction) error) error { - return db.withTx(ctx, beginImmediate, exec) + return db.withTx(ctx, beginDefault, exec) } // TxImmediate creates immediate transaction. @@ -700,10 +702,7 @@ func (db *sqliteDatabase) TxImmediate(ctx context.Context) (Transaction, error) // WithTxImmediate will pass initialized immediate transaction to exec callback. // Will commit only if error is nil. -func (db *sqliteDatabase) WithTxImmediate( - ctx context.Context, - exec func(Transaction) error, -) error { +func (db *sqliteDatabase) WithTxImmediate(ctx context.Context, exec func(Transaction) error) error { return db.withTx(ctx, beginImmediate, exec) } diff --git a/sql/database_test.go b/sql/database_test.go index 46bc62aff4..d197d5e497 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -520,7 +520,7 @@ func TestDBClosed(t *testing.T) { require.NoError(t, db.Close()) _, err := db.Exec("select 1", nil, nil) require.ErrorIs(t, err, ErrClosed) - err = db.WithTx(context.Background(), func(tx Transaction) error { return nil }) + err = db.WithTxImmediate(context.Background(), func(tx Transaction) error { return nil }) require.ErrorIs(t, err, ErrClosed) } diff --git a/sql/schema.go b/sql/schema.go index 20a949b324..f393d7534f 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -85,7 +85,7 @@ func (s *Schema) SkipMigrations(i ...int) { // Apply applies the schema to the database. func (s *Schema) Apply(db Database) error { - return db.WithTx(context.Background(), func(tx Transaction) error { + return db.WithTxImmediate(context.Background(), func(tx Transaction) error { scanner := bufio.NewScanner(strings.NewReader(s.Script)) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if i := bytes.Index(data, []byte(";")); i >= 0 { @@ -147,7 +147,7 @@ func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState in if m.Order() <= before { continue } - if err := db.WithTx(context.Background(), func(tx Transaction) error { + if err := db.WithTxImmediate(context.Background(), func(tx Transaction) error { if _, ok := s.skipMigration[m.Order()]; !ok { if err := m.Apply(tx, logger); err != nil { for j := i; j >= 0 && s.Migrations[j].Order() > before; j-- { @@ -196,7 +196,7 @@ func (s *Schema) MigrateTempDB(logger *zap.Logger, db Database, before int) erro } if _, ok := s.skipMigration[m.Order()]; !ok { - if err := db.WithTx(context.Background(), func(tx Transaction) error { + if err := db.WithTxImmediate(context.Background(), func(tx Transaction) error { return m.Apply(tx, logger) }); err != nil { return fmt.Errorf("apply %s: %w", m.Name(), err) diff --git a/sql/transactions/iterator_test.go b/sql/transactions/iterator_test.go index 432e2f3c69..8988d9917d 100644 --- a/sql/transactions/iterator_test.go +++ b/sql/transactions/iterator_test.go @@ -64,7 +64,7 @@ func TestIterateResults(t *testing.T) { gen := fixture.NewTransactionResultGenerator() txs := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(context.TODO(), func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { for i := range txs { tx := gen.Next() @@ -148,7 +148,7 @@ func TestIterateSnapshot(t *testing.T) { require.NoError(t, err) gen := fixture.NewTransactionResultGenerator() expect := 10 - require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { for i := 0; i < expect; i++ { tx := gen.Next() @@ -176,7 +176,7 @@ func TestIterateSnapshot(t *testing.T) { }() <-initialized - require.NoError(t, db.WithTx(context.TODO(), func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { for i := 0; i < 10; i++ { tx := gen.Next() diff --git a/sql/transactions/transactions_test.go b/sql/transactions/transactions_test.go index ab4781874f..0bdac033b3 100644 --- a/sql/transactions/transactions_test.go +++ b/sql/transactions/transactions_test.go @@ -232,17 +232,17 @@ func TestApply_AlreadyApplied(t *testing.T) { require.NoError(t, transactions.Add(db, tx, time.Now())) bid := types.RandomBlockID() - require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) // same block applied again - require.Error(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { + require.Error(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) // different block applied again - require.Error(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { + require.Error(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult( dtx, tx.ID, @@ -254,7 +254,7 @@ func TestApply_AlreadyApplied(t *testing.T) { func TestUndoLayers_Empty(t *testing.T) { db := statesql.InMemoryTest(t) - require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, types.LayerID(199)) })) } @@ -273,7 +273,7 @@ func TestApplyAndUndoLayers(t *testing.T) { require.NoError(t, transactions.Add(db, tx, time.Now())) bid := types.RandomBlockID() - require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) applied = append(applied, tx.ID) @@ -285,7 +285,7 @@ func TestApplyAndUndoLayers(t *testing.T) { require.Equal(t, types.APPLIED, mtx.State) } // revert to firstLayer - require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, firstLayer.Add(1)) })) @@ -349,7 +349,7 @@ func TestGetByAddress(t *testing.T) { createTX(t, signer1, signer2Address, 1, 191, 1), } received := time.Now() - require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dbtx sql.Transaction) error { for _, tx := range txs { require.NoError(t, transactions.Add(dbtx, tx, received)) require.NoError(t, transactions.AddResult(dbtx, tx.ID, &types.TransactionResult{Layer: lid})) @@ -418,7 +418,7 @@ func TestAppliedLayer(t *testing.T) { for _, tx := range txs { require.NoError(t, transactions.Add(db, tx, time.Now())) } - require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, txs[0].ID, &types.TransactionResult{Layer: lid, Block: types.BlockID{1, 1}}) })) @@ -429,7 +429,7 @@ func TestAppliedLayer(t *testing.T) { _, err = transactions.GetAppliedLayer(db, txs[1].ID) require.ErrorIs(t, err, sql.ErrNotFound) - require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, lid) })) _, err = transactions.GetAppliedLayer(db, txs[0].ID) @@ -466,7 +466,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { {Address: principals[0], Nonce: txs[0].Nonce}, {Address: principals[1], Nonce: txs[2].Nonce}, }, rst) - require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dbtx sql.Transaction) error { return transactions.AddResult(dbtx, txs[0].ID, &types.TransactionResult{Message: "hey"}) })) rst, err = transactions.AddressesWithPendingTransactions(db) @@ -475,7 +475,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { {Address: principals[0], Nonce: txs[1].Nonce}, {Address: principals[1], Nonce: txs[2].Nonce}, }, rst) - require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { + require.NoError(t, db.WithTxImmediate(context.Background(), func(dbtx sql.Transaction) error { return transactions.AddResult(dbtx, txs[2].ID, &types.TransactionResult{Message: "hey"}) })) rst, err = transactions.AddressesWithPendingTransactions(db) diff --git a/syncer/atxsync/syncer.go b/syncer/atxsync/syncer.go index 99a47d741f..7bcea2a253 100644 --- a/syncer/atxsync/syncer.go +++ b/syncer/atxsync/syncer.go @@ -332,7 +332,7 @@ func (s *Syncer) downloadAtxs( } } - if err := s.localdb.WithTx(context.Background(), func(tx sql.Transaction) error { + if err := s.localdb.WithTxImmediate(context.Background(), func(tx sql.Transaction) error { err := atxsync.SaveRequest(tx, publish, lastSuccess, int64(len(state)), int64(len(downloaded))) if err != nil { return fmt.Errorf("failed to save request time: %w", err) diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go index f35e2175cb..687d4380ed 100644 --- a/syncer/malsync/syncer.go +++ b/syncer/malsync/syncer.go @@ -341,7 +341,7 @@ func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan } func (s *Syncer) updateState(ctx context.Context) error { - if err := s.localdb.WithTx(ctx, func(tx sql.Transaction) error { + if err := s.localdb.WithTxImmediate(ctx, func(tx sql.Transaction) error { return malsync.UpdateSyncState(tx, s.clock.Now()) }); err != nil { if ctx.Err() != nil { @@ -382,7 +382,9 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up return ctx.Err() case update = <-updates: s.logger.Debug("malfeasance sync update", - log.ZContext(ctx), zap.Int("count", len(update.nodeIDs))) + log.ZContext(ctx), + zap.Int("count", len(update.nodeIDs)), + ) sst.update(update) gotUpdate = true } @@ -392,7 +394,9 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up return ctx.Err() case update = <-updates: s.logger.Debug("malfeasance sync update", - log.ZContext(ctx), zap.Int("count", len(update.nodeIDs))) + log.ZContext(ctx), + zap.Int("count", len(update.nodeIDs)), + ) sst.update(update) gotUpdate = true default: @@ -417,7 +421,8 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up if len(batch) != 0 { s.logger.Debug("retrieving malfeasant identities", log.ZContext(ctx), - zap.Int("count", len(batch))) + zap.Int("count", len(batch)), + ) err := s.fetcher.GetMalfeasanceProofs(ctx, batch) if err != nil { if errors.Is(err, context.Canceled) { diff --git a/syncer/malsync/syncer_test.go b/syncer/malsync/syncer_test.go index b1397cd3e0..f3b9f53f03 100644 --- a/syncer/malsync/syncer_test.go +++ b/syncer/malsync/syncer_test.go @@ -225,8 +225,7 @@ func TestSyncer(t *testing.T) { tester.expectGetProofs(nil) epochStart := tester.clock.Now().Truncate(time.Second) epochEnd := epochStart.Add(10 * time.Minute) - require.NoError(t, - tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) require.ElementsMatch(t, []types.NodeID{ nid("1"), nid("2"), nid("3"), nid("4"), }, maps.Keys(tester.received)) @@ -238,8 +237,7 @@ func TestSyncer(t *testing.T) { }, tester.attempts) tester.clock.Advance(1 * time.Minute) // second call does nothing after recent sync - require.NoError(t, - tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) require.Zero(t, tester.peerErrCount.n) }) t.Run("EnsureInSync with no malfeasant identities", func(t *testing.T) { @@ -295,7 +293,7 @@ func TestSyncer(t *testing.T) { cancel() eg.Wait() }) - t.Run("gettings ids from MinSyncPeers peers is enough", func(t *testing.T) { + t.Run("getting ids from MinSyncPeers peers is enough", func(t *testing.T) { cfg := DefaultConfig() cfg.MinSyncPeers = 2 tester := newTester(t, cfg) @@ -324,8 +322,7 @@ func TestSyncer(t *testing.T) { }, tester.attempts) tester.clock.Advance(1 * time.Minute) // second call does nothing after recent sync - require.NoError(t, - tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) require.Equal(t, 1, tester.peerErrCount.n) }) t.Run("skip hashes after max retries", func(t *testing.T) { @@ -352,8 +349,7 @@ func TestSyncer(t *testing.T) { }, tester.attempts) tester.clock.Advance(1 * time.Minute) // second call does nothing after recent sync - require.NoError(t, - tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) }) t.Run("skip hashes after validation reject", func(t *testing.T) { tester := newTester(t, DefaultConfig()) @@ -379,7 +375,6 @@ func TestSyncer(t *testing.T) { }, tester.attempts) tester.clock.Advance(1 * time.Minute) // second call does nothing after recent sync - require.NoError(t, - tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) }) } diff --git a/txs/cache.go b/txs/cache.go index e23b92a9ff..ee49cedfca 100644 --- a/txs/cache.go +++ b/txs/cache.go @@ -688,7 +688,7 @@ func (c *Cache) ApplyLayer( // commit results before reporting them // TODO(dshulyak) save results in vm - if err := db.WithTx(context.Background(), func(dbtx sql.Transaction) error { + if err := db.WithTxImmediate(context.Background(), func(dbtx sql.Transaction) error { for _, rst := range results { err := transactions.AddResult(dbtx, rst.ID, &rst.TransactionResult) if err != nil { @@ -835,7 +835,7 @@ func checkApplyOrder(logger *zap.Logger, db sql.StateDatabase, toApply types.Lay } func addToProposal(db sql.StateDatabase, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID) error { - return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { + return db.WithTxImmediate(context.Background(), func(dbtx sql.Transaction) error { for _, tid := range tids { if err := transactions.AddToProposal(dbtx, tid, lid, pid); err != nil { return fmt.Errorf("add2prop %w", err) @@ -846,7 +846,7 @@ func addToProposal(db sql.StateDatabase, lid types.LayerID, pid types.ProposalID } func addToBlock(db sql.StateDatabase, lid types.LayerID, bid types.BlockID, tids []types.TransactionID) error { - return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { + return db.WithTxImmediate(context.Background(), func(dbtx sql.Transaction) error { for _, tid := range tids { if err := transactions.AddToBlock(dbtx, tid, lid, bid); err != nil { return fmt.Errorf("add2block %w", err) @@ -857,7 +857,7 @@ func addToBlock(db sql.StateDatabase, lid types.LayerID, bid types.BlockID, tids } func undoLayers(db sql.StateDatabase, from types.LayerID) error { - return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { + return db.WithTxImmediate(context.Background(), func(dbtx sql.Transaction) error { err := transactions.UndoLayers(dbtx, from) if err != nil { return fmt.Errorf("undo %w", err) From 2439e4c5cd03315009b39cdfb6afa5f679bf8509 Mon Sep 17 00:00:00 2001 From: Matthias Fasching <5011972+fasmat@users.noreply.github.com> Date: Wed, 30 Oct 2024 11:57:30 +0000 Subject: [PATCH 2/8] Update CHANGELOG and remove unused metric (#6423) ## Motivation Follow up to #6422: Update changlog and remove unused metric. --- CHANGELOG.md | 3 +++ blocks/generator_test.go | 3 ++- miner/metrics.go | 12 +++++------- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b0d12892e6..25f805d286 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ See [RELEASE](./RELEASE.md) for workflow instructions. * [#6417](https://github.com/spacemeshos/go-spacemesh/pull/6417) Fix initial post being deleted when the node is restarted or times out before the first ATX is published. +* [#6422](https://github.com/spacemeshos/go-spacemesh/pull/6422) Further improved performance of the proposal building + process to avoid late proposals. + ## v1.7.6 ### Upgrade information diff --git a/blocks/generator_test.go b/blocks/generator_test.go index 4edf80a138..fd02248787 100644 --- a/blocks/generator_test.go +++ b/blocks/generator_test.go @@ -88,7 +88,8 @@ func createTestGenerator(tb testing.TB) *testGenerator { tg.mockPatrol, WithGeneratorLogger(lg), WithHareOutputChan(ch), - WithConfig(testConfig())) + WithConfig(testConfig()), + ) return tg } diff --git a/miner/metrics.go b/miner/metrics.go index 2625866fb3..113334a3ba 100644 --- a/miner/metrics.go +++ b/miner/metrics.go @@ -21,12 +21,11 @@ type latencyTracker struct { start time.Time end time.Time - data time.Duration - tortoise time.Duration - hash time.Duration - activeSet time.Duration - txs time.Duration - publish time.Duration + data time.Duration + tortoise time.Duration + hash time.Duration + txs time.Duration + publish time.Duration } func (lt *latencyTracker) total() time.Duration { @@ -35,7 +34,6 @@ func (lt *latencyTracker) total() time.Duration { func (lt *latencyTracker) MarshalLogObject(encoder zapcore.ObjectEncoder) error { encoder.AddDuration("data", lt.data) - encoder.AddDuration("active set", lt.activeSet) encoder.AddDuration("tortoise", lt.tortoise) encoder.AddDuration("hash", lt.hash) encoder.AddDuration("txs", lt.txs) From f70592abee91a6293922eca66ab176e1b9d49ba2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Tue, 24 Sep 2024 09:45:10 +0200 Subject: [PATCH 3/8] Abstract away DB in AtxBuilder --- activation/activation.go | 204 +++++--------- activation/activation_multi_test.go | 44 +-- activation/activation_test.go | 278 ++++++------------- activation/atx_service_db.go | 141 ++++++++++ activation/atx_service_db_test.go | 104 +++++++ activation/builder_v2_test.go | 1 - activation/e2e/activation_test.go | 14 +- activation/e2e/builds_atx_v2_test.go | 13 +- activation/e2e/checkpoint_test.go | 26 +- activation/interface.go | 9 + activation/mocks.go | 140 ++++++++++ common/errors.go | 5 + node/node.go | 18 +- sql/database.go | 3 +- sql/localsql/atxs/atxs.go | 37 +++ sql/localsql/schema/migrations/0010_atxs.sql | 12 + sql/localsql/schema/schema.sql | 12 +- 17 files changed, 670 insertions(+), 391 deletions(-) create mode 100644 activation/atx_service_db.go create mode 100644 activation/atx_service_db_test.go create mode 100644 common/errors.go create mode 100644 sql/localsql/atxs/atxs.go create mode 100644 sql/localsql/schema/migrations/0010_atxs.sql diff --git a/activation/activation.go b/activation/activation.go index 7923f4ef46..a09400a25c 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -18,8 +18,8 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/metrics" "github.com/spacemeshos/go-spacemesh/activation/wire" - "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/common" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" "github.com/spacemeshos/go-spacemesh/log" @@ -27,7 +27,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/localsql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -81,11 +81,11 @@ type Config struct { // it is responsible for initializing post, receiving poet proof and orchestrating nipost after which it will // calculate total weight and providing relevant view as proof. type Builder struct { - accountLock sync.RWMutex - coinbaseAccount types.Address - conf Config - db sql.Executor - atxsdata *atxsdata.Data + accountLock sync.RWMutex + coinbaseAccount types.Address + conf Config + atxSvc AtxService + localDB sql.LocalDatabase publisher pubsub.Publisher nipostBuilder nipostBuilder @@ -97,8 +97,6 @@ type Builder struct { poets []PoetService poetCfg PoetConfig poetRetryInterval time.Duration - // delay before PoST in ATX is considered valid (counting from the time it was received) - postValidityDelay time.Duration // ATX versions versions []atxVersion @@ -121,16 +119,12 @@ type positioningAtxFinder struct { id types.ATXID forPublish types.EpochID } + golden types.ATXID + logger *zap.Logger } type BuilderOption func(*Builder) -func WithPostValidityDelay(delay time.Duration) BuilderOption { - return func(b *Builder) { - b.postValidityDelay = delay - } -} - // WithPoetRetryInterval modifies time that builder will have to wait before retrying ATX build process // if it failed due to issues with PoET server. func WithPoetRetryInterval(interval time.Duration) BuilderOption { @@ -160,12 +154,6 @@ func WithPoets(poets ...PoetService) BuilderOption { } } -func WithValidator(v nipostValidator) BuilderOption { - return func(b *Builder) { - b.validator = v - } -} - func WithPostStates(ps PostStates) BuilderOption { return func(b *Builder) { b.postStates = ps @@ -181,10 +169,10 @@ func BuilderAtxVersions(v AtxVersions) BuilderOption { // NewBuilder returns an atx builder that will start a routine that will attempt to create an atx upon each new layer. func NewBuilder( conf Config, - db sql.Executor, - atxsdata *atxsdata.Data, localDB sql.LocalDatabase, + atxService AtxService, publisher pubsub.Publisher, + nipostValidator nipostValidator, nipostBuilder nipostBuilder, layerClock layerClock, syncer syncer, @@ -195,18 +183,21 @@ func NewBuilder( parentCtx: context.Background(), signers: make(map[types.NodeID]*signing.EdSigner), conf: conf, - db: db, - atxsdata: atxsdata, localDB: localDB, publisher: publisher, + atxSvc: atxService, + validator: nipostValidator, nipostBuilder: nipostBuilder, layerClock: layerClock, syncer: syncer, logger: log, poetRetryInterval: defaultPoetRetryInterval, - postValidityDelay: 12 * time.Hour, postStates: NewPostStates(log), versions: []atxVersion{{0, types.AtxV1}}, + posAtxFinder: positioningAtxFinder{ + golden: conf.GoldenATXID, + logger: log, + }, } for _, opt := range opts { opt(b) @@ -349,7 +340,7 @@ func (b *Builder) SmesherIDs() []types.NodeID { func (b *Builder) BuildInitialPost(ctx context.Context, nodeID types.NodeID) error { // Generate the initial POST if we don't have an ATX... - if _, err := atxs.GetLastIDByNodeID(b.db, nodeID); err == nil { + if _, err := b.atxSvc.LastATX(ctx, nodeID); err == nil { return nil } // ...and if we haven't stored an initial post yet. @@ -358,7 +349,7 @@ func (b *Builder) BuildInitialPost(ctx context.Context, nodeID types.NodeID) err case err == nil: b.logger.Info("load initial post from db") return nil - case errors.Is(err, sql.ErrNotFound): + case errors.Is(err, common.ErrNotFound): b.logger.Info("creating initial post") default: return fmt.Errorf("get initial post: %w", err) @@ -530,11 +521,11 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID) // Start building new challenge: // 1. get previous ATX - prevAtx, err := b.GetPrevAtx(nodeID) + prevAtx, err := b.atxSvc.LastATX(ctx, nodeID) switch { case err == nil: currentEpochId = max(currentEpochId, prevAtx.PublishEpoch) - case errors.Is(err, sql.ErrNotFound): + case errors.Is(err, common.ErrNotFound): // no previous ATX case err != nil: return nil, fmt.Errorf("get last ATX: %w", err) @@ -580,11 +571,11 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID) // 4. build new challenge logger.Info("building new NiPOST challenge", zap.Uint32("current_epoch", currentEpochId.Uint32())) - prevAtx, err = b.GetPrevAtx(nodeID) + prevAtx, err = b.atxSvc.LastATX(ctx, nodeID) var challenge *types.NIPostChallenge switch { - case errors.Is(err, sql.ErrNotFound): + case errors.Is(err, common.ErrNotFound): logger.Info("no previous ATX found, creating an initial nipost challenge") challenge, err = b.buildInitialNIPostChallenge(ctx, logger, nodeID, publishEpochId) if err != nil { @@ -620,7 +611,7 @@ func (b *Builder) getExistingChallenge( challenge, err := nipost.Challenge(b.localDB, nodeID) switch { - case errors.Is(err, sql.ErrNotFound): + case errors.Is(err, common.ErrNotFound): return nil, nil case err != nil: @@ -700,14 +691,6 @@ func (b *Builder) buildInitialNIPostChallenge( }, nil } -func (b *Builder) GetPrevAtx(nodeID types.NodeID) (*types.ActivationTx, error) { - id, err := atxs.GetLastIDByNodeID(b.db, nodeID) - if err != nil { - return nil, fmt.Errorf("getting last ATXID: %w", err) - } - return atxs.Get(b.db, id) -} - // SetCoinbase sets the address rewardAddress to be the coinbase account written into the activation transaction // the rewards for blocks made by this miner will go to this address. func (b *Builder) SetCoinbase(rewardAddress types.Address) { @@ -755,6 +738,11 @@ func (b *Builder) PublishActivationTx(ctx context.Context, sig *signing.EdSigner case <-b.layerClock.AwaitLayer(challenge.PublishEpoch.FirstLayer()): } + err = atxs.AddBlob(b.localDB, challenge.PublishEpoch, atx.ID(), sig.NodeID(), codec.MustEncode(atx)) + if err != nil { + b.logger.Warn("failed to persist built ATX into the local DB - regossiping won't work", zap.Error(err)) + } + for { b.logger.Info( "broadcasting ATX", @@ -847,7 +835,7 @@ func (b *Builder) createAtx( case challenge.PrevATXID == types.EmptyATXID: atx.VRFNonce = (*uint64)(&nipostState.VRFNonce) default: - oldNonce, err := atxs.NonceByID(b.db, challenge.PrevATXID) + prevAtx, err := b.atxSvc.Atx(ctx, challenge.PrevATXID) if err != nil { b.logger.Warn("failed to get VRF nonce for ATX", zap.Error(err), @@ -855,12 +843,12 @@ func (b *Builder) createAtx( ) break } - if nipostState.VRFNonce != oldNonce { + if nipostState.VRFNonce != prevAtx.VRFNonce { b.logger.Info( "attaching a new VRF nonce in ATX", log.ZShortStringer("smesherID", sig.NodeID()), zap.Uint64("new nonce", uint64(nipostState.VRFNonce)), - zap.Uint64("old nonce", uint64(oldNonce)), + zap.Uint64("old nonce", uint64(prevAtx.VRFNonce)), ) atx.VRFNonce = (*uint64)(&nipostState.VRFNonce) } @@ -919,54 +907,44 @@ func (b *Builder) broadcast(ctx context.Context, atx scale.Encodable) (int, erro return len(buf), nil } -// searchPositioningAtx returns atx id with the highest tick height. -// Publish epoch is used for caching the positioning atx. -func (b *Builder) searchPositioningAtx( +// find returns atx id with the highest tick height. +// The publish epoch (of the built ATX) is used for: +// - caching the positioning atx, +// - filtering candidates for positioning atx (it must be published in an earlier epoch than built ATX). +// +// It always returns an ATX, falling back to the golden one as the last resort. +func (f *positioningAtxFinder) find( ctx context.Context, - nodeID types.NodeID, + atxs AtxService, publish types.EpochID, -) (types.ATXID, error) { - logger := b.logger.With(log.ZShortStringer("smesherID", nodeID), zap.Uint32("publish epoch", publish.Uint32())) +) types.ATXID { + logger := f.logger.With(zap.Uint32("publish epoch", publish.Uint32())) - b.posAtxFinder.finding.Lock() - defer b.posAtxFinder.finding.Unlock() + f.finding.Lock() + defer f.finding.Unlock() - if found := b.posAtxFinder.found; found != nil && found.forPublish == publish { + if found := f.found; found != nil && found.forPublish == publish { logger.Debug("using cached positioning atx", log.ZShortStringer("atx_id", found.id)) - return found.id, nil + return found.id } - latestPublished, err := atxs.LatestEpoch(b.db) - if err != nil { - return types.EmptyATXID, fmt.Errorf("get latest epoch: %w", err) - } - - logger.Info("searching for positioning atx", zap.Uint32("latest_epoch", latestPublished.Uint32())) - - // positioning ATX publish epoch must be lower than the publish epoch of built ATX - positioningAtxPublished := min(latestPublished, publish-1) - id, err := findFullyValidHighTickAtx( - ctx, - b.atxsdata, - positioningAtxPublished, - b.conf.GoldenATXID, - b.validator, - logger, - VerifyChainOpts.AssumeValidBefore(time.Now().Add(-b.postValidityDelay)), - VerifyChainOpts.WithTrustedID(nodeID), - VerifyChainOpts.WithLogger(b.logger), - ) + id, err := atxs.PositioningATX(ctx, publish-1) if err != nil { - logger.Info("search failed - using golden atx as positioning atx", zap.Error(err)) - id = b.conf.GoldenATXID + logger.Warn("failed to get positioning ATX - falling back to golden", zap.Error(err)) + f.found = &struct { + id types.ATXID + forPublish types.EpochID + }{f.golden, publish} + return f.golden } - b.posAtxFinder.found = &struct { + logger.Debug("found candidate positioning atx", log.ZShortStringer("id", id)) + + f.found = &struct { id types.ATXID forPublish types.EpochID }{id, publish} - - return id, nil + return id } // getPositioningAtx returns the positioning ATX. @@ -978,15 +956,7 @@ func (b *Builder) getPositioningAtx( publish types.EpochID, previous *types.ActivationTx, ) (types.ATXID, error) { - id, err := b.searchPositioningAtx(ctx, nodeID, publish) - if err != nil { - return types.EmptyATXID, err - } - - b.logger.Debug("found candidate positioning atx", - log.ZShortStringer("id", id), - log.ZShortStringer("smesherID", nodeID), - ) + id := b.posAtxFinder.find(ctx, b.atxSvc, publish) if previous == nil { b.logger.Info("selected positioning atx", @@ -1004,7 +974,7 @@ func (b *Builder) getPositioningAtx( return id, nil } - candidate, err := atxs.Get(b.db, id) + candidate, err := b.atxSvc.Atx(ctx, id) if err != nil { return types.EmptyATXID, fmt.Errorf("get candidate pos ATX %s: %w", id.ShortString(), err) } @@ -1024,23 +994,17 @@ func (b *Builder) getPositioningAtx( func (b *Builder) Regossip(ctx context.Context, nodeID types.NodeID) error { epoch := b.layerClock.CurrentLayer().GetEpoch() - atx, err := atxs.GetIDByEpochAndNodeID(b.db, epoch, nodeID) - if errors.Is(err, sql.ErrNotFound) { + id, blob, err := atxs.AtxBlob(b.localDB, epoch, nodeID) + if errors.Is(err, common.ErrNotFound) { return nil } else if err != nil { return err } - var blob sql.Blob - if _, err := atxs.LoadBlob(ctx, b.db, atx.Bytes(), &blob); err != nil { - return fmt.Errorf("get blob %s: %w", atx.ShortString(), err) - } - if len(blob.Bytes) == 0 { - return nil // checkpoint - } - if err := b.publisher.Publish(ctx, pubsub.AtxProtocol, blob.Bytes); err != nil { - return fmt.Errorf("republish %s: %w", atx.ShortString(), err) + + if err := b.publisher.Publish(ctx, pubsub.AtxProtocol, blob); err != nil { + return fmt.Errorf("republishing ATX %s: %w", id, err) } - b.logger.Debug("re-gossipped atx", log.ZShortStringer("smesherID", nodeID), log.ZShortStringer("atx", atx)) + b.logger.Debug("re-gossipped atx", log.ZShortStringer("smesherID", nodeID), log.ZShortStringer("atx ID", id)) return nil } @@ -1053,41 +1017,3 @@ func (b *Builder) version(publish types.EpochID) types.AtxVersion { } return version } - -func findFullyValidHighTickAtx( - ctx context.Context, - atxdata *atxsdata.Data, - publish types.EpochID, - goldenATXID types.ATXID, - validator nipostValidator, - logger *zap.Logger, - opts ...VerifyChainOption, -) (types.ATXID, error) { - var found *types.ATXID - - // iterate trough epochs, to get first valid, not malicious ATX with the biggest height - atxdata.IterateHighTicksInEpoch(publish+1, func(id types.ATXID) (contSearch bool) { - logger.Debug("found candidate for high-tick atx", log.ZShortStringer("id", id)) - if ctx.Err() != nil { - return false - } - // verify ATX-candidate by getting their dependencies (previous Atx, positioning ATX etc.) - // and verifying PoST for every dependency - if err := validator.VerifyChain(ctx, id, goldenATXID, opts...); err != nil { - logger.Debug("rejecting candidate for high-tick atx", zap.Error(err), log.ZShortStringer("id", id)) - return true - } - found = &id - return false - }) - - if ctx.Err() != nil { - return types.ATXID{}, ctx.Err() - } - - if found == nil { - return types.ATXID{}, ErrNotFound - } - - return *found, nil -} diff --git a/activation/activation_multi_test.go b/activation/activation_multi_test.go index ce795e53b5..56736a7c51 100644 --- a/activation/activation_multi_test.go +++ b/activation/activation_multi_test.go @@ -18,7 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/atxs" + localatxs "github.com/spacemeshos/go-spacemesh/sql/localsql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -188,46 +188,22 @@ func TestRegossip(t *testing.T) { }) t.Run("success", func(t *testing.T) { - goldenATXID := types.RandomATXID() tab := newTestBuilder(t, 5) - var refAtx *types.ActivationTx - + var ( + smesher types.NodeID + blob []byte + ) for _, sig := range tab.signers { - atx := newInitialATXv1(t, goldenATXID) - atx.PublishEpoch = layer.GetEpoch() - atx.Sign(sig) - vAtx := toAtx(t, atx) - require.NoError(t, atxs.Add(tab.db, vAtx, atx.Blob())) - - if refAtx == nil { - refAtx = vAtx - } + smesher = sig.NodeID() + blob = types.RandomBytes(20) + localatxs.AddBlob(tab.localDb, layer.GetEpoch(), types.RandomATXID(), smesher, blob) } - var blob sql.Blob - ver, err := atxs.LoadBlob(context.Background(), tab.db, refAtx.ID().Bytes(), &blob) - require.NoError(t, err) - require.Equal(t, types.AtxV1, ver) - // atx will be regossiped once (by the smesher) tab.mclock.EXPECT().CurrentLayer().Return(layer) ctx := context.Background() - tab.mpub.EXPECT().Publish(ctx, pubsub.AtxProtocol, blob.Bytes) - require.NoError(t, tab.Regossip(ctx, refAtx.SmesherID)) - }) - - t.Run("checkpointed", func(t *testing.T) { - tab := newTestBuilder(t, 5) - for _, sig := range tab.signers { - atx := atxs.CheckpointAtx{ - ID: types.RandomATXID(), - Epoch: layer.GetEpoch(), - SmesherID: sig.NodeID(), - } - require.NoError(t, atxs.AddCheckpointed(tab.db, &atx)) - tab.mclock.EXPECT().CurrentLayer().Return(layer) - require.NoError(t, tab.Regossip(context.Background(), sig.NodeID())) - } + tab.mpub.EXPECT().Publish(ctx, pubsub.AtxProtocol, blob) + require.NoError(t, tab.Regossip(ctx, smesher)) }) } diff --git a/activation/activation_test.go b/activation/activation_test.go index ba1c6dd7ea..7b3bcd0eaf 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -91,20 +91,26 @@ func newTestBuilder(tb testing.TB, numSigners int, opts ...BuilderOption) *testA mValidator: NewMocknipostValidator(ctrl), } - opts = append(opts, WithValidator(tab.mValidator)) - cfg := Config{ GoldenATXID: tab.goldenATXID, } tab.msync.EXPECT().RegisterForATXSynced().DoAndReturn(closedChan).AnyTimes() - b := NewBuilder( - cfg, + atxService := NewDBAtxService( tab.db, + tab.goldenATXID, atxsdata.New(), + tab.mValidator, + logger, + ) + + b := NewBuilder( + cfg, tab.localDb, + atxService, tab.mpub, + tab.mValidator, tab.mnipost, tab.mclock, tab.msync, @@ -141,7 +147,6 @@ func publishAtxV1( tb, atxs.SetPost(tab.db, watx.ID(), watx.PrevATXID, 0, watx.SmesherID, watx.NumUnits, watx.PublishEpoch), ) - tab.atxsdata.AddFromAtx(toAtx(tb, &watx), false) return &watx } @@ -353,14 +358,17 @@ func TestBuilder_PublishActivationTx_HappyFlow(t *testing.T) { posEpoch := postGenesisEpoch currLayer := posEpoch.FirstLayer() - prevAtx := newInitialATXv1(t, tab.goldenATXID) - prevAtx.Sign(sig) - require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) - tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) + prevAtx := &types.ActivationTx{ + CommitmentATX: &tab.goldenATXID, + Coinbase: tab.Coinbase(), + NumUnits: 100, + TickCount: 10, + SmesherID: sig.NodeID(), + } + require.NoError(t, atxs.Add(tab.db, prevAtx, types.AtxBlob{})) // create and publish ATX tab.mclock.EXPECT().CurrentLayer().Return(currLayer).Times(4) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), prevAtx.ID(), tab.goldenATXID, gomock.Any()) atx1 := publishAtxV1(t, tab, sig.NodeID(), posEpoch, &currLayer, layersPerEpoch) require.NotNil(t, atx1) require.Equal(t, prevAtx.ID(), atx1.PositioningATXID) @@ -368,7 +376,6 @@ func TestBuilder_PublishActivationTx_HappyFlow(t *testing.T) { // create and publish another ATX currLayer = (posEpoch + 1).FirstLayer() tab.mclock.EXPECT().CurrentLayer().Return(currLayer).Times(4) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), atx1.ID(), tab.goldenATXID, gomock.Any()) atx2 := publishAtxV1(t, tab, sig.NodeID(), atx1.PublishEpoch, &currLayer, layersPerEpoch) require.NotNil(t, atx2) require.NotEqual(t, atx1, atx2) @@ -392,7 +399,6 @@ func TestBuilder_Loop_WaitsOnStaleChallenge(t *testing.T) { prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) - tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) tab.mclock.EXPECT().CurrentLayer().Return(currLayer).AnyTimes() tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn( @@ -416,8 +422,6 @@ func TestBuilder_Loop_WaitsOnStaleChallenge(t *testing.T) { return ch }) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - // Act & Verify var eg errgroup.Group eg.Go(func() error { @@ -441,7 +445,6 @@ func TestBuilder_PublishActivationTx_FaultyNet(t *testing.T) { prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) - tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) publishEpoch := posEpoch + 1 tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes() @@ -488,7 +491,6 @@ func TestBuilder_PublishActivationTx_FaultyNet(t *testing.T) { // after successful publish, state is cleaned up tab.mnipost.EXPECT().ResetState(sig.NodeID()).Return(nil) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tab.mpub.EXPECT().Publish(gomock.Any(), pubsub.AtxProtocol, gomock.Any()).DoAndReturn( // second publish succeeds func(_ context.Context, _ string, got []byte) error { @@ -516,7 +518,6 @@ func TestBuilder_PublishActivationTx_UsesExistingChallengeOnLatePublish(t *testi prevAtx.Sign(sig) vPrevAtx := toAtx(t, prevAtx) require.NoError(t, atxs.Add(tab.db, vPrevAtx, prevAtx.Blob())) - tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) publishEpoch := currLayer.GetEpoch() tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes() @@ -536,12 +537,14 @@ func TestBuilder_PublishActivationTx_UsesExistingChallengeOnLatePublish(t *testi NumUnits: DefaultPostSetupOpts().NumUnits, LabelsPerUnit: DefaultPostConfig().LabelsPerUnit, }, nil).AnyTimes() - tab.mnipost.EXPECT().BuildNIPost(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ *signing.EdSigner, _ types.Hash32, _ *types.NIPostChallenge, - ) (*nipost.NIPostState, error) { - currLayer = currLayer.Add(1) - return newNIPostWithPoet(t, []byte("66666")), nil - }) + tab.mnipost.EXPECT(). + BuildNIPost(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn( + func(_ context.Context, _ *signing.EdSigner, _ types.Hash32, _ *types.NIPostChallenge, + ) (*nipost.NIPostState, error) { + currLayer = currLayer.Add(1) + return newNIPostWithPoet(t, []byte("66666")), nil + }) done := make(chan struct{}) close(done) tab.mclock.EXPECT().AwaitLayer(publishEpoch.FirstLayer()).DoAndReturn( @@ -593,7 +596,6 @@ func TestBuilder_PublishActivationTx_RebuildNIPostWhenTargetEpochPassed(t *testi prevAtx.Sign(sig) vPrevAtx := toAtx(t, prevAtx) require.NoError(t, atxs.Add(tab.db, vPrevAtx, prevAtx.Blob())) - tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) publishEpoch := posEpoch + 1 tab.mclock.EXPECT().CurrentLayer().DoAndReturn( @@ -622,7 +624,6 @@ func TestBuilder_PublishActivationTx_RebuildNIPostWhenTargetEpochPassed(t *testi } return done }) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() var built *wire.ActivationTxV1 @@ -654,10 +655,8 @@ func TestBuilder_PublishActivationTx_RebuildNIPostWhenTargetEpochPassed(t *testi posAtx := newInitialATXv1(t, tab.goldenATXID, func(atx *wire.ActivationTxV1) { atx.PublishEpoch = posEpoch }) posAtx.Sign(sig) require.NoError(t, atxs.Add(tab.db, toAtx(t, posAtx), posAtx.Blob())) - tab.atxsdata.AddFromAtx(toAtx(t, posAtx), false) tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes() tab.mnipost.EXPECT().ResetState(sig.NodeID()).Return(nil) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), posAtx.ID(), tab.goldenATXID, gomock.Any()) built2 := publishAtxV1(t, tab, sig.NodeID(), posEpoch, &currLayer, layersPerEpoch) require.NotNil(t, built2) require.NotEqual(t, built.NIPostChallengeV1, built2.NIPostChallengeV1) @@ -843,31 +842,35 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) { r := require.New(t) // Arrange - tab := newTestBuilder(t, 1, WithPoetConfig(PoetConfig{PhaseShift: layerDuration * 4})) + actSvc := NewMockAtxService(gomock.NewController(t)) + tab := newTestBuilder(t, 1, + WithPoetConfig(PoetConfig{PhaseShift: layerDuration * 4}), + ) + tab.atxSvc = actSvc sig := maps.Values(tab.signers)[0] - otherSigner, err := signing.NewEdSigner() - r.NoError(err) - poetBytes := []byte("poet") - currentLayer := postGenesisEpoch.FirstLayer().Add(3) - posAtx := newInitialATXv1(t, tab.goldenATXID) - posAtx.Sign(otherSigner) - vPosAtx := toAtx(t, posAtx) - vPosAtx.TickCount = 100 - r.NoError(atxs.Add(tab.db, vPosAtx, posAtx.Blob())) - tab.atxsdata.AddFromAtx(vPosAtx, false) - - nonce := types.VRFPostIndex(123) - prevAtx := newInitialATXv1(t, tab.goldenATXID, func(atx *wire.ActivationTxV1) { - atx.VRFNonce = (*uint64)(&nonce) - }) - prevAtx.Sign(sig) - vPrevAtx := toAtx(t, prevAtx) - r.NoError(atxs.Add(tab.db, vPrevAtx, prevAtx.Blob())) - tab.atxsdata.AddFromAtx(vPrevAtx, false) + posAtx := &types.ActivationTx{ + PublishEpoch: 1, + TickCount: 100, + SmesherID: types.RandomNodeID(), + } + posAtx.SetID(types.RandomATXID()) + actSvc.EXPECT().PositioningATX(gomock.Any(), gomock.Any()).Return(posAtx.ID(), nil) + actSvc.EXPECT().Atx(gomock.Any(), posAtx.ID()).Return(posAtx, nil) + + prevAtx := &types.ActivationTx{ + PublishEpoch: 1, + TickCount: 10, + SmesherID: types.RandomNodeID(), + VRFNonce: types.VRFPostIndex(123), + } + prevAtx.SetID(types.RandomATXID()) + actSvc.EXPECT().LastATX(gomock.Any(), sig.NodeID()).Return(prevAtx, nil).Times(2) + actSvc.EXPECT().Atx(gomock.Any(), prevAtx.ID()).Return(prevAtx, nil) // Act + currentLayer := prevAtx.PublishEpoch.FirstLayer().Add(3) tab.msync.EXPECT().RegisterForATXSynced().DoAndReturn(closedChan).AnyTimes() tab.mclock.EXPECT().CurrentLayer().Return(currentLayer).AnyTimes() @@ -877,7 +880,7 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) { genesis := time.Now().Add(-time.Duration(currentLayer) * layerDuration) return genesis.Add(layerDuration * time.Duration(layer)) }).AnyTimes() - tab.mclock.EXPECT().AwaitLayer(vPosAtx.PublishEpoch.FirstLayer().Add(layersPerEpoch)).DoAndReturn( + tab.mclock.EXPECT().AwaitLayer(prevAtx.PublishEpoch.FirstLayer().Add(layersPerEpoch)).DoAndReturn( func(layer types.LayerID) <-chan struct{} { ch := make(chan struct{}) close(ch) @@ -889,7 +892,7 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) { tab.mpostClient.EXPECT().Info(gomock.Any()).Return(&types.PostInfo{ NodeID: sig.NodeID(), CommitmentATX: commitmentATX, - Nonce: &nonce, + Nonce: &prevAtx.VRFNonce, NumUnits: DefaultPostSetupOpts().NumUnits, LabelsPerUnit: DefaultPostConfig().LabelsPerUnit, @@ -904,8 +907,6 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) { return newNIPostWithPoet(t, poetBytes), nil }) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tab.mpub.EXPECT(). Publish(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, _ string, msg []byte) error { @@ -929,7 +930,7 @@ func TestBuilder_PublishActivationTx_PrevATXWithoutPrevATX(t *testing.T) { r.NoError(tab.PublishActivationTx(context.Background(), sig)) // state is cleaned up - _, err = nipost.Challenge(tab.localDB, sig.NodeID()) + _, err := nipost.Challenge(tab.localDB, sig.NodeID()) require.ErrorIs(t, err, sql.ErrNotFound) } @@ -937,19 +938,23 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) { r := require.New(t) // Arrange - tab := newTestBuilder(t, 1, WithPoetConfig(PoetConfig{PhaseShift: layerDuration * 4})) + atxSvc := NewMockAtxService(gomock.NewController(t)) + tab := newTestBuilder(t, 1, + WithPoetConfig(PoetConfig{PhaseShift: layerDuration * 4}), + ) + tab.atxSvc = atxSvc sig := maps.Values(tab.signers)[0] - otherSigner, err := signing.NewEdSigner() - r.NoError(err) - poetBytes := []byte("poet") currentLayer := postGenesisEpoch.FirstLayer().Add(3) - posEpoch := postGenesisEpoch - posAtx := newInitialATXv1(t, tab.goldenATXID) - posAtx.Sign(otherSigner) - r.NoError(atxs.Add(tab.db, toAtx(t, posAtx), posAtx.Blob())) - tab.atxsdata.AddFromAtx(toAtx(t, posAtx), false) + posAtx := &types.ActivationTx{ + PublishEpoch: 2, + TickCount: 100, + SmesherID: types.RandomNodeID(), + } + posAtx.SetID(types.RandomATXID()) + atxSvc.EXPECT().PositioningATX(gomock.Any(), gomock.Any()).Return(posAtx.ID(), nil) + atxSvc.EXPECT().LastATX(gomock.Any(), sig.NodeID()).Return(nil, sql.ErrNotFound).Times(2) // Act & Assert tab.msync.EXPECT().RegisterForATXSynced().DoAndReturn(closedChan).AnyTimes() @@ -987,7 +992,6 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) { return newNIPostWithPoet(t, poetBytes), nil }) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tab.mpub.EXPECT(). Publish(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, _ string, msg []byte) error { @@ -999,7 +1003,7 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) { r.Equal(types.EmptyATXID, atx.PrevATXID) r.NotNil(atx.InitialPost) r.Equal(posAtx.ID(), atx.PositioningATXID) - r.Equal(posEpoch+1, atx.PublishEpoch) + r.Equal(posAtx.PublishEpoch+1, atx.PublishEpoch) r.Equal(poetBytes, atx.NIPost.PostMetadata.Challenge) return nil @@ -1029,7 +1033,7 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) { r.NoError(tab.PublishActivationTx(context.Background(), sig)) // state is cleaned up - _, err = nipost.Challenge(tab.localDB, sig.NodeID()) + _, err := nipost.Challenge(tab.localDB, sig.NodeID()) require.ErrorIs(t, err, sql.ErrNotFound) } @@ -1042,7 +1046,6 @@ func TestBuilder_PublishActivationTx_FailsWhenNIPostBuilderFails(t *testing.T) { prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) - tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) tab.mclock.EXPECT().CurrentLayer().Return(posEpoch.FirstLayer()).AnyTimes() tab.mclock.EXPECT().LayerToTime(gomock.Any()).DoAndReturn( @@ -1055,7 +1058,6 @@ func TestBuilder_PublishActivationTx_FailsWhenNIPostBuilderFails(t *testing.T) { tab.mnipost.EXPECT(). BuildNIPost(gomock.Any(), sig, gomock.Any(), gomock.Any()). Return(nil, nipostErr) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) require.ErrorIs(t, tab.PublishActivationTx(context.Background(), sig), nipostErr) // state is preserved @@ -1100,7 +1102,6 @@ func TestBuilder_RetryPublishActivationTx(t *testing.T) { prevAtx := newInitialATXv1(t, tab.goldenATXID) prevAtx.Sign(sig) require.NoError(t, atxs.Add(tab.db, toAtx(t, prevAtx), prevAtx.Blob())) - tab.atxsdata.AddFromAtx(toAtx(t, prevAtx), false) currLayer := prevAtx.PublishEpoch.FirstLayer() tab.mclock.EXPECT().CurrentLayer().DoAndReturn(func() types.LayerID { return currLayer }).AnyTimes() @@ -1148,7 +1149,6 @@ func TestBuilder_RetryPublishActivationTx(t *testing.T) { ) tab.mnipost.EXPECT().ResetState(sig.NodeID()).Return(nil) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) nonce := types.VRFPostIndex(123) commitmentATX := types.RandomATXID() @@ -1395,74 +1395,20 @@ func TestWaitPositioningAtx(t *testing.T) { } } -// Test if GetPositioningAtx disregards ATXs with invalid POST in their chain. -// It should pick an ATX with valid POST even though it's a lower height. -func TestGetPositioningAtxPicksAtxWithValidChain(t *testing.T) { - tab := newTestBuilder(t, 1) - sig := maps.Values(tab.signers)[0] - - // Invalid chain with high height - sigInvalid, err := signing.NewEdSigner() - require.NoError(t, err) - invalidAtx := newInitialATXv1(t, tab.goldenATXID) - invalidAtx.Sign(sigInvalid) - vInvalidAtx := toAtx(t, invalidAtx) - vInvalidAtx.TickCount = 100 - require.NoError(t, err) - require.NoError(t, atxs.Add(tab.db, vInvalidAtx, invalidAtx.Blob())) - tab.atxsdata.AddFromAtx(vInvalidAtx, false) - - // Valid chain with lower height - sigValid, err := signing.NewEdSigner() - require.NoError(t, err) - validAtx := newInitialATXv1(t, tab.goldenATXID) - validAtx.NumUnits += 10 - validAtx.Sign(sigValid) - vValidAtx := toAtx(t, validAtx) - require.NoError(t, atxs.Add(tab.db, vValidAtx, validAtx.Blob())) - tab.atxsdata.AddFromAtx(vValidAtx, false) - - tab.mValidator.EXPECT(). - VerifyChain(gomock.Any(), invalidAtx.ID(), tab.goldenATXID, gomock.Any()). - Return(errors.New("")) - tab.mValidator.EXPECT(). - VerifyChain(gomock.Any(), validAtx.ID(), tab.goldenATXID, gomock.Any()) - - posAtxID, err := tab.getPositioningAtx(context.Background(), sig.NodeID(), 77, nil) - require.NoError(t, err) - require.Equal(t, posAtxID, vValidAtx.ID()) - - // should use the cached positioning ATX when asked for the same publish epoch - posAtxID, err = tab.getPositioningAtx(context.Background(), sig.NodeID(), 77, nil) - require.NoError(t, err) - require.Equal(t, posAtxID, vValidAtx.ID()) - - // should lookup again when asked for a different publish epoch - tab.mValidator.EXPECT(). - VerifyChain(gomock.Any(), invalidAtx.ID(), tab.goldenATXID, gomock.Any()). - Return(errors.New("")) - tab.mValidator.EXPECT(). - VerifyChain(gomock.Any(), validAtx.ID(), tab.goldenATXID, gomock.Any()) - - posAtxID, err = tab.getPositioningAtx(context.Background(), sig.NodeID(), 99, nil) - require.NoError(t, err) - require.Equal(t, posAtxID, vValidAtx.ID()) -} - func TestGetPositioningAtx(t *testing.T) { t.Parallel() - t.Run("db failed", func(t *testing.T) { + t.Run("picks golden when failed", func(t *testing.T) { t.Parallel() + atxSvc := NewMockAtxService(gomock.NewController(t)) tab := newTestBuilder(t, 1) + tab.atxSvc = atxSvc - db := sql.NewMockExecutor(gomock.NewController(t)) - tab.Builder.db = db - expected := errors.New("db error") - db.EXPECT().Exec(gomock.Any(), gomock.Any(), gomock.Any()).Return(0, expected) + expected := errors.New("expected error") + atxSvc.EXPECT().PositioningATX(gomock.Any(), gomock.Any()).Return(types.ATXID{}, expected) - none, err := tab.getPositioningAtx(context.Background(), types.EmptyNodeID, 99, nil) - require.ErrorIs(t, err, expected) - require.Equal(t, types.ATXID{}, none) + posATX, err := tab.getPositioningAtx(context.Background(), types.EmptyNodeID, 99, nil) + require.NoError(t, err) + require.Equal(t, tab.goldenATXID, posATX) }) t.Run("picks golden if no ATXs", func(t *testing.T) { tab := newTestBuilder(t, 1) @@ -1479,21 +1425,18 @@ func TestGetPositioningAtx(t *testing.T) { require.Equal(t, prev.ID(), atx) }) t.Run("prefers own previous when it has GTE ticks", func(t *testing.T) { + atxSvc := NewMockAtxService(gomock.NewController(t)) tab := newTestBuilder(t, 1) + tab.atxSvc = atxSvc atxInDb := &types.ActivationTx{TickCount: 10} atxInDb.SetID(types.RandomATXID()) - require.NoError(t, atxs.Add(tab.db, atxInDb, types.AtxBlob{})) - tab.atxsdata.AddFromAtx(atxInDb, false) + atxSvc.EXPECT().PositioningATX(gomock.Any(), types.EpochID(98)).Return(atxInDb.ID(), nil) + atxSvc.EXPECT().Atx(context.Background(), atxInDb.ID()).Return(atxInDb, nil).Times(2) prev := &types.ActivationTx{TickCount: 100} prev.SetID(types.RandomATXID()) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), atxInDb.ID(), tab.goldenATXID, gomock.Any()) - found, err := tab.searchPositioningAtx(context.Background(), types.EmptyNodeID, 99) - require.NoError(t, err) - require.Equal(t, atxInDb.ID(), found) - // prev.Height > found.Height selected, err := tab.getPositioningAtx(context.Background(), types.EmptyNodeID, 99, prev) require.NoError(t, err) @@ -1505,67 +1448,6 @@ func TestGetPositioningAtx(t *testing.T) { require.NoError(t, err) require.Equal(t, prev.ID(), selected) }) - t.Run("prefers own previous or golded when positioning ATX selection timout expired", func(t *testing.T) { - tab := newTestBuilder(t, 1) - - atxInDb := &types.ActivationTx{TickCount: 100} - atxInDb.SetID(types.RandomATXID()) - require.NoError(t, atxs.Add(tab.db, atxInDb, types.AtxBlob{})) - tab.atxsdata.AddFromAtx(atxInDb, false) - - prev := &types.ActivationTx{TickCount: 90} - prev.SetID(types.RandomATXID()) - - // no timeout set up - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), atxInDb.ID(), tab.goldenATXID, gomock.Any()) - found, err := tab.getPositioningAtx(context.Background(), types.EmptyNodeID, 99, prev) - require.NoError(t, err) - require.Equal(t, atxInDb.ID(), found) - - tab.posAtxFinder.found = nil - - // timeout set up, prev ATX exists - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - selected, err := tab.getPositioningAtx(ctx, types.EmptyNodeID, 99, prev) - require.NoError(t, err) - require.Equal(t, prev.ID(), selected) - - tab.posAtxFinder.found = nil - - // timeout set up, prev ATX do not exists - ctx, cancel = context.WithCancel(context.Background()) - cancel() - - selected, err = tab.getPositioningAtx(ctx, types.EmptyNodeID, 99, nil) - require.NoError(t, err) - require.Equal(t, tab.goldenATXID, selected) - }) -} - -func TestFindFullyValidHighTickAtx(t *testing.T) { - t.Parallel() - golden := types.RandomATXID() - - t.Run("skips malicious ATXs", func(t *testing.T) { - data := atxsdata.New() - atxMal := &types.ActivationTx{TickCount: 100, SmesherID: types.RandomNodeID()} - atxMal.SetID(types.RandomATXID()) - data.AddFromAtx(atxMal, true) - - atxLower := &types.ActivationTx{TickCount: 10, SmesherID: types.RandomNodeID()} - atxLower.SetID(types.RandomATXID()) - data.AddFromAtx(atxLower, false) - - mValidator := NewMocknipostValidator(gomock.NewController(t)) - mValidator.EXPECT().VerifyChain(gomock.Any(), atxLower.ID(), golden, gomock.Any()) - - lg := zaptest.NewLogger(t) - found, err := findFullyValidHighTickAtx(context.Background(), data, 0, golden, mValidator, lg) - require.NoError(t, err) - require.Equal(t, atxLower.ID(), found) - }) } // Test_Builder_RegenerateInitialPost tests the coverage for the edge case diff --git a/activation/atx_service_db.go b/activation/atx_service_db.go new file mode 100644 index 0000000000..94e40b17ba --- /dev/null +++ b/activation/atx_service_db.go @@ -0,0 +1,141 @@ +package activation + +import ( + "context" + "fmt" + "time" + + "go.uber.org/zap" + + "github.com/spacemeshos/go-spacemesh/atxsdata" + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/log" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" +) + +// dbAtxService implements AtxService by accessing the state database. +type dbAtxService struct { + golden types.ATXID + logger *zap.Logger + db sql.Executor + atxsdata *atxsdata.Data + validator nipostValidator + cfg dbAtxServiceConfig +} + +type dbAtxServiceConfig struct { + // delay before PoST in ATX is considered valid (counting from the time it was received) + postValidityDelay time.Duration +} + +type dbAtxServiceOption func(*dbAtxServiceConfig) + +func WithPostValidityDelay(delay time.Duration) dbAtxServiceOption { + return func(cfg *dbAtxServiceConfig) { + cfg.postValidityDelay = delay + } +} + +func NewDBAtxService( + db sql.Executor, + golden types.ATXID, + atxsdata *atxsdata.Data, + validator nipostValidator, + logger *zap.Logger, + opts ...dbAtxServiceOption, +) *dbAtxService { + cfg := dbAtxServiceConfig{ + postValidityDelay: time.Hour * 12, + } + + for _, opt := range opts { + opt(&cfg) + } + + return &dbAtxService{ + golden: golden, + logger: logger, + db: db, + atxsdata: atxsdata, + validator: validator, + cfg: cfg, + } +} + +func (s *dbAtxService) Atx(_ context.Context, id types.ATXID) (*types.ActivationTx, error) { + return atxs.Get(s.db, id) +} + +func (s *dbAtxService) LastATX(ctx context.Context, id types.NodeID) (*types.ActivationTx, error) { + atxid, err := atxs.GetLastIDByNodeID(s.db, id) + if err != nil { + return nil, fmt.Errorf("getting last ATXID: %w", err) + } + return atxs.Get(s.db, atxid) +} + +func (s *dbAtxService) PositioningATX(ctx context.Context, maxPublish types.EpochID) (types.ATXID, error) { + latestPublished, err := atxs.LatestEpoch(s.db) + if err != nil { + return types.EmptyATXID, fmt.Errorf("get latest epoch: %w", err) + } + s.logger.Info("searching for positioning atx", zap.Uint32("latest_epoch", latestPublished.Uint32())) + + // positioning ATX publish epoch must be lower than the publish epoch of built ATX + positioningAtxPublished := min(latestPublished, maxPublish) + id, err := findFullyValidHighTickAtx( + ctx, + s.atxsdata, + positioningAtxPublished, + s.golden, + s.validator, s.logger, + VerifyChainOpts.AssumeValidBefore(time.Now().Add(-s.cfg.postValidityDelay)), + // VerifyChainOpts.WithTrustedID(nodeID), + VerifyChainOpts.WithLogger(s.logger), + ) + if err != nil { + s.logger.Info("search failed - using golden atx as positioning atx", zap.Error(err)) + id = s.golden + } + + return id, nil +} + +func findFullyValidHighTickAtx( + ctx context.Context, + atxdata *atxsdata.Data, + publish types.EpochID, + goldenATXID types.ATXID, + validator nipostValidator, + logger *zap.Logger, + opts ...VerifyChainOption, +) (types.ATXID, error) { + var found *types.ATXID + + // iterate trough epochs, to get first valid, not malicious ATX with the biggest height + atxdata.IterateHighTicksInEpoch(publish+1, func(id types.ATXID) (contSearch bool) { + logger.Debug("found candidate for high-tick atx", log.ZShortStringer("id", id)) + if ctx.Err() != nil { + return false + } + // verify ATX-candidate by getting their dependencies (previous Atx, positioning ATX etc.) + // and verifying PoST for every dependency + if err := validator.VerifyChain(ctx, id, goldenATXID, opts...); err != nil { + logger.Debug("rejecting candidate for high-tick atx", zap.Error(err), log.ZShortStringer("id", id)) + return true + } + found = &id + return false + }) + + if ctx.Err() != nil { + return types.ATXID{}, ctx.Err() + } + + if found == nil { + return types.ATXID{}, ErrNotFound + } + + return *found, nil +} diff --git a/activation/atx_service_db_test.go b/activation/atx_service_db_test.go new file mode 100644 index 0000000000..f2e916a595 --- /dev/null +++ b/activation/atx_service_db_test.go @@ -0,0 +1,104 @@ +package activation + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" + + "github.com/spacemeshos/go-spacemesh/atxsdata" + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func newTestDbAtxService(t *testing.T) *dbAtxService { + return NewDBAtxService( + statesql.InMemoryTest(t), + types.RandomATXID(), + atxsdata.New(), + NewMocknipostValidator(gomock.NewController(t)), + zaptest.NewLogger(t), + ) +} + +// Test if PositioningAtx disregards ATXs with invalid POST in their chain. +// It should pick an ATX with valid POST even though it's a lower height. +func TestGetPositioningAtxPicksAtxWithValidChain(t *testing.T) { + atxSvc := newTestDbAtxService(t) + + // Invalid chain with high height + sigInvalid, err := signing.NewEdSigner() + require.NoError(t, err) + invalidAtx := newInitialATXv1(t, atxSvc.golden) + invalidAtx.Sign(sigInvalid) + vInvalidAtx := toAtx(t, invalidAtx) + vInvalidAtx.TickCount = 100 + require.NoError(t, err) + require.NoError(t, atxs.Add(atxSvc.db, vInvalidAtx, invalidAtx.Blob())) + atxSvc.atxsdata.AddFromAtx(vInvalidAtx, false) + + // Valid chain with lower height + sigValid, err := signing.NewEdSigner() + require.NoError(t, err) + validAtx := newInitialATXv1(t, atxSvc.golden) + validAtx.NumUnits += 10 + validAtx.Sign(sigValid) + vValidAtx := toAtx(t, validAtx) + require.NoError(t, atxs.Add(atxSvc.db, vValidAtx, validAtx.Blob())) + atxSvc.atxsdata.AddFromAtx(vValidAtx, false) + + atxSvc.validator.(*MocknipostValidator).EXPECT(). + VerifyChain(gomock.Any(), invalidAtx.ID(), atxSvc.golden, gomock.Any()). + Return(errors.New("this is invalid")) + atxSvc.validator.(*MocknipostValidator).EXPECT(). + VerifyChain(gomock.Any(), validAtx.ID(), atxSvc.golden, gomock.Any()) + + posAtxID, err := atxSvc.PositioningATX(context.Background(), validAtx.PublishEpoch) + require.NoError(t, err) + require.Equal(t, vValidAtx.ID(), posAtxID) + + // look in a later epoch, it should return the same one (there is no newer one). + atxSvc.validator.(*MocknipostValidator).EXPECT(). + VerifyChain(gomock.Any(), invalidAtx.ID(), atxSvc.golden, gomock.Any()). + Return(errors.New("")) + atxSvc.validator.(*MocknipostValidator).EXPECT(). + VerifyChain(gomock.Any(), validAtx.ID(), atxSvc.golden, gomock.Any()) + + posAtxID, err = atxSvc.PositioningATX(context.Background(), validAtx.PublishEpoch+1) + require.NoError(t, err) + require.Equal(t, vValidAtx.ID(), posAtxID) + + // it returns the golden ATX if couldn't find a better one + posAtxID, err = atxSvc.PositioningATX(context.Background(), validAtx.PublishEpoch-1) + require.NoError(t, err) + require.Equal(t, atxSvc.golden, posAtxID) +} + +func TestFindFullyValidHighTickAtx(t *testing.T) { + t.Parallel() + golden := types.RandomATXID() + + t.Run("skips malicious ATXs", func(t *testing.T) { + data := atxsdata.New() + atxMal := &types.ActivationTx{TickCount: 100, SmesherID: types.RandomNodeID()} + atxMal.SetID(types.RandomATXID()) + data.AddFromAtx(atxMal, true) + + atxLower := &types.ActivationTx{TickCount: 10, SmesherID: types.RandomNodeID()} + atxLower.SetID(types.RandomATXID()) + data.AddFromAtx(atxLower, false) + + mValidator := NewMocknipostValidator(gomock.NewController(t)) + mValidator.EXPECT().VerifyChain(gomock.Any(), atxLower.ID(), golden, gomock.Any()) + + lg := zaptest.NewLogger(t) + found, err := findFullyValidHighTickAtx(context.Background(), data, 0, golden, mValidator, lg) + require.NoError(t, err) + require.Equal(t, atxLower.ID(), found) + }) +} diff --git a/activation/builder_v2_test.go b/activation/builder_v2_test.go index 0054147f4e..d7170c12bf 100644 --- a/activation/builder_v2_test.go +++ b/activation/builder_v2_test.go @@ -93,7 +93,6 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) { posEpoch += 1 layer = posEpoch.FirstLayer() tab.mclock.EXPECT().CurrentLayer().Return(layer).Times(4) - tab.mValidator.EXPECT().VerifyChain(gomock.Any(), atx1.ID(), tab.goldenATXID, gomock.Any()) var atx2 wire.ActivationTxV2 publishAtx(t, tab, sig.NodeID(), posEpoch, &layer, layersPerEpoch, func(_ context.Context, _ string, got []byte) error { diff --git a/activation/e2e/activation_test.go b/activation/e2e/activation_test.go index d0c86e5785..d4d65a2a5b 100644 --- a/activation/e2e/activation_test.go +++ b/activation/e2e/activation_test.go @@ -189,18 +189,26 @@ func Test_BuilderWithMultipleClients(t *testing.T) { ).Times(totalAtxs) t.Cleanup(func() { assert.NoError(t, verifier.Close()) }) - tab := activation.NewBuilder( - conf, + + atxService := activation.NewDBAtxService( db, + conf.GoldenATXID, data, + validator, + logger, + ) + + tab := activation.NewBuilder( + conf, localDB, + atxService, mpub, + validator, nb, clock, syncedSyncer(t), logger, activation.WithPoetConfig(poetCfg), - activation.WithValidator(validator), activation.WithPoets(client), ) for _, sig := range signers { diff --git a/activation/e2e/builds_atx_v2_test.go b/activation/e2e/builds_atx_v2_test.go index 93832608eb..1621329847 100644 --- a/activation/e2e/builds_atx_v2_test.go +++ b/activation/e2e/builds_atx_v2_test.go @@ -200,18 +200,25 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) { ).Times(2), ) - tab := activation.NewBuilder( - conf, + atxService := activation.NewDBAtxService( db, + conf.GoldenATXID, atxsdata, + validator, + logger, + ) + + tab := activation.NewBuilder( + conf, localDB, + atxService, mpub, + validator, nb, clock, syncedSyncer(t), logger, activation.WithPoetConfig(poetCfg), - activation.WithValidator(validator), activation.BuilderAtxVersions(atxVersions), ) tab.Register(sig) diff --git a/activation/e2e/checkpoint_test.go b/activation/e2e/checkpoint_test.go index b7a2596cae..e06ce214af 100644 --- a/activation/e2e/checkpoint_test.go +++ b/activation/e2e/checkpoint_test.go @@ -123,18 +123,25 @@ func TestCheckpoint_PublishingSoloATXs(t *testing.T) { activation.WithAtxVersions(atxVersions), ) - tab := activation.NewBuilder( - activation.Config{GoldenATXID: goldenATX}, + atxService := activation.NewDBAtxService( db, + goldenATX, atxdata, + validator, + logger, + ) + + tab := activation.NewBuilder( + activation.Config{GoldenATXID: goldenATX}, localDB, + atxService, mpub, + validator, nb, clock, syncer, logger, activation.WithPoetConfig(poetCfg), - activation.WithValidator(validator), activation.BuilderAtxVersions(atxVersions), ) tab.Register(sig) @@ -223,18 +230,25 @@ func TestCheckpoint_PublishingSoloATXs(t *testing.T) { ) require.NoError(t, err) - tab = activation.NewBuilder( - activation.Config{GoldenATXID: goldenATX}, + atxService = activation.NewDBAtxService( newDB, + goldenATX, atxdata, + validator, + logger, + ) + + tab = activation.NewBuilder( + activation.Config{GoldenATXID: goldenATX}, localDB, + atxService, mpub, + validator, nb, clock, syncer, logger, activation.WithPoetConfig(poetCfg), - activation.WithValidator(validator), activation.BuilderAtxVersions(atxVersions), ) tab.Register(sig) diff --git a/activation/interface.go b/activation/interface.go index c9c3359091..1fa673948f 100644 --- a/activation/interface.go +++ b/activation/interface.go @@ -108,6 +108,15 @@ type atxProvider interface { GetAtx(id types.ATXID) (*types.ActivationTx, error) } +// AtxService provides ATXs needed by the ATX Builder. +type AtxService interface { + Atx(ctx context.Context, id types.ATXID) (*types.ActivationTx, error) + LastATX(ctx context.Context, nodeID types.NodeID) (*types.ActivationTx, error) + // PositioningATX returns atx id with the highest tick height. + // The maxPublish epoch is the maximum publish epoch of the returned ATX. + PositioningATX(ctx context.Context, maxPublish types.EpochID) (types.ATXID, error) +} + // PostSetupProvider defines the functionality required for Post setup. // This interface is used by the atx builder and currently implemented by the PostSetupManager. // Eventually most of the functionality will be moved to the PoSTClient. diff --git a/activation/mocks.go b/activation/mocks.go index 985f6a05f3..4cd3041f9d 100644 --- a/activation/mocks.go +++ b/activation/mocks.go @@ -1217,6 +1217,146 @@ func (c *MockatxProviderGetAtxCall) DoAndReturn(f func(types.ATXID) (*types.Acti return c } +// MockAtxService is a mock of AtxService interface. +type MockAtxService struct { + ctrl *gomock.Controller + recorder *MockAtxServiceMockRecorder +} + +// MockAtxServiceMockRecorder is the mock recorder for MockAtxService. +type MockAtxServiceMockRecorder struct { + mock *MockAtxService +} + +// NewMockAtxService creates a new mock instance. +func NewMockAtxService(ctrl *gomock.Controller) *MockAtxService { + mock := &MockAtxService{ctrl: ctrl} + mock.recorder = &MockAtxServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAtxService) EXPECT() *MockAtxServiceMockRecorder { + return m.recorder +} + +// Atx mocks base method. +func (m *MockAtxService) Atx(ctx context.Context, id types.ATXID) (*types.ActivationTx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Atx", ctx, id) + ret0, _ := ret[0].(*types.ActivationTx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Atx indicates an expected call of Atx. +func (mr *MockAtxServiceMockRecorder) Atx(ctx, id any) *MockAtxServiceAtxCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Atx", reflect.TypeOf((*MockAtxService)(nil).Atx), ctx, id) + return &MockAtxServiceAtxCall{Call: call} +} + +// MockAtxServiceAtxCall wrap *gomock.Call +type MockAtxServiceAtxCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockAtxServiceAtxCall) Return(arg0 *types.ActivationTx, arg1 error) *MockAtxServiceAtxCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockAtxServiceAtxCall) Do(f func(context.Context, types.ATXID) (*types.ActivationTx, error)) *MockAtxServiceAtxCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockAtxServiceAtxCall) DoAndReturn(f func(context.Context, types.ATXID) (*types.ActivationTx, error)) *MockAtxServiceAtxCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// LastATX mocks base method. +func (m *MockAtxService) LastATX(ctx context.Context, nodeID types.NodeID) (*types.ActivationTx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastATX", ctx, nodeID) + ret0, _ := ret[0].(*types.ActivationTx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LastATX indicates an expected call of LastATX. +func (mr *MockAtxServiceMockRecorder) LastATX(ctx, nodeID any) *MockAtxServiceLastATXCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastATX", reflect.TypeOf((*MockAtxService)(nil).LastATX), ctx, nodeID) + return &MockAtxServiceLastATXCall{Call: call} +} + +// MockAtxServiceLastATXCall wrap *gomock.Call +type MockAtxServiceLastATXCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockAtxServiceLastATXCall) Return(arg0 *types.ActivationTx, arg1 error) *MockAtxServiceLastATXCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockAtxServiceLastATXCall) Do(f func(context.Context, types.NodeID) (*types.ActivationTx, error)) *MockAtxServiceLastATXCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockAtxServiceLastATXCall) DoAndReturn(f func(context.Context, types.NodeID) (*types.ActivationTx, error)) *MockAtxServiceLastATXCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// PositioningATX mocks base method. +func (m *MockAtxService) PositioningATX(ctx context.Context, maxPublish types.EpochID) (types.ATXID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PositioningATX", ctx, maxPublish) + ret0, _ := ret[0].(types.ATXID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PositioningATX indicates an expected call of PositioningATX. +func (mr *MockAtxServiceMockRecorder) PositioningATX(ctx, maxPublish any) *MockAtxServicePositioningATXCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PositioningATX", reflect.TypeOf((*MockAtxService)(nil).PositioningATX), ctx, maxPublish) + return &MockAtxServicePositioningATXCall{Call: call} +} + +// MockAtxServicePositioningATXCall wrap *gomock.Call +type MockAtxServicePositioningATXCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockAtxServicePositioningATXCall) Return(arg0 types.ATXID, arg1 error) *MockAtxServicePositioningATXCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockAtxServicePositioningATXCall) Do(f func(context.Context, types.EpochID) (types.ATXID, error)) *MockAtxServicePositioningATXCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockAtxServicePositioningATXCall) DoAndReturn(f func(context.Context, types.EpochID) (types.ATXID, error)) *MockAtxServicePositioningATXCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // MockpostSetupProvider is a mock of postSetupProvider interface. type MockpostSetupProvider struct { ctrl *gomock.Controller diff --git a/common/errors.go b/common/errors.go new file mode 100644 index 0000000000..4342b6fa07 --- /dev/null +++ b/common/errors.go @@ -0,0 +1,5 @@ +package common + +import "errors" + +var ErrNotFound = errors.New("not found") diff --git a/node/node.go b/node/node.go index 11076d235b..97013612ee 100644 --- a/node/node.go +++ b/node/node.go @@ -1078,22 +1078,30 @@ func (app *App) initServices(ctx context.Context) error { GoldenATXID: goldenATXID, RegossipInterval: app.Config.RegossipAtxInterval, } - atxBuilder := activation.NewBuilder( - builderConfig, + + atxBuilderLog := app.addLogger(ATXBuilderLogger, lg).Zap() + atxService := activation.NewDBAtxService( app.db, + goldenATXID, app.atxsdata, + app.validator, + atxBuilderLog, + activation.WithPostValidityDelay(app.Config.PostValidDelay), + ) + atxBuilder := activation.NewBuilder( + builderConfig, app.localDB, + atxService, app.host, + app.validator, nipostBuilder, app.clock, newSyncer, - app.addLogger(ATXBuilderLogger, lg).Zap(), + atxBuilderLog, activation.WithContext(ctx), activation.WithPoetConfig(app.Config.POET), // TODO(dshulyak) makes no sense. how we ended using it? activation.WithPoetRetryInterval(app.Config.HARE3.PreroundDelay), - activation.WithValidator(app.validator), - activation.WithPostValidityDelay(app.Config.PostValidDelay), activation.WithPostStates(postStates), activation.WithPoets(poetClients...), activation.BuilderAtxVersions(app.Config.AtxVersions), diff --git a/sql/database.go b/sql/database.go index 4f4224b710..5a6de38162 100644 --- a/sql/database.go +++ b/sql/database.go @@ -18,6 +18,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" + "github.com/spacemeshos/go-spacemesh/common" "github.com/spacemeshos/go-spacemesh/common/types" ) @@ -27,7 +28,7 @@ var ( // ErrNoConnection is returned if pooled connection is not available. ErrNoConnection = errors.New("database: no free connection") // ErrNotFound is returned if requested record is not found. - ErrNotFound = errors.New("database: not found") + ErrNotFound = fmt.Errorf("database: %w", common.ErrNotFound) // ErrObjectExists is returned if database constraints didn't allow to insert an object. ErrObjectExists = errors.New("database: object exists") // ErrConflict is returned if database constraints didn't allow to update an object. diff --git a/sql/localsql/atxs/atxs.go b/sql/localsql/atxs/atxs.go new file mode 100644 index 0000000000..6c0d351d5a --- /dev/null +++ b/sql/localsql/atxs/atxs.go @@ -0,0 +1,37 @@ +package atxs + +import ( + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" +) + +func AddBlob(db sql.LocalDatabase, epoch types.EpochID, id types.ATXID, nodeID types.NodeID, blob []byte) error { + _, err := db.Exec("INSERT INTO atx_blobs (epoch, id, pubkey, atx) VALUES (?1, ?2, ?3, ?4)", + func(s *sql.Statement) { + s.BindInt64(1, int64(epoch)) + s.BindBytes(2, id[:]) + s.BindBytes(3, nodeID[:]) + s.BindBytes(4, blob) + }, nil) + return err +} + +func AtxBlob(db sql.LocalDatabase, epoch types.EpochID, nodeID types.NodeID) (id types.ATXID, blob []byte, err error) { + rows, err := db.Exec("select id, atx from atx_blobs where epoch = ?1 and pubkey = ?2", + func(s *sql.Statement) { + s.BindInt64(1, int64(epoch)) + s.BindBytes(2, nodeID[:]) + }, + func(s *sql.Statement) bool { + s.ColumnBytes(0, id[:]) + blob = make([]byte, s.ColumnLen(1)) + s.ColumnBytes(1, blob) + return false + }, + ) + if rows == 0 { + return id, blob, sql.ErrNotFound + } + + return id, blob, err +} diff --git a/sql/localsql/schema/migrations/0010_atxs.sql b/sql/localsql/schema/migrations/0010_atxs.sql new file mode 100644 index 0000000000..19e9798333 --- /dev/null +++ b/sql/localsql/schema/migrations/0010_atxs.sql @@ -0,0 +1,12 @@ +--- Table for storing blobs of published ATX for regossiping purposes. +CREATE TABLE atx_blobs +( + id CHAR(32), + pubkey CHAR(32) NOT NULL, + epoch INT NOT NULL, + atx BLOB, + version INTEGER +); + +CREATE UNIQUE INDEX atx_blobs_id ON atx_blobs (id); +CREATE UNIQUE INDEX atx_blobs_epoch_pubkey ON atx_blobs (epoch, pubkey); diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql index 0b6a1c00f7..d614213034 100755 --- a/sql/localsql/schema/schema.sql +++ b/sql/localsql/schema/schema.sql @@ -1,4 +1,14 @@ -PRAGMA user_version = 9; +PRAGMA user_version = 10; +CREATE TABLE atx_blobs +( + id CHAR(32), + pubkey CHAR(32) NOT NULL, + epoch INT NOT NULL, + atx BLOB, + version INTEGER +); +CREATE UNIQUE INDEX atx_blobs_epoch_pubkey ON atx_blobs (epoch, pubkey); +CREATE UNIQUE INDEX atx_blobs_id ON atx_blobs (id); CREATE TABLE atx_sync_requests ( epoch INT NOT NULL, From 32818b068d663c8300293c71e3db35cbc2ae694f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Thu, 3 Oct 2024 13:57:56 +0200 Subject: [PATCH 4/8] review feedback --- activation/activation.go | 24 ++++++++------------ activation/atx_service_db.go | 9 +++++++- activation/validation.go | 9 ++++---- activation/validation_test.go | 2 +- node/node.go | 6 +++++ sql/localsql/schema/migrations/0010_atxs.sql | 3 +-- sql/localsql/schema/schema.sql | 3 +-- 7 files changed, 32 insertions(+), 24 deletions(-) diff --git a/activation/activation.go b/activation/activation.go index a09400a25c..88b79ab5ab 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -113,14 +113,16 @@ type Builder struct { stop context.CancelFunc } +type foundPosAtx struct { + id types.ATXID + forPublish types.EpochID +} + type positioningAtxFinder struct { finding sync.Mutex - found *struct { - id types.ATXID - forPublish types.EpochID - } - golden types.ATXID - logger *zap.Logger + found *foundPosAtx + golden types.ATXID + logger *zap.Logger } type BuilderOption func(*Builder) @@ -931,19 +933,13 @@ func (f *positioningAtxFinder) find( id, err := atxs.PositioningATX(ctx, publish-1) if err != nil { logger.Warn("failed to get positioning ATX - falling back to golden", zap.Error(err)) - f.found = &struct { - id types.ATXID - forPublish types.EpochID - }{f.golden, publish} + f.found = &foundPosAtx{f.golden, publish} return f.golden } logger.Debug("found candidate positioning atx", log.ZShortStringer("id", id)) - f.found = &struct { - id types.ATXID - forPublish types.EpochID - }{id, publish} + f.found = &foundPosAtx{id, publish} return id } diff --git a/activation/atx_service_db.go b/activation/atx_service_db.go index 94e40b17ba..e84713aa6e 100644 --- a/activation/atx_service_db.go +++ b/activation/atx_service_db.go @@ -27,6 +27,7 @@ type dbAtxService struct { type dbAtxServiceConfig struct { // delay before PoST in ATX is considered valid (counting from the time it was received) postValidityDelay time.Duration + trusted []types.NodeID } type dbAtxServiceOption func(*dbAtxServiceConfig) @@ -37,6 +38,12 @@ func WithPostValidityDelay(delay time.Duration) dbAtxServiceOption { } } +func WithTrustedIDs(ids ...types.NodeID) dbAtxServiceOption { + return func(cfg *dbAtxServiceConfig) { + cfg.trusted = ids + } +} + func NewDBAtxService( db sql.Executor, golden types.ATXID, @@ -91,7 +98,7 @@ func (s *dbAtxService) PositioningATX(ctx context.Context, maxPublish types.Epoc s.golden, s.validator, s.logger, VerifyChainOpts.AssumeValidBefore(time.Now().Add(-s.cfg.postValidityDelay)), - // VerifyChainOpts.WithTrustedID(nodeID), + VerifyChainOpts.WithTrustedIDs(s.cfg.trusted...), VerifyChainOpts.WithLogger(s.logger), ) if err != nil { diff --git a/activation/validation.go b/activation/validation.go index d6d070d895..b31a3ec786 100644 --- a/activation/validation.go +++ b/activation/validation.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "time" "github.com/spacemeshos/merkle-tree" @@ -360,7 +361,7 @@ func (v *Validator) PositioningAtx( type verifyChainOpts struct { assumedValidTime time.Time - trustedNodeID types.NodeID + trustedNodeID []types.NodeID logger *zap.Logger } @@ -377,8 +378,8 @@ func (verifyChainOptsNs) AssumeValidBefore(val time.Time) VerifyChainOption { } } -// WithTrustedID configures the validator to assume that ATXs created by the given node ID are valid. -func (verifyChainOptsNs) WithTrustedID(val types.NodeID) VerifyChainOption { +// WithTrustedIDs configures the validator to assume that ATXs created by the given node IDs are valid. +func (verifyChainOptsNs) WithTrustedIDs(val ...types.NodeID) VerifyChainOption { return func(o *verifyChainOpts) { o.trustedNodeID = val } @@ -533,7 +534,7 @@ func (v *Validator) verifyChainWithOpts( zap.Time("valid_before", opts.assumedValidTime), ) return nil - case atx.SmesherID == opts.trustedNodeID: + case slices.Contains(opts.trustedNodeID, atx.SmesherID): log.Debug("not verifying ATX chain", zap.Stringer("atx_id", id), zap.String("reason", "trusted")) return nil } diff --git a/activation/validation_test.go b/activation/validation_test.go index 59278548dd..41ebbc5642 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -543,7 +543,7 @@ func TestVerifyChainDeps(t *testing.T) { ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v) - err = validator.VerifyChain(ctx, vAtx.ID(), goldenATXID, VerifyChainOpts.WithTrustedID(signer.NodeID())) + err = validator.VerifyChain(ctx, vAtx.ID(), goldenATXID, VerifyChainOpts.WithTrustedIDs(signer.NodeID())) require.NoError(t, err) }) diff --git a/node/node.go b/node/node.go index 97013612ee..d5fc4089ed 100644 --- a/node/node.go +++ b/node/node.go @@ -1080,6 +1080,11 @@ func (app *App) initServices(ctx context.Context) error { } atxBuilderLog := app.addLogger(ATXBuilderLogger, lg).Zap() + trustedIDs := make([]types.NodeID, 0, len(app.signers)) + for _, sig := range app.signers { + trustedIDs = append(trustedIDs, sig.NodeID()) + } + atxService := activation.NewDBAtxService( app.db, goldenATXID, @@ -1087,6 +1092,7 @@ func (app *App) initServices(ctx context.Context) error { app.validator, atxBuilderLog, activation.WithPostValidityDelay(app.Config.PostValidDelay), + activation.WithTrustedIDs(trustedIDs...), ) atxBuilder := activation.NewBuilder( builderConfig, diff --git a/sql/localsql/schema/migrations/0010_atxs.sql b/sql/localsql/schema/migrations/0010_atxs.sql index 19e9798333..97ead34ce4 100644 --- a/sql/localsql/schema/migrations/0010_atxs.sql +++ b/sql/localsql/schema/migrations/0010_atxs.sql @@ -1,12 +1,11 @@ --- Table for storing blobs of published ATX for regossiping purposes. CREATE TABLE atx_blobs ( - id CHAR(32), + id CHAR(32) PRIMARY KEY, pubkey CHAR(32) NOT NULL, epoch INT NOT NULL, atx BLOB, version INTEGER ); -CREATE UNIQUE INDEX atx_blobs_id ON atx_blobs (id); CREATE UNIQUE INDEX atx_blobs_epoch_pubkey ON atx_blobs (epoch, pubkey); diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql index d614213034..0d920e9cf7 100755 --- a/sql/localsql/schema/schema.sql +++ b/sql/localsql/schema/schema.sql @@ -1,14 +1,13 @@ PRAGMA user_version = 10; CREATE TABLE atx_blobs ( - id CHAR(32), + id CHAR(32) PRIMARY KEY, pubkey CHAR(32) NOT NULL, epoch INT NOT NULL, atx BLOB, version INTEGER ); CREATE UNIQUE INDEX atx_blobs_epoch_pubkey ON atx_blobs (epoch, pubkey); -CREATE UNIQUE INDEX atx_blobs_id ON atx_blobs (id); CREATE TABLE atx_sync_requests ( epoch INT NOT NULL, From ca7cf2f2b370d1e34fe87a698b7ece25436e1624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Mon, 7 Oct 2024 11:49:22 +0200 Subject: [PATCH 5/8] review feedback: remove common error --- activation/activation.go | 16 ++++++---------- activation/activation_test.go | 2 +- activation/atx_service_db.go | 12 ++++++++++-- activation/interface.go | 9 +++++++++ common/errors.go | 5 ----- sql/database.go | 3 +-- 6 files changed, 27 insertions(+), 20 deletions(-) delete mode 100644 common/errors.go diff --git a/activation/activation.go b/activation/activation.go index 88b79ab5ab..3e092df3be 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -19,7 +19,6 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/metrics" "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/codec" - "github.com/spacemeshos/go-spacemesh/common" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" "github.com/spacemeshos/go-spacemesh/log" @@ -31,10 +30,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) -var ( - ErrNotFound = errors.New("not found") - errNilVrfNonce = errors.New("nil VRF nonce") -) +var errNilVrfNonce = errors.New("nil VRF nonce") // PoetConfig is the configuration to interact with the poet server. type PoetConfig struct { @@ -351,7 +347,7 @@ func (b *Builder) BuildInitialPost(ctx context.Context, nodeID types.NodeID) err case err == nil: b.logger.Info("load initial post from db") return nil - case errors.Is(err, common.ErrNotFound): + case errors.Is(err, sql.ErrNotFound): b.logger.Info("creating initial post") default: return fmt.Errorf("get initial post: %w", err) @@ -527,7 +523,7 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID) switch { case err == nil: currentEpochId = max(currentEpochId, prevAtx.PublishEpoch) - case errors.Is(err, common.ErrNotFound): + case errors.Is(err, ErrNotFound): // no previous ATX case err != nil: return nil, fmt.Errorf("get last ATX: %w", err) @@ -577,7 +573,7 @@ func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID) var challenge *types.NIPostChallenge switch { - case errors.Is(err, common.ErrNotFound): + case errors.Is(err, ErrNotFound): logger.Info("no previous ATX found, creating an initial nipost challenge") challenge, err = b.buildInitialNIPostChallenge(ctx, logger, nodeID, publishEpochId) if err != nil { @@ -613,7 +609,7 @@ func (b *Builder) getExistingChallenge( challenge, err := nipost.Challenge(b.localDB, nodeID) switch { - case errors.Is(err, common.ErrNotFound): + case errors.Is(err, sql.ErrNotFound): return nil, nil case err != nil: @@ -991,7 +987,7 @@ func (b *Builder) getPositioningAtx( func (b *Builder) Regossip(ctx context.Context, nodeID types.NodeID) error { epoch := b.layerClock.CurrentLayer().GetEpoch() id, blob, err := atxs.AtxBlob(b.localDB, epoch, nodeID) - if errors.Is(err, common.ErrNotFound) { + if errors.Is(err, sql.ErrNotFound) { return nil } else if err != nil { return err diff --git a/activation/activation_test.go b/activation/activation_test.go index 7b3bcd0eaf..f706a301fe 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -954,7 +954,7 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) { } posAtx.SetID(types.RandomATXID()) atxSvc.EXPECT().PositioningATX(gomock.Any(), gomock.Any()).Return(posAtx.ID(), nil) - atxSvc.EXPECT().LastATX(gomock.Any(), sig.NodeID()).Return(nil, sql.ErrNotFound).Times(2) + atxSvc.EXPECT().LastATX(gomock.Any(), sig.NodeID()).Return(nil, ErrNotFound).Times(2) // Act & Assert tab.msync.EXPECT().RegisterForATXSynced().DoAndReturn(closedChan).AnyTimes() diff --git a/activation/atx_service_db.go b/activation/atx_service_db.go index e84713aa6e..a2cf4fd5fb 100644 --- a/activation/atx_service_db.go +++ b/activation/atx_service_db.go @@ -2,6 +2,7 @@ package activation import ( "context" + "errors" "fmt" "time" @@ -71,12 +72,19 @@ func NewDBAtxService( } func (s *dbAtxService) Atx(_ context.Context, id types.ATXID) (*types.ActivationTx, error) { - return atxs.Get(s.db, id) + atx, err := atxs.Get(s.db, id) + if errors.Is(err, sql.ErrNotFound) { + return nil, ErrNotFound + } + return atx, err } func (s *dbAtxService) LastATX(ctx context.Context, id types.NodeID) (*types.ActivationTx, error) { atxid, err := atxs.GetLastIDByNodeID(s.db, id) - if err != nil { + switch { + case errors.Is(err, sql.ErrNotFound): + return nil, ErrNotFound + case err != nil: return nil, fmt.Errorf("getting last ATXID: %w", err) } return atxs.Get(s.db, atxid) diff --git a/activation/interface.go b/activation/interface.go index 1fa673948f..7a5066a8d7 100644 --- a/activation/interface.go +++ b/activation/interface.go @@ -19,6 +19,8 @@ import ( //go:generate mockgen -typed -package=activation -destination=./mocks.go -source=./interface.go +var ErrNotFound = errors.New("not found") + type AtxReceiver interface { OnAtx(*types.ActivationTx) } @@ -110,9 +112,16 @@ type atxProvider interface { // AtxService provides ATXs needed by the ATX Builder. type AtxService interface { + // Get ATX with given ID + // + // Returns ErrNotFound if couldn't get the ATX. Atx(ctx context.Context, id types.ATXID) (*types.ActivationTx, error) + // Get the last ATX of the given identitity. + // + // Returns ErrNotFound if couldn't get the ATX. LastATX(ctx context.Context, nodeID types.NodeID) (*types.ActivationTx, error) // PositioningATX returns atx id with the highest tick height. + // // The maxPublish epoch is the maximum publish epoch of the returned ATX. PositioningATX(ctx context.Context, maxPublish types.EpochID) (types.ATXID, error) } diff --git a/common/errors.go b/common/errors.go deleted file mode 100644 index 4342b6fa07..0000000000 --- a/common/errors.go +++ /dev/null @@ -1,5 +0,0 @@ -package common - -import "errors" - -var ErrNotFound = errors.New("not found") diff --git a/sql/database.go b/sql/database.go index 5a6de38162..4f4224b710 100644 --- a/sql/database.go +++ b/sql/database.go @@ -18,7 +18,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" - "github.com/spacemeshos/go-spacemesh/common" "github.com/spacemeshos/go-spacemesh/common/types" ) @@ -28,7 +27,7 @@ var ( // ErrNoConnection is returned if pooled connection is not available. ErrNoConnection = errors.New("database: no free connection") // ErrNotFound is returned if requested record is not found. - ErrNotFound = fmt.Errorf("database: %w", common.ErrNotFound) + ErrNotFound = errors.New("database: not found") // ErrObjectExists is returned if database constraints didn't allow to insert an object. ErrObjectExists = errors.New("database: object exists") // ErrConflict is returned if database constraints didn't allow to update an object. From 667e3f0a9f23547597de59459629835ef53609ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Mon, 14 Oct 2024 15:35:57 +0200 Subject: [PATCH 6/8] review feedback --- activation/activation.go | 6 ++-- activation/activation_multi_test.go | 2 +- activation/atx_service_db.go | 8 +---- activation/atx_service_db_test.go | 6 ++-- sql/localsql/{atxs => localatxs}/atxs.go | 2 +- sql/localsql/localatxs/atxs_test.go | 41 ++++++++++++++++++++++++ sql/localsql/schema/schema.sql | 2 +- 7 files changed, 50 insertions(+), 17 deletions(-) rename sql/localsql/{atxs => localatxs}/atxs.go (98%) create mode 100644 sql/localsql/localatxs/atxs_test.go diff --git a/activation/activation.go b/activation/activation.go index 3e092df3be..0b9d720b32 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -26,7 +26,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/localsql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/localsql/localatxs" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -736,7 +736,7 @@ func (b *Builder) PublishActivationTx(ctx context.Context, sig *signing.EdSigner case <-b.layerClock.AwaitLayer(challenge.PublishEpoch.FirstLayer()): } - err = atxs.AddBlob(b.localDB, challenge.PublishEpoch, atx.ID(), sig.NodeID(), codec.MustEncode(atx)) + err = localatxs.AddBlob(b.localDB, challenge.PublishEpoch, atx.ID(), sig.NodeID(), codec.MustEncode(atx)) if err != nil { b.logger.Warn("failed to persist built ATX into the local DB - regossiping won't work", zap.Error(err)) } @@ -986,7 +986,7 @@ func (b *Builder) getPositioningAtx( func (b *Builder) Regossip(ctx context.Context, nodeID types.NodeID) error { epoch := b.layerClock.CurrentLayer().GetEpoch() - id, blob, err := atxs.AtxBlob(b.localDB, epoch, nodeID) + id, blob, err := localatxs.AtxBlob(b.localDB, epoch, nodeID) if errors.Is(err, sql.ErrNotFound) { return nil } else if err != nil { diff --git a/activation/activation_multi_test.go b/activation/activation_multi_test.go index 56736a7c51..064dc4a98d 100644 --- a/activation/activation_multi_test.go +++ b/activation/activation_multi_test.go @@ -18,7 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" - localatxs "github.com/spacemeshos/go-spacemesh/sql/localsql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/localsql/localatxs" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) diff --git a/activation/atx_service_db.go b/activation/atx_service_db.go index a2cf4fd5fb..26e11ad7ae 100644 --- a/activation/atx_service_db.go +++ b/activation/atx_service_db.go @@ -99,7 +99,7 @@ func (s *dbAtxService) PositioningATX(ctx context.Context, maxPublish types.Epoc // positioning ATX publish epoch must be lower than the publish epoch of built ATX positioningAtxPublished := min(latestPublished, maxPublish) - id, err := findFullyValidHighTickAtx( + return findFullyValidHighTickAtx( ctx, s.atxsdata, positioningAtxPublished, @@ -109,12 +109,6 @@ func (s *dbAtxService) PositioningATX(ctx context.Context, maxPublish types.Epoc VerifyChainOpts.WithTrustedIDs(s.cfg.trusted...), VerifyChainOpts.WithLogger(s.logger), ) - if err != nil { - s.logger.Info("search failed - using golden atx as positioning atx", zap.Error(err)) - id = s.golden - } - - return id, nil } func findFullyValidHighTickAtx( diff --git a/activation/atx_service_db_test.go b/activation/atx_service_db_test.go index f2e916a595..9ed32db65c 100644 --- a/activation/atx_service_db_test.go +++ b/activation/atx_service_db_test.go @@ -73,10 +73,8 @@ func TestGetPositioningAtxPicksAtxWithValidChain(t *testing.T) { require.NoError(t, err) require.Equal(t, vValidAtx.ID(), posAtxID) - // it returns the golden ATX if couldn't find a better one - posAtxID, err = atxSvc.PositioningATX(context.Background(), validAtx.PublishEpoch-1) - require.NoError(t, err) - require.Equal(t, atxSvc.golden, posAtxID) + _, err = atxSvc.PositioningATX(context.Background(), validAtx.PublishEpoch-1) + require.ErrorIs(t, err, ErrNotFound) } func TestFindFullyValidHighTickAtx(t *testing.T) { diff --git a/sql/localsql/atxs/atxs.go b/sql/localsql/localatxs/atxs.go similarity index 98% rename from sql/localsql/atxs/atxs.go rename to sql/localsql/localatxs/atxs.go index 6c0d351d5a..de9eaadee0 100644 --- a/sql/localsql/atxs/atxs.go +++ b/sql/localsql/localatxs/atxs.go @@ -1,4 +1,4 @@ -package atxs +package localatxs import ( "github.com/spacemeshos/go-spacemesh/common/types" diff --git a/sql/localsql/localatxs/atxs_test.go b/sql/localsql/localatxs/atxs_test.go new file mode 100644 index 0000000000..23e77d402e --- /dev/null +++ b/sql/localsql/localatxs/atxs_test.go @@ -0,0 +1,41 @@ +package localatxs_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/localsql/localatxs" +) + +func Test_Blobs(t *testing.T) { + t.Run("not found", func(t *testing.T) { + db := localsql.InMemoryTest(t) + _, _, err := localatxs.AtxBlob(db, types.EpochID(0), types.NodeID{}) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("found", func(t *testing.T) { + db := localsql.InMemoryTest(t) + epoch := types.EpochID(2) + atxid := types.RandomATXID() + nodeID := types.RandomNodeID() + blob := types.RandomBytes(10) + err := localatxs.AddBlob(db, epoch, atxid, nodeID, blob) + require.NoError(t, err) + gotID, gotBlob, err := localatxs.AtxBlob(db, epoch, nodeID) + require.NoError(t, err) + require.Equal(t, atxid, gotID) + require.Equal(t, blob, gotBlob) + + // different ID + _, _, err = localatxs.AtxBlob(db, epoch, types.RandomNodeID()) + require.ErrorIs(t, err, sql.ErrNotFound) + + // different epoch + _, _, err = localatxs.AtxBlob(db, types.EpochID(3), nodeID) + require.ErrorIs(t, err, sql.ErrNotFound) + }) +} diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql index 0d920e9cf7..a3a7333b18 100755 --- a/sql/localsql/schema/schema.sql +++ b/sql/localsql/schema/schema.sql @@ -7,7 +7,6 @@ CREATE TABLE atx_blobs atx BLOB, version INTEGER ); -CREATE UNIQUE INDEX atx_blobs_epoch_pubkey ON atx_blobs (epoch, pubkey); CREATE TABLE atx_sync_requests ( epoch INT NOT NULL, @@ -89,4 +88,5 @@ CREATE TABLE prepared_activeset data BLOB NOT NULL, PRIMARY KEY (kind, epoch) ) WITHOUT ROWID; +CREATE UNIQUE INDEX atx_blobs_epoch_pubkey ON atx_blobs (epoch, pubkey); CREATE UNIQUE INDEX idx_poet_certificates ON poet_certificates (node_id, certifier_id); From ec8a7d8b97b2b990e5c9a9f6eb28183d52aa2412 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Thu, 17 Oct 2024 12:37:52 +0200 Subject: [PATCH 7/8] review feedback: use singleflight --- activation/activation.go | 48 ++++++++++++++++++----------------- activation/activation_test.go | 16 ++++++++++++ 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/activation/activation.go b/activation/activation.go index 0b9d720b32..46f7d9ae34 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -15,6 +15,7 @@ import ( "go.uber.org/zap/zapcore" "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" "github.com/spacemeshos/go-spacemesh/activation/metrics" "github.com/spacemeshos/go-spacemesh/activation/wire" @@ -115,9 +116,8 @@ type foundPosAtx struct { } type positioningAtxFinder struct { - finding sync.Mutex - found *foundPosAtx - golden types.ATXID + finding singleflight.Group + found foundPosAtx logger *zap.Logger } @@ -193,7 +193,6 @@ func NewBuilder( postStates: NewPostStates(log), versions: []atxVersion{{0, types.AtxV1}}, posAtxFinder: positioningAtxFinder{ - golden: conf.GoldenATXID, logger: log, }, } @@ -915,28 +914,26 @@ func (f *positioningAtxFinder) find( ctx context.Context, atxs AtxService, publish types.EpochID, -) types.ATXID { +) (types.ATXID, error) { logger := f.logger.With(zap.Uint32("publish epoch", publish.Uint32())) - f.finding.Lock() - defer f.finding.Unlock() - - if found := f.found; found != nil && found.forPublish == publish { - logger.Debug("using cached positioning atx", log.ZShortStringer("atx_id", found.id)) - return found.id - } + atx, err, _ := f.finding.Do(publish.String(), func() (any, error) { + if f.found.forPublish == publish { + logger.Debug("using cached positioning atx", log.ZShortStringer("atx_id", f.found.id)) + return f.found.id, nil + } - id, err := atxs.PositioningATX(ctx, publish-1) - if err != nil { - logger.Warn("failed to get positioning ATX - falling back to golden", zap.Error(err)) - f.found = &foundPosAtx{f.golden, publish} - return f.golden - } + id, err := atxs.PositioningATX(ctx, publish-1) + if err != nil { + return types.EmptyATXID, err + } - logger.Debug("found candidate positioning atx", log.ZShortStringer("id", id)) + logger.Debug("found candidate positioning atx", log.ZShortStringer("id", id)) + f.found = foundPosAtx{id, publish} + return id, nil + }) - f.found = &foundPosAtx{id, publish} - return id + return atx.(types.ATXID), err } // getPositioningAtx returns the positioning ATX. @@ -948,7 +945,11 @@ func (b *Builder) getPositioningAtx( publish types.EpochID, previous *types.ActivationTx, ) (types.ATXID, error) { - id := b.posAtxFinder.find(ctx, b.atxSvc, publish) + id, err := b.posAtxFinder.find(ctx, b.atxSvc, publish) + if err != nil { + b.logger.Warn("failed to find positioning ATX - falling back to golden", zap.Error(err)) + id = b.conf.GoldenATXID + } if previous == nil { b.logger.Info("selected positioning atx", @@ -968,7 +969,8 @@ func (b *Builder) getPositioningAtx( candidate, err := b.atxSvc.Atx(ctx, id) if err != nil { - return types.EmptyATXID, fmt.Errorf("get candidate pos ATX %s: %w", id.ShortString(), err) + b.logger.Warn("failed to get candidate pos ATX - falling back to previous", zap.Error(err)) + return previous.ID(), nil } if previous.TickHeight() >= candidate.TickHeight() { diff --git a/activation/activation_test.go b/activation/activation_test.go index f706a301fe..78cb0b0eac 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -1410,6 +1410,22 @@ func TestGetPositioningAtx(t *testing.T) { require.NoError(t, err) require.Equal(t, tab.goldenATXID, posATX) }) + t.Run("picks previous when querying candidate fails and previous is available", func(t *testing.T) { + t.Parallel() + atxSvc := NewMockAtxService(gomock.NewController(t)) + tab := newTestBuilder(t, 1) + tab.atxSvc = atxSvc + + atxID := types.RandomATXID() + atxSvc.EXPECT().PositioningATX(gomock.Any(), types.EpochID(98)).Return(atxID, nil) + atxSvc.EXPECT().Atx(context.Background(), atxID).Return(nil, errors.New("failed")) + + previous := types.ActivationTx{} + previous.SetID(types.RandomATXID()) + posATX, err := tab.getPositioningAtx(context.Background(), types.EmptyNodeID, 99, &previous) + require.NoError(t, err) + require.Equal(t, previous.ID(), posATX) + }) t.Run("picks golden if no ATXs", func(t *testing.T) { tab := newTestBuilder(t, 1) atx, err := tab.getPositioningAtx(context.Background(), types.EmptyNodeID, 99, nil) From 55e311bcdd1df08ad172b423bc73f7fd7bc5bbd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Tue, 29 Oct 2024 16:07:16 +0100 Subject: [PATCH 8/8] regenerate code --- activation/mocks.go | 1 + 1 file changed, 1 insertion(+) diff --git a/activation/mocks.go b/activation/mocks.go index 4cd3041f9d..4e1fe49cef 100644 --- a/activation/mocks.go +++ b/activation/mocks.go @@ -1221,6 +1221,7 @@ func (c *MockatxProviderGetAtxCall) DoAndReturn(f func(types.ATXID) (*types.Acti type MockAtxService struct { ctrl *gomock.Controller recorder *MockAtxServiceMockRecorder + isgomock struct{} } // MockAtxServiceMockRecorder is the mock recorder for MockAtxService.