From c983b22b8c6f8b7b5a573a9827356d360adfe14a Mon Sep 17 00:00:00 2001 From: Brennan Lamey <66885902+brennanjl@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:47:18 -0500 Subject: [PATCH] pg types and full changeset foundation This reworks the pg type system, allowing appropriately typed variables at higher levels (outside pg package). This also enables FULL replication identity for all tables, even those with a primary key, which is to allow full changeset collection to support the upcoming network migrations system. --------- Co-authored-by: Jonathan Chappelow * avoid sentry table seq updates for non-2pc txns * txapp: remove contextual params from mempool struct --- cmd/kwil-admin/cmds/snapshot/create.go | 61 +- cmd/kwild/server/build.go | 12 +- common/common.go | 19 +- common/sql/sql.go | 31 +- core/types/schema.go | 24 +- core/types/transactions/result.go | 14 + internal/abci/abci.go | 52 +- internal/abci/abci_test.go | 11 +- internal/abci/interfaces.go | 6 +- internal/abci/meta/meta.go | 40 +- internal/abci/meta/meta_test.go | 4 +- internal/engine/execution/procedure.go | 120 ++- .../engine/integration/deployment_test.go | 2 +- internal/engine/integration/execution_test.go | 6 +- internal/engine/integration/procedure_test.go | 64 +- internal/engine/integration/schema_test.go | 2 +- internal/engine/integration/sql_test.go | 2 +- internal/migrations/interfaces.go | 18 + internal/migrations/migrations.go | 103 ++ internal/migrations/migrator.go | 249 +++++ internal/migrations/sql.go | 102 ++ internal/services/grpc/txsvc/v1/pricing.go | 2 +- internal/services/grpc/txsvc/v1/service.go | 7 +- internal/services/jsonrpc/adminsvc/service.go | 11 +- internal/services/jsonrpc/usersvc/service.go | 9 +- internal/sql/pg/conn.go | 20 +- internal/sql/pg/db.go | 105 +- internal/sql/pg/db_live_test.go | 262 ++++- internal/sql/pg/query.go | 231 +---- internal/sql/pg/repl.go | 50 +- internal/sql/pg/repl_changeset.go | 501 ++++++++++ internal/sql/pg/repl_test.go | 4 +- internal/sql/pg/replmon.go | 53 +- internal/sql/pg/sql.go | 67 +- internal/sql/pg/trigger_repl1.sql | 34 - internal/sql/pg/trigger_repl2.sql | 28 - internal/sql/pg/tx.go | 14 +- internal/sql/pg/types.go | 936 ++++++++++++++++++ internal/sql/pg/types_test.go | 38 + internal/txapp/interfaces.go | 2 +- internal/txapp/mempool.go | 45 +- internal/txapp/mempool_test.go | 45 +- internal/txapp/routes.go | 35 + internal/txapp/routes_test.go | 61 +- internal/txapp/txapp.go | 45 +- internal/voting/broadcast/broadcast.go | 21 +- internal/voting/broadcast/broadcast_test.go | 10 +- internal/voting/events.go | 4 +- internal/voting/voting.go | 17 +- testing/testing.go | 2 +- 50 files changed, 2911 insertions(+), 690 deletions(-) create mode 100644 internal/migrations/interfaces.go create mode 100644 internal/migrations/migrations.go create mode 100644 internal/migrations/migrator.go create mode 100644 internal/migrations/sql.go create mode 100644 internal/sql/pg/repl_changeset.go delete mode 100644 internal/sql/pg/trigger_repl1.sql delete mode 100644 internal/sql/pg/trigger_repl2.sql create mode 100644 internal/sql/pg/types.go create mode 100644 internal/sql/pg/types_test.go diff --git a/cmd/kwil-admin/cmds/snapshot/create.go b/cmd/kwil-admin/cmds/snapshot/create.go index 8feae6be9..08474643f 100644 --- a/cmd/kwil-admin/cmds/snapshot/create.go +++ b/cmd/kwil-admin/cmds/snapshot/create.go @@ -7,6 +7,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" "io" "math/big" @@ -16,6 +17,7 @@ import ( "strconv" "strings" + "github.com/kwilteam/kwil-db/cmd/common/display" "github.com/kwilteam/kwil-db/common/chain" "github.com/spf13/cobra" @@ -34,8 +36,6 @@ kwil-admin snapshot create --dbname kwildb --user user1 --password pass1 --host # Snapshot and genesis files will be created in the snapshot directory ls /path/to/snapshot/dir genesis.json kwildb-snapshot.sql.gz` - - createDatabase = `CREATE DATABASE ` ) /* @@ -59,10 +59,16 @@ func createCmd() *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { snapshotDir, err := expandPath(snapshotDir) if err != nil { - return fmt.Errorf("failed to expand snapshot directory path: %w", err) + return display.PrintErr(cmd, fmt.Errorf("failed to expand snapshot directory path: %v", err)) } - return pgDump(cmd.Context(), dbName, dbUser, dbPass, dbHost, dbPort, maxRowSize, snapshotDir) + logs, err := pgDump(cmd.Context(), dbName, dbUser, dbPass, dbHost, dbPort, maxRowSize, snapshotDir) + if err != nil { + return display.PrintErr(cmd, fmt.Errorf("failed to create database snapshot: %v", err)) + } + + r := &createSnapshotRes{Logs: logs} + return display.PrintCmd(cmd, r) }, } @@ -77,6 +83,18 @@ func createCmd() *cobra.Command { return cmd } +type createSnapshotRes struct { + Logs []string `json:"logs"` +} + +func (c *createSnapshotRes) MarshalJSON() ([]byte, error) { + return json.Marshal(c) +} + +func (c *createSnapshotRes) MarshalText() (text []byte, err error) { + return []byte(fmt.Sprintf("Snapshot created successfully\n%s", strings.Join(c.Logs, "\n"))), nil +} + func expandPath(path string) (string, error) { if strings.HasPrefix(path, "~") { home, err := os.UserHomeDir() @@ -88,17 +106,19 @@ func expandPath(path string) (string, error) { return filepath.Abs(path) } -func pgDump(ctx context.Context, dbName, dbUser, dbPass, dbHost, dbPort string, maxRowSize int, snapshotDir string) (err error) { +// PGDump uses pg_dump to create a snapshot of the database. +// It returns messages to log and an error if any. +func pgDump(ctx context.Context, dbName, dbUser, dbPass, dbHost, dbPort string, maxRowSize int, snapshotDir string) (logs []string, err error) { // Check if the snapshot directory exists, if not create it err = os.MkdirAll(snapshotDir, 0755) if err != nil { - return fmt.Errorf("failed to create snapshot directory: %w", err) + return nil, fmt.Errorf("failed to create snapshot directory: %w", err) } dumpFile := filepath.Join(snapshotDir, "kwildb-snapshot.sql.gz") outputFile, err := os.Create(dumpFile) if err != nil { - return fmt.Errorf("failed to create dump file: %w", err) + return nil, fmt.Errorf("failed to create dump file: %w", err) } // delete the dump file if an error occurs anywhere during the snapshot process defer func() { @@ -154,12 +174,12 @@ func pgDump(ctx context.Context, dbName, dbUser, dbPass, dbHost, dbPort string, var stderr bytes.Buffer pgDumpOutput, err := pgDumpCmd.StdoutPipe() if err != nil { - return fmt.Errorf("failed to get stdout pipe: %w", err) + return nil, fmt.Errorf("failed to get stdout pipe: %w", err) } pgDumpCmd.Stderr = &stderr if err := pgDumpCmd.Start(); err != nil { - return fmt.Errorf("failed to start pg_dump command: %w", err) + return nil, fmt.Errorf("failed to start pg_dump command: %w", err) } defer pgDumpOutput.Close() @@ -187,7 +207,7 @@ func pgDump(ctx context.Context, dbName, dbUser, dbPass, dbHost, dbPort string, inVotersBlock = false n, err := multiWriter.Write([]byte(line + "\n")) if err != nil { - return fmt.Errorf("failed to write to gzip writer: %w", err) + return nil, fmt.Errorf("failed to write to gzip writer: %w", err) } totalBytes += int64(n) continue @@ -195,16 +215,16 @@ func pgDump(ctx context.Context, dbName, dbUser, dbPass, dbHost, dbPort string, strs := strings.Split(line, "\t") if len(strs) != 3 { - return fmt.Errorf("invalid voter line: %s", line) + return nil, fmt.Errorf("invalid voter line: %s", line) } voterID, err := hex.DecodeString(strs[1][3:]) // Remove the leading \\x if err != nil { - return fmt.Errorf("failed to decode voter ID: %w", err) + return nil, fmt.Errorf("failed to decode voter ID: %w", err) } power, err := strconv.ParseInt(strs[2], 10, 64) if err != nil { - return fmt.Errorf("failed to parse power: %w", err) + return nil, fmt.Errorf("failed to parse power: %w", err) } genCfg.Validators = append(genCfg.Validators, &chain.GenesisValidator{ @@ -221,7 +241,7 @@ func pgDump(ctx context.Context, dbName, dbUser, dbPass, dbHost, dbPort string, } else if strings.HasPrefix(line, "SET") || strings.HasPrefix(line, "SELECT") || strings.HasPrefix(line[1:], "connect") { // Skip SET and SELECT and connect statements continue - } else if strings.HasPrefix(line, createDatabase) { + } else if strings.HasPrefix(line, `CREATE DATABASE `) { // Skip CREATE DATABASE statement } else { if strings.HasPrefix(line, "COPY kwild_voting.voters") && strings.Contains(line, "FROM stdin;") { @@ -231,7 +251,7 @@ func pgDump(ctx context.Context, dbName, dbUser, dbPass, dbHost, dbPort string, // Write the sanitized line to the gzip writer n, err := multiWriter.Write([]byte(line + "\n")) if err != nil { - return fmt.Errorf("failed to write to gzip writer: %w", err) + return nil, fmt.Errorf("failed to write to gzip writer: %w", err) } totalBytes += int64(n) } @@ -239,12 +259,12 @@ func pgDump(ctx context.Context, dbName, dbUser, dbPass, dbHost, dbPort string, } if err := scanner.Err(); err != nil { - return fmt.Errorf("failed to scan pg_dump output: %w", err) + return nil, fmt.Errorf("failed to scan pg_dump output: %w", err) } // Close the writer when pg_dump completes to signal EOF to sed if err := pgDumpCmd.Wait(); err != nil { - return fmt.Errorf(stderr.String()) + return nil, fmt.Errorf(stderr.String()) } gzipWriter.Flush() @@ -254,10 +274,9 @@ func pgDump(ctx context.Context, dbName, dbUser, dbPass, dbHost, dbPort string, // Write the genesis config to a file genesisFile := filepath.Join(snapshotDir, "genesis.json") if err := genCfg.SaveAs(genesisFile); err != nil { - return fmt.Errorf("failed to save genesis config: %w", err) + return nil, fmt.Errorf("failed to save genesis config: %w", err) } - fmt.Println("Snapshot created at: ", dumpFile, " Total bytes written: ", totalBytes) - fmt.Println("Genesis config created at: ", genesisFile, " Genesis hash: ", fmt.Sprintf("%x", hash)) - return nil + return []string{fmt.Sprintf("Snapshot created at: %s, Total bytes written: %d", dumpFile, totalBytes), + fmt.Sprintf("Genesis config created at: %s, Genesis hash: %s", genesisFile, fmt.Sprintf("%x", hash))}, nil } diff --git a/cmd/kwild/server/build.go b/cmd/kwild/server/build.go index c33140231..77109123f 100644 --- a/cmd/kwild/server/build.go +++ b/cmd/kwild/server/build.go @@ -215,7 +215,7 @@ func buildServer(d *coreDependencies, closers *closeFuncs) *Server { rpcSvcLogger := increaseLogLevel("user-json-svc", &d.log, d.cfg.Logging.RPCLevel) rpcServerLogger := increaseLogLevel("user-jsonrpc-server", &d.log, d.cfg.Logging.RPCLevel) - jsonRPCTxSvc := usersvc.NewService(db, e, wrappedCmtClient, txApp, + jsonRPCTxSvc := usersvc.NewService(db, e, wrappedCmtClient, txApp, abciApp, *rpcSvcLogger, usersvc.WithReadTxTimeout(time.Duration(d.cfg.AppCfg.ReadTxTimeout))) jsonRPCServer, err := rpcserver.NewServer(d.cfg.AppCfg.JSONRPCListenAddress, *rpcServerLogger, rpcserver.WithTimeout(time.Duration(d.cfg.AppCfg.RPCTimeout)), @@ -228,7 +228,7 @@ func buildServer(d *coreDependencies, closers *closeFuncs) *Server { // admin service and server signer := buildSigner(d) - jsonAdminSvc := adminsvc.NewService(db, wrappedCmtClient, txApp, signer, d.cfg, + jsonAdminSvc := adminsvc.NewService(db, wrappedCmtClient, txApp, abciApp, signer, d.cfg, d.genesisCfg.ChainID, *d.log.Named("admin-json-svc")) jsonRPCAdminServer := buildJRPCAdminServer(d) jsonRPCAdminServer.RegisterSvc(jsonAdminSvc) @@ -236,7 +236,7 @@ func buildServer(d *coreDependencies, closers *closeFuncs) *Server { jsonRPCAdminServer.RegisterSvc(&funcsvc.Service{}) // legacy tx service and grpc server - txsvc := buildTxSvc(d, db, e, wrappedCmtClient, txApp) + txsvc := buildTxSvc(d, db, e, wrappedCmtClient, txApp, abciApp) grpcServer := buildGrpcServer(d, txsvc) return &Server{ @@ -377,7 +377,7 @@ func buildAbci(d *coreDependencies, db *pg.DB, txApp abci.TxApp, snapshotter *st } func buildEventBroadcaster(d *coreDependencies, ev broadcast.EventStore, b broadcast.Broadcaster, txapp *txapp.TxApp) *broadcast.EventBroadcaster { - return broadcast.NewEventBroadcaster(ev, b, txapp, buildSigner(d), d.genesisCfg.ChainID, d.genesisCfg.ConsensusParams.Votes.MaxVotesPerTx) + return broadcast.NewEventBroadcaster(ev, b, txapp, buildSigner(d), d.genesisCfg.ChainID, d.genesisCfg.ConsensusParams.Votes.MaxVotesPerTx, *d.log.Named("event-broadcaster")) } func buildEventStore(d *coreDependencies, closers *closeFuncs) *voting.EventStore { @@ -396,8 +396,8 @@ func buildEventStore(d *coreDependencies, closers *closeFuncs) *voting.EventStor return e } -func buildTxSvc(d *coreDependencies, db *pg.DB, txsvc txSvc.EngineReader, cometBftClient txSvc.BlockchainTransactor, nodeApp txSvc.NodeApplication) *txSvc.Service { - return txSvc.NewService(db, txsvc, cometBftClient, nodeApp, +func buildTxSvc(d *coreDependencies, db *pg.DB, txsvc txSvc.EngineReader, cometBftClient txSvc.BlockchainTransactor, nodeApp txSvc.NodeApplication, pricer txSvc.Pricer) *txSvc.Service { + return txSvc.NewService(db, txsvc, cometBftClient, nodeApp, pricer, txSvc.WithLogger(*d.log.Named("tx-service")), txSvc.WithReadTxTimeout(time.Duration(d.cfg.AppCfg.ReadTxTimeout)), ) diff --git a/common/common.go b/common/common.go index 708c38514..c7108a2aa 100644 --- a/common/common.go +++ b/common/common.go @@ -132,14 +132,13 @@ type NetworkParameters struct { VoteExpiry int64 // DisabledGasCosts dictates whether gas costs are disabled. DisabledGasCosts bool -} - -// Copy returns a deep copy of the network parameters. -func (n *NetworkParameters) Copy() *NetworkParameters { - return &NetworkParameters{ - MaxBlockSize: n.MaxBlockSize, - JoinExpiry: n.JoinExpiry, - VoteExpiry: n.VoteExpiry, - DisabledGasCosts: n.DisabledGasCosts, - } + // InMigration is true if the network is being migrated to a new network. + // Once this is set to true, it can never be set to false. If true, + // new databases cannot be created, old databases cannot be deleted, + // balances cannot be transferred + // and the vote store is paused. + InMigration bool + // MaxVotesPerTx is the maximum number of votes that can be included in a + // single transaction. + MaxVotesPerTx int64 } diff --git a/common/sql/sql.go b/common/sql/sql.go index e5c7d24d9..5a41ee610 100644 --- a/common/sql/sql.go +++ b/common/sql/sql.go @@ -5,6 +5,7 @@ package sql import ( "context" "errors" + "io" ) var ( @@ -58,15 +59,16 @@ type Tx interface { // create transactions, which may be closed or create additional nested // transactions. // -// Some implementations may also be an OuterTxMaker and/or a ReadTxMaker. Embed +// Some implementations may also be an PreparedTxMaker and/or a ReadTxMaker. Embed // with those interfaces to compose the minimal interface required. type DB interface { Executor TxMaker } -// ReadTxMaker can make read-only transactions. -// Many read-only transactions can be made at once. +// ReadTxMaker can make read-only transactions. This is necessarily an outermost +// transaction since nested transactions inherit their access mode from their +// parent. Many read-only transactions can be made at once. type ReadTxMaker interface { BeginReadTx(ctx context.Context) (Tx, error) } @@ -78,22 +80,25 @@ type DelayedReadTxMaker interface { BeginDelayedReadTx() Tx } -// OuterTx is the outermost database transaction. +// PreparedTx is an outermost database transaction that uses two-phase commit +// with the Precommit method. // -// NOTE: An OuterTx may be used where only a Tx or DB is required since those -// interfaces are a subset of the OuterTx method set. -type OuterTx interface { +// NOTE: A PreparedTx may be used where only a Tx or DB is required since those +// interfaces are a subset of the PreparedTx method set. +// It takes a writer to write the full changeset to. +// If the writer is nil, the changeset will not be written. +type PreparedTx interface { Tx - Precommit(ctx context.Context) ([]byte, error) + Precommit(ctx context.Context, writer io.Writer) ([]byte, error) } -// OuterTxMaker is the special kind of transaction that creates a transaction -// that has a Precommit method (see OuterTx), which supports obtaining a commit +// PreparedTxMaker is the special kind of transaction that creates a transaction +// that has a Precommit method (see PreparedTx), which supports obtaining a commit // ID using a (two-phase) prepared transaction prior to Commit. This is a -// different method name so that an implementation may satisfy both OuterTxMaker +// different method name so that an implementation may satisfy both PreparedTxMaker // and TxMaker. -type OuterTxMaker interface { - BeginOuterTx(ctx context.Context) (OuterTx, error) +type PreparedTxMaker interface { + BeginPreparedTx(ctx context.Context) (PreparedTx, error) } // SnapshotTxMaker is an interface that creates a transaction for taking a diff --git a/core/types/schema.go b/core/types/schema.go index 58c2ce611..e6319d0fc 100644 --- a/core/types/schema.go +++ b/core/types/schema.go @@ -1254,21 +1254,35 @@ var ( IntType = &DataType{ Name: intStr, } - TextType = &DataType{ + IntArrayType = ArrayType(IntType) + TextType = &DataType{ Name: textStr, } - BoolType = &DataType{ + TextArrayType = ArrayType(TextType) + BoolType = &DataType{ Name: boolStr, } - BlobType = &DataType{ + BoolArrayType = ArrayType(BoolType) + BlobType = &DataType{ Name: blobStr, } - UUIDType = &DataType{ + BlobArrayType = ArrayType(BlobType) + UUIDType = &DataType{ Name: uuidStr, } - Uint256Type = &DataType{ + UUIDArrayType = ArrayType(UUIDType) + // DecimalType contains 1,0 metadata. + // For type detection, users should prefer compare a datatype + // name with the DecimalStr constant. + DecimalType = &DataType{ + Name: DecimalStr, + Metadata: [2]uint16{1, 0}, // the minimum precision and scale + } + DecimalArrayType = ArrayType(DecimalType) + Uint256Type = &DataType{ Name: uint256Str, } + Uint256ArrayType = ArrayType(Uint256Type) // NullType is a special type used internally NullType = &DataType{ Name: nullStr, diff --git a/core/types/transactions/result.go b/core/types/transactions/result.go index 743ff4a39..bf8fd930c 100644 --- a/core/types/transactions/result.go +++ b/core/types/transactions/result.go @@ -44,6 +44,8 @@ const ( CodeDatasetMissing TxCode = 110 CodeDatasetExists TxCode = 120 + CodeNetworkInMigration TxCode = 200 + CodeUnknownError TxCode = math.MaxUint32 ) @@ -71,6 +73,18 @@ func (c TxCode) String() string { return "insufficient fee" case CodeInvalidAmount: return "invalid amount" + case CodeInvalidSender: + return "invalid sender" + case CodeInvalidSchema: + return "invalid schema" + case CodeDatasetMissing: + return "dataset missing" + case CodeDatasetExists: + return "dataset exists" + case CodeNetworkInMigration: + return "network in migration" + case CodeUnknownError: + return "unknown error" default: return "unknown tx error" } diff --git a/internal/abci/abci.go b/internal/abci/abci.go index b7dc1b906..994bcdfb5 100644 --- a/internal/abci/abci.go +++ b/internal/abci/abci.go @@ -35,6 +35,7 @@ import ( abciTypes "github.com/cometbft/cometbft/abci/types" "github.com/cometbft/cometbft/crypto/ed25519" + "github.com/cometbft/cometbft/crypto/tmhash" "go.uber.org/zap" ) @@ -68,7 +69,7 @@ func NewAbciApp(ctx context.Context, cfg *AbciConfig, snapshotter SnapshotModule } app.forks.FromMap(cfg.ForkHeights) - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginTx(ctx) if err != nil { return nil, fmt.Errorf("failed to begin outer tx: %w", err) } @@ -152,6 +153,7 @@ func NewAbciApp(ctx context.Context, cfg *AbciConfig, snapshotter SnapshotModule JoinExpiry: app.consensusParams.Validator.JoinExpiry, VoteExpiry: app.consensusParams.Votes.VoteExpiry, DisabledGasCosts: app.consensusParams.WithoutGasCosts, + MaxVotesPerTx: app.consensusParams.Votes.MaxVotesPerTx, }) if err != nil { return nil, fmt.Errorf("failed to store network params: %w", err) @@ -162,6 +164,7 @@ func NewAbciApp(ctx context.Context, cfg *AbciConfig, snapshotter SnapshotModule JoinExpiry: app.consensusParams.Validator.JoinExpiry, VoteExpiry: app.consensusParams.Votes.VoteExpiry, DisabledGasCosts: app.consensusParams.WithoutGasCosts, + MaxVotesPerTx: app.consensusParams.Votes.MaxVotesPerTx, } } else if err != nil { return nil, fmt.Errorf("failed to load network params: %w", err) @@ -213,9 +216,9 @@ type AbciApp struct { db DB // consensusTx is the outermost transaction that wraps all other transactions // that can modify state. It should be set in FinalizeBlock and committed in Commit. - consensusTx sql.OuterTx + consensusTx sql.PreparedTx // genesisTx is the transaction that is used at genesis, and in the first block. - genesisTx sql.OuterTx + genesisTx sql.PreparedTx // appHash is the hash of the application state appHash []byte // height is the current block height @@ -365,7 +368,15 @@ func (a *AbciApp) CheckTx(ctx context.Context, incoming *abciTypes.RequestCheckT } defer readTx.Rollback(ctx) // always rollback since we are read-only - err = a.txApp.ApplyMempool(ctx, readTx, tx) + err = a.txApp.ApplyMempool(&common.TxContext{ + Ctx: ctx, + BlockContext: &common.BlockContext{ + ChainContext: a.chainContext, + Height: a.height + 1, // height increments at the start of FinalizeBlock, + Proposer: nil, // we don't know the proposer here + }, + TxID: cometTXID(incoming.Tx), + }, readTx, tx) if err != nil { if errors.Is(err, transactions.ErrInvalidNonce) { code = codeInvalidNonce @@ -398,6 +409,11 @@ func (a *AbciApp) CheckTx(ctx context.Context, incoming *abciTypes.RequestCheckT return &abciTypes.ResponseCheckTx{Code: code.Uint32()}, nil } +// cometTXID gets the cometbft transaction ID. +func cometTXID(tx []byte) []byte { + return tmhash.Sum(tx) +} + // FinalizeBlock is on the consensus connection. Note that according to CometBFT // docs, "ResponseFinalizeBlock.app_hash is included as the Header.AppHash in // the next block." @@ -412,7 +428,7 @@ func (a *AbciApp) FinalizeBlock(ctx context.Context, req *abciTypes.RequestFinal a.genesisTx = nil } else { var err error - a.consensusTx, err = a.db.BeginOuterTx(ctx) + a.consensusTx, err = a.db.BeginPreparedTx(ctx) if err != nil { return nil, fmt.Errorf("begin outer tx failed: %w", err) } @@ -431,7 +447,7 @@ func (a *AbciApp) FinalizeBlock(ctx context.Context, req *abciTypes.RequestFinal VoteExpiry: a.consensusParams.Votes.VoteExpiry, DisabledGasCosts: a.consensusParams.WithoutGasCosts, } - oldNetworkParams := networkParams.Copy() + oldNetworkParams := *networkParams initialValidators, err := a.txApp.GetValidators(ctx, a.consensusTx) if err != nil { @@ -531,7 +547,7 @@ func (a *AbciApp) FinalizeBlock(ctx context.Context, req *abciTypes.RequestFinal // Broadcast any events that have not been broadcasted yet if a.broadcastFn != nil && len(proposerPubKey) > 0 { - err := a.broadcastFn(ctx, a.consensusTx, proposerPubKey) + err := a.broadcastFn(ctx, a.consensusTx, &blockCtx) if err != nil { return nil, fmt.Errorf("failed to broadcast events: %w", err) } @@ -545,7 +561,7 @@ func (a *AbciApp) FinalizeBlock(ctx context.Context, req *abciTypes.RequestFinal } // store any changes to the network params - err = meta.StoreDiff(ctx, a.consensusTx, oldNetworkParams, networkParams) + err = meta.StoreDiff(ctx, a.consensusTx, &oldNetworkParams, networkParams) if err != nil { return nil, fmt.Errorf("failed to store network params diff: %w", err) } @@ -562,7 +578,7 @@ func (a *AbciApp) FinalizeBlock(ctx context.Context, req *abciTypes.RequestFinal } // we now get the apphash by calling precommit on the transaction - appHash, err := a.consensusTx.Precommit(ctx) + appHash, err := a.consensusTx.Precommit(ctx, nil) if err != nil { return nil, fmt.Errorf("failed to precommit transaction: %w", err) } @@ -650,7 +666,7 @@ func (a *AbciApp) Commit(ctx context.Context, _ *abciTypes.RequestCommit) (*abci // opening a new transaction. This could leave us in a state where data is // committed but the apphash is not, which would essentially nuke the chain. ctx0 := context.Background() // badly timed shutdown MUST NOT cancel now, we need consistency with consensus tx commit - tx, err := a.db.BeginOuterTx(ctx0) + tx, err := a.db.BeginTx(ctx0) if err != nil { return nil, fmt.Errorf("failed to begin outer tx: %w", err) } @@ -743,7 +759,7 @@ func (a *AbciApp) InitChain(ctx context.Context, req *abciTypes.RequestInitChain logger.Debug("", zap.String("ChainId", req.ChainId)) // maybe verify a.cfg.ChainID against the one in the request var err error - a.genesisTx, err = a.db.BeginOuterTx(ctx) + a.genesisTx, err = a.db.BeginPreparedTx(ctx) if err != nil { return nil, fmt.Errorf("begin outer tx failed: %w", err) } @@ -796,14 +812,14 @@ func (a *AbciApp) InitChain(ctx context.Context, req *abciTypes.RequestInitChain return nil, fmt.Errorf("expected NULL app hash, got %x", appHash) } - startParams := a.chainContext.NetworkParameters.Copy() + startParams := *a.chainContext.NetworkParameters if err := a.txApp.GenesisInit(ctx, a.genesisTx, vldtrs, genesisAllocs, req.InitialHeight, a.chainContext); err != nil { return nil, fmt.Errorf("txApp.GenesisInit failed: %w", err) } // persist any diff to the network params - err = meta.StoreDiff(ctx, a.genesisTx, startParams, a.chainContext.NetworkParameters) + err = meta.StoreDiff(ctx, a.genesisTx, &startParams, a.chainContext.NetworkParameters) if err != nil { return nil, fmt.Errorf("failed to store network params diff: %w", err) } @@ -1349,7 +1365,7 @@ func (a *AbciApp) Query(ctx context.Context, req *abciTypes.RequestQuery) (*abci return &abciTypes.ResponseQuery{}, nil } -type EventBroadcaster func(ctx context.Context, db sql.DB, proposer []byte) error +type EventBroadcaster func(ctx context.Context, db sql.DB, block *common.BlockContext) error func (a *AbciApp) SetEventBroadcaster(fn EventBroadcaster) { a.broadcastFn = fn @@ -1407,3 +1423,11 @@ func (a *AbciApp) Close() error { } return nil } + +// Price estimates the price for a transaction. +// Consumers who do not have information about the current chain parameters / +// who wanmt a guarantee that they have the most up-to-date parameters without +// reading from the DB can use this method. +func (a *AbciApp) Price(ctx context.Context, db sql.DB, tx *transactions.Transaction) (*big.Int, error) { + return a.txApp.Price(ctx, db, tx, a.chainContext) +} diff --git a/internal/abci/abci_test.go b/internal/abci/abci_test.go index 36dcdb221..096342c6e 100644 --- a/internal/abci/abci_test.go +++ b/internal/abci/abci_test.go @@ -3,6 +3,7 @@ package abci import ( "bytes" "context" + "io" "math/big" "testing" @@ -396,7 +397,7 @@ func (m *mockTxApp) AccountInfo(ctx context.Context, db sql.DB, acctID []byte, g return big.NewInt(0), 0, nil } -func (m *mockTxApp) ApplyMempool(ctx context.Context, db sql.DB, tx *transactions.Transaction) error { +func (m *mockTxApp) ApplyMempool(ctx *common.TxContext, db sql.DB, tx *transactions.Transaction) error { return nil } @@ -435,9 +436,13 @@ func (m *mockTxApp) Reload(ctx context.Context, db sql.DB) error { return nil } +func (m *mockTxApp) Price(ctx context.Context, db sql.DB, tx *transactions.Transaction, c *common.ChainContext) (*big.Int, error) { + return big.NewInt(0), nil +} + type mockDB struct{} -func (m *mockDB) BeginOuterTx(ctx context.Context) (sql.OuterTx, error) { +func (m *mockDB) BeginPreparedTx(ctx context.Context) (sql.PreparedTx, error) { return &mockTx{}, nil } @@ -477,6 +482,6 @@ func (m *mockTx) BeginTx(ctx context.Context) (sql.Tx, error) { return &mockTx{}, nil } -func (m *mockTx) Precommit(ctx context.Context) ([]byte, error) { +func (m *mockTx) Precommit(ctx context.Context, w io.Writer) ([]byte, error) { return nil, nil } diff --git a/internal/abci/interfaces.go b/internal/abci/interfaces.go index 377874a0b..26700ee80 100644 --- a/internal/abci/interfaces.go +++ b/internal/abci/interfaces.go @@ -46,7 +46,7 @@ type StateSyncModule interface { // and managing a mempool type TxApp interface { AccountInfo(ctx context.Context, db sql.DB, acctID []byte, getUnconfirmed bool) (balance *big.Int, nonce int64, err error) - ApplyMempool(ctx context.Context, db sql.DB, tx *transactions.Transaction) error + ApplyMempool(ctx *common.TxContext, db sql.DB, tx *transactions.Transaction) error Begin(ctx context.Context, height int64) error Commit(ctx context.Context) Execute(ctx txapp.TxContext, db sql.DB, tx *transactions.Transaction) *txapp.TxResponse @@ -56,6 +56,7 @@ type TxApp interface { ProposerTxs(ctx context.Context, db sql.DB, txNonce uint64, maxTxsSize int64, block *common.BlockContext) ([][]byte, error) Reload(ctx context.Context, db sql.DB) error UpdateValidator(ctx context.Context, db sql.DB, validator []byte, power int64) error + Price(ctx context.Context, db sql.DB, tx *transactions.Transaction, chainCtx *common.ChainContext) (*big.Int, error) } // ConsensusParams returns kwil specific consensus parameters. @@ -72,7 +73,8 @@ type ConsensusParams interface { // from within a transaction. A DB can create read transactions or the special // two-phase outer write transaction. type DB interface { - sql.OuterTxMaker + sql.TxMaker // for out-of-consensus writes e.g. setup and meta table writes + sql.PreparedTxMaker sql.ReadTxMaker sql.SnapshotTxMaker } diff --git a/internal/abci/meta/meta.go b/internal/abci/meta/meta.go index 6681ef1fd..b51ff0747 100644 --- a/internal/abci/meta/meta.go +++ b/internal/abci/meta/meta.go @@ -155,6 +155,22 @@ func StoreParams(ctx context.Context, db sql.TxMaker, params *common.NetworkPara return err } + buf = make([]byte, 1) + if params.InMigration { + buf[0] = 1 + } + _, err = tx.Execute(ctx, upsertParam, inMigration, buf) + if err != nil { + return err + } + + buf = make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(params.MaxVotesPerTx)) + _, err = tx.Execute(ctx, upsertParam, maxVotesPerTx, buf) + if err != nil { + return err + } + return tx.Commit(ctx) } @@ -195,8 +211,8 @@ func LoadParams(ctx context.Context, db sql.Executor) (*common.NetworkParameters return nil, ErrParamsNotFound } - if len(res.Rows) != 4 { - return nil, fmt.Errorf("expected four rows, got %d", len(res.Rows)) + if len(res.Rows) != 6 { + return nil, fmt.Errorf("internal bug: expected 6 rows, got %d", len(res.Rows)) } params := &common.NetworkParameters{} @@ -224,6 +240,10 @@ func LoadParams(ctx context.Context, db sql.Executor) (*common.NetworkParameters params.VoteExpiry = int64(binary.LittleEndian.Uint64(value)) case disabledGasKey: params.DisabledGasCosts = value[0] == 1 + case inMigration: + params.InMigration = value[0] == 1 + case maxVotesPerTx: + params.MaxVotesPerTx = int64(binary.LittleEndian.Uint64(value)) default: return nil, fmt.Errorf("internal bug: unknown param name: %s", param) } @@ -261,6 +281,20 @@ func diff(original, new *common.NetworkParameters) map[string][]byte { d[disabledGasKey] = buf } + if original.InMigration != new.InMigration { + buf := make([]byte, 1) + if new.InMigration { + buf[0] = 1 + } + d[inMigration] = buf + } + + if original.MaxVotesPerTx != new.MaxVotesPerTx { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(new.MaxVotesPerTx)) + d[maxVotesPerTx] = buf + } + return d } @@ -269,4 +303,6 @@ const ( joinExpiryKey = `join_expiry` voteExpiryKey = `vote_expiry` disabledGasKey = `disabled_gas_costs` + inMigration = `in_migration` + maxVotesPerTx = `max_votes_per_tx` ) diff --git a/internal/abci/meta/meta_test.go b/internal/abci/meta/meta_test.go index 85ef132e3..efe7586e7 100644 --- a/internal/abci/meta/meta_test.go +++ b/internal/abci/meta/meta_test.go @@ -30,7 +30,7 @@ func Test_NetworkParams(t *testing.T) { require.NoError(t, err) defer db.Close() - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginTx(ctx) require.NoError(t, err) defer tx.Rollback(ctx) // always rollback to reset the test @@ -48,6 +48,7 @@ func Test_NetworkParams(t *testing.T) { JoinExpiry: 100, VoteExpiry: 100, DisabledGasCosts: true, + MaxVotesPerTx: 100, } err = meta.StoreParams(ctx, tx, param) @@ -62,6 +63,7 @@ func Test_NetworkParams(t *testing.T) { param2.MaxBlockSize = 2000 param2.JoinExpiry = 200 param2.DisabledGasCosts = false + param2.InMigration = true err = meta.StoreDiff(ctx, tx, param, param2) require.NoError(t, err) diff --git a/internal/engine/execution/procedure.go b/internal/engine/execution/procedure.go index b53485d65..b50964fd8 100644 --- a/internal/engine/execution/procedure.go +++ b/internal/engine/execution/procedure.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "maps" - "reflect" "strings" "github.com/kwilteam/kwil-db/common" @@ -398,64 +397,65 @@ func makeExecutables(params []*generate.InlineExpression) []evaluatable { if record[0] == nil { return nil, nil } + // TODO: I am currently making changes to PG that will remove the need for this (I think) // there is an edge case here where if the value is an array, it needs to be of the exact array type. // For example, pgx only understands []string, and not []any, however it will return arrays to us as // []any. If the returned type here is an array, we need to convert it to an array of the correct type. - typeOf := reflect.TypeOf(record[0]) - if typeOf.Kind() == reflect.Slice && typeOf.Elem().Kind() != reflect.Uint8 { - // if it is an array, we need to convert it to the correct type. - // if of length 0, we can simply set it to a text array - if len(record[0].([]any)) == 0 { - return []string{}, nil - } - - switch v := record[0].([]any)[0].(type) { - case string: - textArr := make([]string, len(record[0].([]any))) - for i, val := range record[0].([]any) { - textArr[i] = val.(string) - } - return textArr, nil - case int64: - intArr := make([]int64, len(record[0].([]any))) - for i, val := range record[0].([]any) { - intArr[i] = val.(int64) - } - return intArr, nil - case []byte: - blobArr := make([][]byte, len(record[0].([]any))) - for i, val := range record[0].([]any) { - blobArr[i] = val.([]byte) - } - return blobArr, nil - case bool: - boolArr := make([]bool, len(record[0].([]any))) - for i, val := range record[0].([]any) { - boolArr[i] = val.(bool) - } - return boolArr, nil - case *types.UUID: - uuidArr := make(types.UUIDArray, len(record[0].([]any))) - for i, val := range record[0].([]any) { - uuidArr[i] = val.(*types.UUID) - } - return uuidArr, nil - case *types.Uint256: - uint256Arr := make(types.Uint256Array, len(record[0].([]any))) - for i, val := range record[0].([]any) { - uint256Arr[i] = val.(*types.Uint256) - } - return uint256Arr, nil - case *decimal.Decimal: - decArr := make(decimal.DecimalArray, len(record[0].([]any))) - for i, val := range record[0].([]any) { - decArr[i] = val.(*decimal.Decimal) - } - return decArr, nil - default: - return nil, fmt.Errorf("unsupported in-line array type %T", v) - } - } + // typeOf := reflect.TypeOf(record[0]) + // if typeOf.Kind() == reflect.Slice && typeOf.Elem().Kind() != reflect.Uint8 { + // // if it is an array, we need to convert it to the correct type. + // // if of length 0, we can simply set it to a text array + // if len(record[0].([]any)) == 0 { + // return []string{}, nil + // } + + // switch v := record[0].([]any)[0].(type) { + // case string: + // textArr := make([]string, len(record[0].([]any))) + // for i, val := range record[0].([]any) { + // textArr[i] = val.(string) + // } + // return textArr, nil + // case int64: + // intArr := make([]int64, len(record[0].([]any))) + // for i, val := range record[0].([]any) { + // intArr[i] = val.(int64) + // } + // return intArr, nil + // case []byte: + // blobArr := make([][]byte, len(record[0].([]any))) + // for i, val := range record[0].([]any) { + // blobArr[i] = val.([]byte) + // } + // return blobArr, nil + // case bool: + // boolArr := make([]bool, len(record[0].([]any))) + // for i, val := range record[0].([]any) { + // boolArr[i] = val.(bool) + // } + // return boolArr, nil + // case *types.UUID: + // uuidArr := make(types.UUIDArray, len(record[0].([]any))) + // for i, val := range record[0].([]any) { + // uuidArr[i] = val.(*types.UUID) + // } + // return uuidArr, nil + // case *types.Uint256: + // uint256Arr := make(types.Uint256Array, len(record[0].([]any))) + // for i, val := range record[0].([]any) { + // uint256Arr[i] = val.(*types.Uint256) + // } + // return uint256Arr, nil + // case *decimal.Decimal: + // decArr := make(decimal.DecimalArray, len(record[0].([]any))) + // for i, val := range record[0].([]any) { + // decArr[i] = val.(*decimal.Decimal) + // } + // return decArr, nil + // default: + // return nil, fmt.Errorf("unsupported in-line array type %T", v) + // } + // } return record[0], nil }) @@ -600,7 +600,7 @@ func (p *preparedProcedure) shapeReturn(result *sql.ResultSet) error { continue } - arr, ok := row[i].([]any) + arr, ok := row[i].(decimal.DecimalArray) if !ok { return fmt.Errorf("shapeReturn: expected decimal array, got %T", row[i]) } @@ -609,11 +609,7 @@ func (p *preparedProcedure) shapeReturn(result *sql.ResultSet) error { if v == nil { continue } - dec, ok := v.(*decimal.Decimal) - if !ok { - return fmt.Errorf("shapeReturn: expected decimal, got %T", dec) - } - err := dec.SetPrecisionAndScale(col.Type.Metadata[0], col.Type.Metadata[1]) + err := v.SetPrecisionAndScale(col.Type.Metadata[0], col.Type.Metadata[1]) if err != nil { return err } diff --git a/internal/engine/integration/deployment_test.go b/internal/engine/integration/deployment_test.go index 689be2d60..9190ed4be 100644 --- a/internal/engine/integration/deployment_test.go +++ b/internal/engine/integration/deployment_test.go @@ -206,7 +206,7 @@ func Test_Deployment(t *testing.T) { ctx := context.Background() - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginTx(ctx) require.NoError(t, err) defer tx.Rollback(ctx) diff --git a/internal/engine/integration/execution_test.go b/internal/engine/integration/execution_test.go index 3d556d8f6..9681d7683 100644 --- a/internal/engine/integration/execution_test.go +++ b/internal/engine/integration/execution_test.go @@ -5,6 +5,7 @@ package integration_test import ( "context" "fmt" + "os" "testing" "github.com/kwilteam/kwil-db/common" @@ -583,13 +584,14 @@ func Test_Engine(t *testing.T) { ctx := context.Background() - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginPreparedTx(ctx) require.NoError(t, err) defer tx.Rollback(ctx) test.ses1(t, global, tx) - id, err := tx.Precommit(ctx) // not needed, but test how txApp would use the engine + w := os.Stdout + id, err := tx.Precommit(ctx, w) // not needed, but test how txApp would use the engine require.NoError(t, err) require.NotEmpty(t, id) diff --git a/internal/engine/integration/procedure_test.go b/internal/engine/integration/procedure_test.go index 11dc9cba5..c0a55d6fc 100644 --- a/internal/engine/integration/procedure_test.go +++ b/internal/engine/integration/procedure_test.go @@ -61,7 +61,7 @@ func Test_Procedures(t *testing.T) { } `, inputs: []any{[]int64{1, 2, 3}}, - outputs: [][]any{{[]any{int64(2), int64(4), int64(6)}}}, // returns 1 row, 1 column, with an array of ints + outputs: [][]any{{[]int64{int64(2), int64(4), int64(6)}}}, // returns 1 row, 1 column, with an array of ints }, { name: "is (null)", @@ -277,62 +277,62 @@ func Test_Procedures(t *testing.T) { $c := $a/$b; return [$a, $b, $c]; }`, - outputs: [][]any{{[]any{mustDecimal("2.5", 2, 1), mustDecimal("3.5", 2, 1), mustDecimal("0.7", 2, 1)}}}, + outputs: [][]any{{decimal.DecimalArray{mustDecimal("2.5", 2, 1), mustDecimal("3.5", 2, 1), mustDecimal("0.7", 2, 1)}}}, }, { name: "decimal", procedure: `procedure d() public view { - $i := 100.423; - $j decimal(16,8) := 46728954.23743892; - $k := $i::decimal(16,8) + $j; - if $k != 46729054.66043892 { - error('decimal failed'); - } - if $k::text != '46729054.66043892' { - error('decimal text failed'); - } - if ($k::decimal(16,2))::text != '46729054.66' { - error('decimal 2 failed'); - } - }`, + $i := 100.423; + $j decimal(16,8) := 46728954.23743892; + $k := $i::decimal(16,8) + $j; + if $k != 46729054.66043892 { + error('decimal failed'); + } + if $k::text != '46729054.66043892' { + error('decimal text failed'); + } + if ($k::decimal(16,2))::text != '46729054.66' { + error('decimal 2 failed'); + } + }`, }, { name: "early empty return", procedure: `procedure return_early() public view { - $exit := true; - if $exit { - return; - } - error('should not reach here'); - }`, + $exit := true; + if $exit { + return; + } + error('should not reach here'); + }`, }, { name: "private procedure", procedure: `procedure private_proc() private view { - error('should not reach here'); - }`, + error('should not reach here'); + }`, err: execution.ErrPrivate, }, { name: "owner procedure - success", procedure: `procedure owner_proc() public owner view returns (is_owner bool) { - return true; - }`, + return true; + }`, outputs: [][]any{{true}}, }, { name: "owner procedure - fail", procedure: `procedure owner_proc() public owner view returns (is_owner bool) { - return false; - }`, + return false; + }`, err: execution.ErrOwnerOnly, caller: "some_other_wallet", }, { name: "mutative procedure in read-only tx", procedure: `procedure mutative() public { - return; - }`, + return; + }`, err: execution.ErrMutativeProcedure, readOnly: true, }, @@ -352,7 +352,7 @@ func Test_Procedures(t *testing.T) { $arr[2] := 4; return $arr; }`, - outputs: [][]any{{[]any{int64(1), int64(4), int64(3)}}}, + outputs: [][]any{{[]int64{int64(1), int64(4), int64(3)}}}, }, } @@ -366,7 +366,7 @@ func Test_Procedures(t *testing.T) { ctx := context.Background() - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginTx(ctx) require.NoError(t, err) defer tx.Rollback(ctx) @@ -616,7 +616,7 @@ func Test_ForeignProcedures(t *testing.T) { ctx := context.Background() - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginTx(ctx) require.NoError(t, err) defer tx.Rollback(ctx) diff --git a/internal/engine/integration/schema_test.go b/internal/engine/integration/schema_test.go index 4084623b5..392411aaa 100644 --- a/internal/engine/integration/schema_test.go +++ b/internal/engine/integration/schema_test.go @@ -213,7 +213,7 @@ func Test_Schemas(t *testing.T) { ctx := context.Background() - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginTx(ctx) require.NoError(t, err) defer tx.Rollback(ctx) diff --git a/internal/engine/integration/sql_test.go b/internal/engine/integration/sql_test.go index 7abf8aae3..904650364 100644 --- a/internal/engine/integration/sql_test.go +++ b/internal/engine/integration/sql_test.go @@ -212,7 +212,7 @@ func Test_SQL(t *testing.T) { ctx := context.Background() - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginTx(ctx) require.NoError(t, err) defer tx.Rollback(ctx) diff --git a/internal/migrations/interfaces.go b/internal/migrations/interfaces.go new file mode 100644 index 000000000..dd56746bd --- /dev/null +++ b/internal/migrations/interfaces.go @@ -0,0 +1,18 @@ +package migrations + +import ( + "context" + + "github.com/kwilteam/kwil-db/common/sql" + "github.com/kwilteam/kwil-db/internal/statesync" +) + +type Snapshotter interface { + CreateSnapshot(ctx context.Context, height uint64, snapshotID string) error + LoadSnapshotChunk(height uint64, format uint32, chunkIdx uint32) ([]byte, error) + ListSnapshots() []*statesync.Snapshot +} + +type Database interface { + sql.SnapshotTxMaker +} diff --git a/internal/migrations/migrations.go b/internal/migrations/migrations.go new file mode 100644 index 000000000..dbcea6eeb --- /dev/null +++ b/internal/migrations/migrations.go @@ -0,0 +1,103 @@ +// package migrations implements a long-running migrations protocol for Kwil. +// This allows networks to upgrade to new networks over long periods of time, +// without any downtime. +// +// The process is as follows: +// +// 1. A network votes to create a new network. If enough votes are attained, the process is started. +// +// 2. Once the process is started, each validator should create a new node to run the new network, which will +// connect to their current node. This new node will forward all changes from the old network to the new network. +// +// 3. The two networks will run in parallel until the old network reaches the scheduled shutdown block. At this point, +// the new network will take over and the old network will be shut down. +// +// The old network cannot deploy databases, drop them, transfer balances, vote on any resolutions, or change their validator power. +// +// For more information on conflict resolution, see https://github.com/kwilteam/kwil-db/wiki/Long%E2%80%90Running-Network-Migrations +package migrations + +import ( + "context" + "fmt" + "math/big" + + "github.com/kwilteam/kwil-db/common" + "github.com/kwilteam/kwil-db/core/types/serialize" + "github.com/kwilteam/kwil-db/extensions/resolutions" +) + +// MigrationDeclaration creates a new migration. It is used to agree on terms of a migration, +// and is voted on using Kwil's vote store. +type MigrationDeclaration struct { + // ActivationPeriod is the amount of blocks before the migration is activated. + // It starts after the migration is approved via the voting system. + // The intention is to allow validators to prepare for the migration. + ActivationPeriod int64 + // Duration is the amount of blocks the migration will take to complete. + Duration int64 + // ChainID is the new chain ID that the network will migrate to. + // A new chain ID should always be used for a new network, to avoid + // cross-network replay attacks. + ChainID string + // Timestamp is the time the migration was created. It is set by the migration + // creator. The primary purpose of it is to guarantee uniqueness of the serialized + // MigrationDeclaration, since that is a requirement for the voting system. + Timestamp string +} + +// MarshalBinary marshals the MigrationDeclaration into a binary format. +func (md *MigrationDeclaration) MarshalBinary() ([]byte, error) { + return serialize.Encode(md) +} + +// UnmarshalBinary unmarshals the MigrationDeclaration from a binary format. +func (md *MigrationDeclaration) UnmarshalBinary(data []byte) error { + return serialize.Decode(data, md) +} + +// MigrationResolution is the definition for the network migration vote type in Kwil's +// voting system. +var MigrationResolution = resolutions.ResolutionConfig{ + ConfirmationThreshold: big.NewRat(2, 3), + ExpirationPeriod: 100800, // 1 week + ResolveFunc: func(ctx context.Context, app *common.App, resolution *resolutions.Resolution, block *common.BlockContext) error { + // The resolve func is responsible for: + // - Pausing all deploys and drops + // - Pausing all validator transactions + // - Pausing all votes + alreadyHasMigration, err := migrationActive(ctx, app.DB) + if err != nil { + return err + } + + if alreadyHasMigration { + return fmt.Errorf("failed to start migration: only one migration can be active at a time") + } + + mig := &MigrationDeclaration{} + if err := mig.UnmarshalBinary(resolution.Body); err != nil { + return err + } + + // the start height for the migration is whatever the height the migration + // resolution passed + the activation period, which allows validators to prepare + // for the migration. End height is the same, + the duration of the migration. + active := &activeMigration{ + StartHeight: block.Height + mig.ActivationPeriod, + EndHeight: block.Height + mig.ActivationPeriod + mig.Duration, + ChainID: mig.ChainID, + } + + err = createMigration(ctx, app.DB, active) + if err != nil { + return err + } + + block.ChainContext.NetworkParameters.InMigration = true + + // TODO: there are certainly other things we need to do on activation. I am unsure how to handle this. + // For example, we need to snapshot the network at the activation block. + return nil + }, +} diff --git a/internal/migrations/migrator.go b/internal/migrations/migrator.go new file mode 100644 index 000000000..c92ddc5ed --- /dev/null +++ b/internal/migrations/migrator.go @@ -0,0 +1,249 @@ +package migrations + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/kwilteam/kwil-db/common" + "github.com/kwilteam/kwil-db/core/log" + "github.com/kwilteam/kwil-db/internal/statesync" +) + +// migrator is responsible for managing the migrations. +// It is responsible for tracking any in-process migrations, snapshotting +// at the appropriate height, persisting changesets for the migration for each +// block as it occurs, and making that data available via RPC for the new node. +// Similarly, if the local process is the new node, it is responsible for reading +// changesets from the external node and applying them to the local database. +type Migrator struct { + // mu is the mutex for the migrator. + mu sync.RWMutex + // activeMigration is the migration that is currently in progress. + // It is nil if there is no migration in progress. + activeMigration *activeMigration + // snapshotter creates snapshots of the state. + snapshotter Snapshotter + // DB is a connection to the database. + // It should connect to the same Postgres database as kwild, + // but should be a different connection pool. + DB Database + + // lastChangeset is the height of the last changeset that was stored. + // If no changesets have been stored, it is -1. + lastChangeset int64 + + // Logger is the logger for the migrator. + Logger log.Logger + + // dir is the directory where the migration data is stored. + // It is expected to be a full path. + dir string +} + +// activeMigration is an in-process migration. +type activeMigration struct { + // StartHeight is the height at which the migration starts. + StartHeight int64 + // EndHeight is the height at which the migration ends. + EndHeight int64 + // ChainID is the chain ID of the migration. + ChainID string +} + +// TODO: when we implement the constructor, we need to add a note that the snapshotter should +// not be the normal snapshotter, but instead its own instance. This is needed to ensure we do not +// delete the migration snapshot. + +// NotifyHeight notifies the migrator that a new block has been committed. +// It is called at the end of the block being applied, but before the block is +// committed to the database, in between tx.PreCommit and tx.Commit. +func (m *Migrator) NotifyHeight(ctx context.Context, block *common.BlockContext) error { + m.mu.Lock() + defer m.mu.Unlock() + // if there is no active migration, there is nothing to do + if m.activeMigration == nil { + return nil + } + + // if not in a migration, we can return early + if block.Height < m.activeMigration.StartHeight { + return nil + } + + if block.Height > m.activeMigration.EndHeight { + panic("internal bug: block height is greater than end height of migration") + } + + /* + I previously thought to make this run asynchronously, since PG dump can take a significant amount of time, + however I decided againast it, because nodes are required to agree on the height of the old chain during the + migration on the new chain. Im not sure of a way to guarantee this besdies literally enforcing that the old + chain runs the migration synchronously as part of consensus. + + NOTE: https://github.com/kwilteam/kwil-db/pull/837#discussion_r1648036539 + */ + + // if the current block height is the height at which the migration starts, then + // we should snapshot the current DB and begin the migration. Since NotifyHeight is called + // during PreCommit, the state changes from the current block won't be included in the snapshot, + // and will instead need to be recorded as the first changeset of the migration. + if block.Height == m.activeMigration.StartHeight { + tx, snapshotId, err := m.DB.BeginSnapshotTx(ctx) + if err != nil { + return err + } + + err = m.snapshotter.CreateSnapshot(ctx, uint64(block.Height), snapshotId) + if err != nil { + err2 := tx.Rollback(ctx) + if err2 != nil { + // we can mostly ignore this error, since the original err will halt the node anyways + m.Logger.Errorf("failed to rollback transaction: %s", err2.Error()) + } + return err + } + + err = tx.Rollback(ctx) + if err != nil { + return err + } + } + + if block.Height == m.activeMigration.EndHeight { + // an error here will halt the node. + // there might be a more elegant way to handle this, but for now, this is fine. + return fmt.Errorf(`NETWORK HALTED: migration to chain "%s" has completed`, m.activeMigration.ChainID) + } + + // if we reach here, we are in a block that must be migrated. + // TODO: get changeset + var cs Changeset + err := m.storeChangeset(&cs) + if err != nil { + return err + } + + m.lastChangeset = block.Height + return nil +} + +var ErrNoActiveMigration = fmt.Errorf("no active migration") + +// MigrationMetadata holds metadata about a migration, informing +// consumers of what information the current node has available +// for the migration. +type MigrationMetadata struct { + // GenesisSnapshot holds information about the genesis snapshot. + GenesisSnapshot *statesync.Snapshot + // LastChangeset is the height of the last changeset that was stored. + // Nodes are expected to have all changesets from LastChangeset to + // Snapshot.Height. If LastChangeset is -1, then no changesets have + // been stored yet. + LastChangeset int64 +} + +// GetMigrationMetadata gets the metadata for the genesis snapshot, +// as well as the available changesets. +func (m *Migrator) GetMigrationMetadata() (*MigrationMetadata, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + if m.activeMigration == nil { + return nil, ErrNoActiveMigration + } + + snapshots := m.snapshotter.ListSnapshots() + if len(snapshots) == 0 { + return nil, fmt.Errorf("migration is active, but no snapshots found. The node might still be creating the snapshot") + } + if len(snapshots) > 1 { + return nil, fmt.Errorf("migration is active, but more than one snapshot found. This should not happen, and is likely a bug") + } + + return &MigrationMetadata{ + GenesisSnapshot: snapshots[0], + LastChangeset: m.lastChangeset, + }, nil +} + +// GetGenesisChunk gets the genesis chunk for the migration. +func (m *Migrator) GetGenesisChunk(height int64, format uint32, chunkIdx uint32) ([]byte, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + if m.activeMigration == nil { + return nil, ErrNoActiveMigration + } + return m.snapshotter.LoadSnapshotChunk(uint64(height), format, chunkIdx) +} + +// GetChangeset gets the changeset for a block in the migration. +func (m *Migrator) GetChangeset(height int64) (*Changeset, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + if m.activeMigration == nil { + return nil, ErrNoActiveMigration + } + if height < m.activeMigration.StartHeight { + return nil, fmt.Errorf("requested changeset height is before the start of the migration") + } + if height > m.activeMigration.EndHeight { + return nil, fmt.Errorf("requested changeset height is after the end of the migration") + } + if height > m.lastChangeset { + return nil, fmt.Errorf("requested changeset height has not been recorded by the node yet") + } + + return m.loadChangeset(height) +} + +// storeChangeset persists a changeset to the migration directory. +func (m *Migrator) storeChangeset(c *Changeset) error { + bts, err := json.Marshal(c) + if err != nil { + return err + } + + file, err := os.Create(filepath.Join(m.dir, formatChangesetFilename(c.Height))) + if err != nil { + return err + } + + _, err = file.Write(bts) + if err != nil { + return err + } + + return file.Close() +} + +// loadChangeset loads a changeset from the migration directory. +func (m *Migrator) loadChangeset(height int64) (*Changeset, error) { + file, err := os.Open(filepath.Join(m.dir, formatChangesetFilename(height))) + if err != nil { + // we don't have to have special checks for non-existence, since + // we should check that prior to calling this function. + return nil, err + } + + var c Changeset + err = json.NewDecoder(file).Decode(&c) + if err != nil { + return nil, err + } + + return &c, nil +} + +type Changeset struct { + Height int64 `json:"height"` +} + +func formatChangesetFilename(height int64) string { + return fmt.Sprintf("changeset-%d.json", height) +} diff --git a/internal/migrations/sql.go b/internal/migrations/sql.go new file mode 100644 index 000000000..27c4505ff --- /dev/null +++ b/internal/migrations/sql.go @@ -0,0 +1,102 @@ +package migrations + +import ( + "context" + "fmt" + + "github.com/kwilteam/kwil-db/common/sql" +) + +// InitializeMigrationSchema initializes the migration schema in the database. +func InitializeMigrationSchema(ctx context.Context, db sql.DB) error { + _, err := db.Execute(ctx, tableMigrationsSQL) + return err +} + +const migrationsSchemaName = `kwild_migrations` + +var ( + // tableMigrationsSQL is the sql table used to store the current migration state. + // Only one migration can be active at a time. + // Primary key should always be 1, to help us ensure there are no bugs in the code. + tableMigrationsSQL = `CREATE TABLE IF NOT EXISTS ` + migrationsSchemaName + `.migration ( + id INT PRIMARY KEY, + start_height INT NOT NULL, + end_height INT NOT NULL, + chain_id TEXT NOT NULL + )` + + // getMigrationSQL is the sql query used to get the current migration. + getMigrationSQL = `SELECT start_height, end_height, chain_id FROM ` + migrationsSchemaName + `.migration;` + // migrationIsActiveSQL is the sql query used to check if a migration is active. + migrationIsActiveSQL = `SELECT EXISTS(SELECT 1 FROM ` + migrationsSchemaName + `.migration);` + // createMigrationSQL is the sql query used to create a new migration. + createMigrationSQL = `INSERT INTO ` + migrationsSchemaName + `.migration (id, start_height, end_height, chain_id) VALUES ($1, $2, $3, $4);` +) + +// getMigrationState gets the current migration state from the database. +// TODO: unexport this. I am exporting it b/c it is currently unused, and the linter +// nolint directive is not working. +func GetMigrationState(ctx context.Context, db sql.Executor) (*activeMigration, error) { + res, err := db.Execute(ctx, getMigrationSQL) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("network does not have an active migration") + } + if err != nil { + return nil, err + } + + if len(res.Rows) != 1 { + // should never happen + return nil, fmt.Errorf("internal bug: expected one row for migrations, got %d", len(res.Rows)) + } + + // parse the migration declaration + md := &activeMigration{} + + row := res.Rows[0] + var ok bool + md.StartHeight, ok = row[0].(int64) + if !ok { + return nil, fmt.Errorf("internal bug: activation period is not an int64") + } + + md.EndHeight, ok = row[1].(int64) + if !ok { + return nil, fmt.Errorf("internal bug: duration is not an int64") + } + + md.ChainID, ok = row[2].(string) + if !ok { + return nil, fmt.Errorf("internal bug: chain ID is not a string") + } + + return md, nil +} + +// migrationActive checks if a migration is active. +func migrationActive(ctx context.Context, db sql.Executor) (bool, error) { + res, err := db.Execute(ctx, migrationIsActiveSQL) + if err != nil { + return false, err + } + + if len(res.Rows) != 1 { + // should never happen + return false, fmt.Errorf("internal bug: expected one row for migrations, got %d", len(res.Rows)) + } + + row := res.Rows[0] + active, ok := row[0].(bool) + if !ok { + return false, fmt.Errorf("internal bug: migration active is not a bool") + } + + return active, nil +} + +// createMigration creates a new migration state in the database. +func createMigration(ctx context.Context, db sql.Executor, md *activeMigration) error { + _, err := db.Execute(ctx, createMigrationSQL, 1, md.StartHeight, md.EndHeight, md.ChainID) + return err +} diff --git a/internal/services/grpc/txsvc/v1/pricing.go b/internal/services/grpc/txsvc/v1/pricing.go index f8c1829aa..2d4725253 100644 --- a/internal/services/grpc/txsvc/v1/pricing.go +++ b/internal/services/grpc/txsvc/v1/pricing.go @@ -22,7 +22,7 @@ func (s *Service) EstimatePrice(ctx context.Context, req *txpb.EstimatePriceRequ } defer readTx.Rollback(ctx) - price, err := s.nodeApp.Price(ctx, readTx, tx) + price, err := s.pricer.Price(ctx, readTx, tx) if err != nil { return nil, fmt.Errorf("failed to estimate price: %w", err) } diff --git a/internal/services/grpc/txsvc/v1/service.go b/internal/services/grpc/txsvc/v1/service.go index 44c553522..c9213abc9 100644 --- a/internal/services/grpc/txsvc/v1/service.go +++ b/internal/services/grpc/txsvc/v1/service.go @@ -29,16 +29,18 @@ type Service struct { db sql.ReadTxMaker // this should only ever make a read-only tx nodeApp NodeApplication // so we don't have to do ABCIQuery (indirect) + pricer Pricer chainClient BlockchainTransactor } func NewService(db sql.ReadTxMaker, engine EngineReader, - chainClient BlockchainTransactor, nodeApp NodeApplication, opts ...TxSvcOpt) *Service { + chainClient BlockchainTransactor, nodeApp NodeApplication, pricer Pricer, opts ...TxSvcOpt) *Service { s := &Service{ log: log.NewNoOp(), readTxTimeout: defaultReadTxTimeout, engine: engine, nodeApp: nodeApp, + pricer: pricer, chainClient: chainClient, db: db, } @@ -65,5 +67,8 @@ type BlockchainTransactor interface { type NodeApplication interface { AccountInfo(ctx context.Context, db sql.DB, identifier []byte, getUncommitted bool) (balance *big.Int, nonce int64, err error) +} + +type Pricer interface { Price(ctx context.Context, db sql.DB, tx *transactions.Transaction) (*big.Int, error) } diff --git a/internal/services/jsonrpc/adminsvc/service.go b/internal/services/jsonrpc/adminsvc/service.go index eab09b353..e73f23031 100644 --- a/internal/services/jsonrpc/adminsvc/service.go +++ b/internal/services/jsonrpc/adminsvc/service.go @@ -34,19 +34,23 @@ type BlockchainTransactor interface { } type TxApp interface { - Price(ctx context.Context, db sql.DB, tx *transactions.Transaction) (*big.Int, error) // AccountInfo returns the unconfirmed account info for the given identifier. // If unconfirmed is true, the account found in the mempool is returned. // Otherwise, the account found in the blockchain is returned. AccountInfo(ctx context.Context, db sql.DB, identifier []byte, unconfirmed bool) (balance *big.Int, nonce int64, err error) } +type Pricer interface { + Price(ctx context.Context, db sql.DB, tx *transactions.Transaction) (*big.Int, error) +} + type Service struct { log log.Logger blockchain BlockchainTransactor // node is the local node that can accept transactions. TxApp TxApp db sql.DelayedReadTxMaker + pricer Pricer cfg *config.KwildConfig chainID string @@ -124,13 +128,14 @@ func (svc *Service) Handlers() map[jsonrpc.Method]rpcserver.MethodHandler { } // NewService constructs a new Service. -func NewService(db sql.DelayedReadTxMaker, blockchain BlockchainTransactor, txApp TxApp, signer auth.Signer, cfg *config.KwildConfig, +func NewService(db sql.DelayedReadTxMaker, blockchain BlockchainTransactor, txApp TxApp, pricer Pricer, signer auth.Signer, cfg *config.KwildConfig, chainID string, logger log.Logger) *Service { return &Service{ blockchain: blockchain, TxApp: txApp, signer: signer, chainID: chainID, + pricer: pricer, cfg: cfg, log: logger, db: db, @@ -196,7 +201,7 @@ func (svc *Service) sendTx(ctx context.Context, payload transactions.Payload) (* return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "unable to create transaction", nil) } - fee, err := svc.TxApp.Price(ctx, readTx, tx) + fee, err := svc.pricer.Price(ctx, readTx, tx) if err != nil { return nil, jsonrpc.NewError(jsonrpc.ErrorTxInternal, "unable to price transaction", nil) } diff --git a/internal/services/jsonrpc/usersvc/service.go b/internal/services/jsonrpc/usersvc/service.go index 00812a577..217db6aa9 100644 --- a/internal/services/jsonrpc/usersvc/service.go +++ b/internal/services/jsonrpc/usersvc/service.go @@ -36,6 +36,7 @@ type Service struct { db sql.DelayedReadTxMaker // this should only ever make a read-only tx nodeApp NodeApplication // so we don't have to do ABCIQuery (indirect) chainClient BlockchainTransactor + pricer Pricer } type serviceCfg struct { @@ -57,7 +58,7 @@ const defaultReadTxTimeout = 5 * time.Second // NewService creates a new instance of the user RPC service. func NewService(db sql.DelayedReadTxMaker, engine EngineReader, chainClient BlockchainTransactor, - nodeApp NodeApplication, logger log.Logger, opts ...Opt) *Service { + nodeApp NodeApplication, pricer Pricer, logger log.Logger, opts ...Opt) *Service { cfg := &serviceCfg{ readTxTimeout: defaultReadTxTimeout, } @@ -69,6 +70,7 @@ func NewService(db sql.DelayedReadTxMaker, engine EngineReader, chainClient Bloc readTxTimeout: cfg.readTxTimeout, engine: engine, nodeApp: nodeApp, + pricer: pricer, chainClient: chainClient, db: db, } @@ -190,6 +192,9 @@ type BlockchainTransactor interface { type NodeApplication interface { AccountInfo(ctx context.Context, db sql.DB, identifier []byte, getUncommitted bool) (balance *big.Int, nonce int64, err error) +} + +type Pricer interface { Price(ctx context.Context, db sql.DB, tx *transactions.Transaction) (*big.Int, error) } @@ -300,7 +305,7 @@ func (svc *Service) EstimatePrice(ctx context.Context, req *userjson.EstimatePri readTx := svc.db.BeginDelayedReadTx() defer readTx.Rollback(ctx) - price, err := svc.nodeApp.Price(ctx, readTx, req.Tx) + price, err := svc.pricer.Price(ctx, readTx, req.Tx) if err != nil { svc.log.Error("failed to estimate price", log.Error(err)) // why not tell the client though? return nil, jsonrpc.NewError(jsonrpc.ErrorTxInternal, "failed to estimate price", nil) diff --git a/internal/sql/pg/conn.go b/internal/sql/pg/conn.go index 9b19f8261..33b7a1276 100644 --- a/internal/sql/pg/conn.go +++ b/internal/sql/pg/conn.go @@ -77,6 +77,9 @@ type Pool struct { // how postgres itself reserves connections with the reserved_connections // and superuser_reserved_connections system settings. // https://www.postgresql.org/docs/current/runtime-config-connection.html#GUC-RESERVED-CONNECTIONS + // oidTypes maps an OID to the datatype it represents. Since Kwil has data types such as uint256, + // which are registered as Postgres Domains, each pg instance will have its own random OID for it. + idTypes map[uint32]*datatype } // PoolConfig combines a connection config with additional options for a pool of @@ -168,15 +171,26 @@ func NewPool(ctx context.Context, cfg *PoolConfig) (*Pool, error) { return nil, err } + // acquire a writer to determine the OID of the custom types + writerConn, err := writer.Acquire(ctx) + if err != nil { + return nil, err + } + defer writerConn.Release() + oidTypes := oidTypesMap(writerConn.Conn().TypeMap()) + pool := &Pool{ readers: db, writer: writer, reserved: reserved, + idTypes: oidTypes, } return pool, db.Ping(ctx) } +// registerTypes ensures that the custom types used by Kwil are registered with +// the pgx connection. func registerTypes(ctx context.Context, conn *pgx.Conn) error { err := ensureUint256Domain(ctx, conn) if err != nil { @@ -206,7 +220,7 @@ func registerTypes(ctx context.Context, conn *pgx.Conn) error { // executed in a transaction with read only access mode to ensure there can be // no modifications. func (p *Pool) Query(ctx context.Context, stmt string, args ...any) (*sql.ResultSet, error) { - return queryTx(ctx, p.readers, stmt, args...) + return queryTx(ctx, p.idTypes, p.readers, stmt, args...) } // WARNING: The Execute method is for completeness and helping tests, but is not @@ -218,7 +232,7 @@ func (p *Pool) Execute(ctx context.Context, stmt string, args ...any) (*sql.Resu var res *sql.ResultSet err := p.writer.AcquireFunc(ctx, func(c *pgxpool.Conn) error { var err error - res, err = query(ctx, &cqWrapper{c.Conn()}, stmt, args...) + res, err = query(ctx, p.idTypes, &cqWrapper{c.Conn()}, stmt, args...) return err }) if err != nil { @@ -247,6 +261,7 @@ func (p *Pool) BeginTx(ctx context.Context) (sql.Tx, error) { return &nestedTx{ Tx: tx, accessMode: sql.ReadWrite, + oidTypes: p.idTypes, }, nil } @@ -262,5 +277,6 @@ func (p *Pool) BeginReadTx(ctx context.Context) (sql.Tx, error) { return &nestedTx{ Tx: tx, accessMode: sql.ReadOnly, + oidTypes: p.idTypes, }, nil } diff --git a/internal/sql/pg/db.go b/internal/sql/pg/db.go index c083133a4..b58497843 100644 --- a/internal/sql/pg/db.go +++ b/internal/sql/pg/db.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "strings" "sync" @@ -50,11 +51,12 @@ type DB struct { autoCommit bool // skip the explicit transaction (begin/commit automatically) tx pgx.Tx // interface txid string // uid of the prepared transaction + seq int64 // NOTE: this was initially designed for a single ongoing write transaction, // held in the tx field, and the (*DB).Execute method using it *implicitly*. // We have moved toward using the Execute method of the transaction returned - // by BeginTx/BeginOuterTx/BeginReadTx, and we can potentially allow + // by BeginTx/BeginPreparedTx/BeginReadTx, and we can potentially allow // multiple uncommitted write transactions to support a 2 phase commit of // different stores using the same pg.DB instance. This will take refactoring // of the DB and concrete transaction type methods. @@ -154,8 +156,8 @@ func NewDB(ctx context.Context, cfg *DBConfig) (*DB, error) { // Ensure all tables that are created with no primary key or unique index // are altered to have "full replication identity" for UPDATE and DELETES. - if err = ensureTriggerReplIdentity(ctx, conn); err != nil { - return nil, fmt.Errorf("failed to create replication identity trigger: %w", err) + if err = ensureFullReplicaIdentityTrigger(ctx, conn); err != nil { + return nil, fmt.Errorf("failed to create full replication identity trigger: %w", err) } // Create the publication that is required for logical replication. @@ -176,7 +178,7 @@ func NewDB(ctx context.Context, cfg *DBConfig) (*DB, error) { okSchema = defaultSchemaFilter } - repl, err := newReplMon(ctx, cfg.Host, cfg.Port, cfg.User, cfg.Pass, cfg.DBName, okSchema) + repl, err := newReplMon(ctx, cfg.Host, cfg.Port, cfg.User, cfg.Pass, cfg.DBName, okSchema, pool.idTypes) if err != nil { return nil, err } @@ -199,6 +201,7 @@ func NewDB(ctx context.Context, cfg *DBConfig) (*DB, error) { repl: repl, cancel: cancel, ctx: runCtx, + seq: -1, } // Supervise the replication stream monitor. If it dies (repl.done chan @@ -255,18 +258,17 @@ func (db *DB) AutoCommit(auto bool) { // For {accounts,validators}.Datasets / registry.DB var _ sql.Executor = (*DB)(nil) -var _ sql.OuterTxMaker = (*DB)(nil) // for dataset Registry - -// BeginTx makes the DB's singular transaction, which is used automatically by -// consumers of the Query and Execute methods. This is the mode of operation -// used by Kwil to have one system coordinating transaction lifetime, with one -// or more other systems implicitly using the transaction for their queries. -// -// The returned transaction is also capable of creating nested transactions. -// This functionality is used to prevent user dataset query errors from rolling -// back the outermost transaction. -func (db *DB) BeginOuterTx(ctx context.Context) (sql.OuterTx, error) { - tx, err := db.beginWriterTx(ctx) +var _ sql.PreparedTxMaker = (*DB)(nil) // for dataset Registry + +// beginTx starts a read-write transaction, returning a dbTx. It will be for a +// prepared transaction if sequenced is true, incrementing the seq value of the +// internal sentry table to allow obtaining a commit ID with Precommit. If +// sequenced is false, the precommit method will hang (use commit only). If +// sequenced is true and the precommit method is not used (straight to commit), +// it will work as intended, but the replication monitor will warn about an +// unexpected sequence update in the transaction. +func (db *DB) beginTx(ctx context.Context, sequenced bool) (*dbTx, error) { + tx, err := db.beginWriterTx(ctx, sequenced) if err != nil { return nil, err } @@ -274,6 +276,7 @@ func (db *DB) BeginOuterTx(ctx context.Context) (sql.OuterTx, error) { ntx := &nestedTx{ Tx: tx, accessMode: sql.ReadWrite, + oidTypes: db.pool.idTypes, } return &dbTx{ nestedTx: ntx, @@ -282,11 +285,29 @@ func (db *DB) BeginOuterTx(ctx context.Context) (sql.OuterTx, error) { }, nil } +// BeginPreparedTx makes the DB's singular transaction, which is used automatically +// by consumers of the Query and Execute methods. This is the mode of operation +// used by Kwil to have one system coordinating transaction lifetime, with one +// or more other systems implicitly using the transaction for their queries. +// +// This method creates a sequenced transaction, and it should be committed with +// a prepared transaction (two-phase commit) using Precommit. Use BeginTx for a +// regular outer transaction without sequencing or a prepared transaction. +// +// The returned transaction is also capable of creating nested transactions. +// This functionality is used to prevent user dataset query errors from rolling +// back the outermost transaction. +func (db *DB) BeginPreparedTx(ctx context.Context) (sql.PreparedTx, error) { + return db.beginTx(ctx, true) // sequenced, expose Precommit +} + var _ sql.TxMaker = (*DB)(nil) var _ sql.DB = (*DB)(nil) +// BeginTx starts a regular read-write transaction. For a sequenced two-phase +// transaction, use BeginPreparedTx. func (db *DB) BeginTx(ctx context.Context) (sql.Tx, error) { - return db.BeginOuterTx(ctx) // slice off the Precommit method from sql.OuterTx + return db.beginTx(ctx, false) // slice off the Precommit method from sql.PreparedTx } // ReadTx creates a read-only transaction for the database. @@ -338,6 +359,7 @@ func (db *DB) beginReadTx(ctx context.Context, iso pgx.TxIsoLevel) (sql.Tx, erro ntx := &nestedTx{ Tx: tx, accessMode: sql.ReadOnly, + oidTypes: db.pool.idTypes, } return &readTx{ @@ -363,6 +385,7 @@ func (db *DB) BeginReservedReadTx(ctx context.Context) (sql.Tx, error) { return &nestedTx{ Tx: tx, accessMode: sql.ReadOnly, + oidTypes: db.pool.idTypes, }, nil } @@ -396,7 +419,7 @@ func (txw *writeTxWrapper) Rollback(ctx context.Context) error { // beginWriterTx is the critical section of BeginTx. // It creates a new transaction on the write connection, and stores it in the // DB's tx field. It is not exported, and is only called from BeginTx. -func (db *DB) beginWriterTx(ctx context.Context) (pgx.Tx, error) { +func (db *DB) beginWriterTx(ctx context.Context, sequenced bool) (pgx.Tx, error) { db.mtx.Lock() defer db.mtx.Unlock() @@ -424,36 +447,44 @@ func (db *DB) beginWriterTx(ctx context.Context) (pgx.Tx, error) { release: writer.Release, } - return db.tx, nil -} - -// precommit finalizes the transaction with a prepared transaction and returns -// the ID of the commit. The transaction is not yet committed. -func (db *DB) precommit(ctx context.Context) ([]byte, error) { - db.mtx.Lock() - defer db.mtx.Unlock() - - if db.tx == nil { - return nil, errors.New("no tx exists") + if !sequenced { + db.seq = -1 // should already be + return db.tx, nil } // Do the seq update in sentry table. This ensures a replication message // sequence is emitted from this transaction, and that the data returned // from it includes the expected seq value. - seq, err := incrementSeq(ctx, db.tx) + seq, err := incrementSeq(ctx, tx) if err != nil { return nil, err } logger.Debugf("updated seq to %d", seq) + db.seq = seq + + return db.tx, nil +} + +// precommit finalizes the transaction with a prepared transaction and returns +// the ID of the commit. The transaction is not yet committed. It takes an io.Writer +// to write the changeset to, and returns the commit ID. If the io.Writer is nil, +// it won't write the changeset anywhere. +func (db *DB) precommit(ctx context.Context, writer io.Writer) ([]byte, error) { + db.mtx.Lock() + defer db.mtx.Unlock() + + if db.tx == nil || db.seq == -1 { + return nil, errors.New("no tx exists") + } - resChan, ok := db.repl.recvID(seq) + resChan, ok := db.repl.recvID(db.seq, writer) if !ok { // commitID will not be available, error. There is no recovery presently. return nil, errors.New("replication connection is down") } db.txid = random.String(10) sqlPrepareTx := fmt.Sprintf(`PREPARE TRANSACTION '%s'`, db.txid) - if _, err = db.tx.Exec(ctx, sqlPrepareTx); err != nil { + if _, err := db.tx.Exec(ctx, sqlPrepareTx); err != nil { return nil, err } @@ -489,6 +520,7 @@ func (db *DB) commit(ctx context.Context) error { // Allow commit without two-phase prepare err := db.tx.Commit(ctx) db.tx = nil + db.seq = -1 return err } @@ -498,6 +530,7 @@ func (db *DB) commit(ctx context.Context) error { } sqlRollback := fmt.Sprintf(`ROLLBACK PREPARED '%s'`, db.txid) db.txid = "" + db.seq = -1 if _, err := db.tx.Exec(ctx, sqlRollback); err != nil { logger.Warnf("ROLLBACK PREPARED failed: %v", err) } @@ -518,6 +551,7 @@ func (db *DB) commit(ctx context.Context) error { // prepare will try to rollback this old prepared txn. db.tx = nil db.txid = "" + db.seq = -1 return nil } @@ -535,6 +569,7 @@ func (db *DB) rollback(ctx context.Context) error { defer func() { db.tx = nil db.txid = "" + db.seq = -1 }() // If precommit not yet done, do a regular rollback. @@ -595,7 +630,7 @@ func (db *DB) Execute(ctx context.Context, stmt string, args ...any) (*sql.Resul if db.autoCommit { return nil, errors.New("tx already created, cannot use auto commit") } - return query(ctx, db.tx, stmt, args...) + return query(ctx, db.pool.idTypes, db.tx, stmt, args...) } if !db.autoCommit { return nil, sql.ErrNoTransaction @@ -616,11 +651,11 @@ func (db *DB) Execute(ctx context.Context, stmt string, args ...any) (*sql.Resul return err } var ok bool - resChan, ok = db.repl.recvID(seq) + resChan, ok = db.repl.recvID(seq, nil) // nil changeset writer since we are in auto-commit mode if !ok { return errors.New("replication connection is down") } - res, err = query(ctx, tx, stmt, args...) + res, err = query(ctx, db.pool.idTypes, tx, stmt, args...) return err }, ) diff --git a/internal/sql/pg/db_live_test.go b/internal/sql/pg/db_live_test.go index ea3f1dc06..e469896cc 100644 --- a/internal/sql/pg/db_live_test.go +++ b/internal/sql/pg/db_live_test.go @@ -3,7 +3,9 @@ package pg import ( + "bytes" "context" + "fmt" "strings" "testing" "time" @@ -177,6 +179,11 @@ func TestSelectLiteralType(t *testing.T) { } }) + err = registerTypes(ctx, conn) + if err != nil { + t.Fatal(err) + } + // var arg any = int64(1) // args := []any{arg, arg} argMap := map[string]any{ @@ -232,7 +239,7 @@ func TestSelectLiteralType(t *testing.T) { // Now with our high level func and mode. args2 := append([]any{QueryModeInferredArgTypes}, args...) - results, err := query(ctx, &cqWrapper{conn}, stmt, args2...) + results, err := query(ctx, oidTypesMap(conn.TypeMap()), &cqWrapper{conn}, stmt, args2...) if err != nil { t.Fatal(err) } @@ -256,7 +263,7 @@ func TestNestedTx(t *testing.T) { } // Start the outer transaction. - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginPreparedTx(ctx) if err != nil { t.Fatal(err) } @@ -328,7 +335,7 @@ func TestNestedTx(t *testing.T) { t.Fatal(err) } - id, err := tx.Precommit(ctx) + id, err := tx.Precommit(ctx, nil) if err != nil { t.Fatal(err) } @@ -430,47 +437,47 @@ func TestTypeRoundtrip(t *testing.T) { { typ: "int8[]", val: []int64{1, 2, 3}, - want: []any{int64(1), int64(2), int64(3)}, + want: []int64{int64(1), int64(2), int64(3)}, }, { typ: "bool[]", val: []bool{true, false, true}, - want: []any{true, false, true}, + want: []bool{true, false, true}, }, { typ: "text[]", val: []string{"a", "b", "c"}, - want: []any{"a", "b", "c"}, + want: []string{"a", "b", "c"}, }, { typ: "bytea[]", val: [][]byte{[]byte("a"), []byte("b"), []byte("c")}, - want: []any{[]byte("a"), []byte("b"), []byte("c")}, + want: [][]byte{[]byte("a"), []byte("b"), []byte("c")}, }, { typ: "uuid[]", val: types.UUIDArray{types.NewUUIDV5([]byte("2")), types.NewUUIDV5([]byte("3"))}, - want: []any{types.NewUUIDV5([]byte("2")), types.NewUUIDV5([]byte("3"))}, + want: types.UUIDArray{types.NewUUIDV5([]byte("2")), types.NewUUIDV5([]byte("3"))}, }, { typ: "decimal(6,4)[]", val: decimal.DecimalArray{mustDecimal("12.4223"), mustDecimal("22.4425"), mustDecimal("23.7423")}, - want: []any{mustDecimal("12.4223"), mustDecimal("22.4425"), mustDecimal("23.7423")}, + want: decimal.DecimalArray{mustDecimal("12.4223"), mustDecimal("22.4425"), mustDecimal("23.7423")}, }, { - typ: "decimal(3,0)[]", + typ: "uint256[]", val: types.Uint256Array{types.Uint256FromInt(100), types.Uint256FromInt(200), types.Uint256FromInt(300)}, - want: []any{mustDecimal("100"), mustDecimal("200"), mustDecimal("300")}, + want: types.Uint256Array{types.Uint256FromInt(100), types.Uint256FromInt(200), types.Uint256FromInt(300)}, }, { typ: "text[]", val: []string{}, - want: []any{}, + want: []string{}, }, { typ: "int8[]", val: []int64{}, - want: []any{}, + want: []int64{}, }, { typ: "nil", @@ -480,13 +487,13 @@ func TestTypeRoundtrip(t *testing.T) { { typ: "[]uuid", val: []any{"3146857c-8671-4f4e-99bd-fcc621f9d3d1", "3146857c-8671-4f4e-99bd-fcc621f9d3d1"}, - want: []any{"3146857c-8671-4f4e-99bd-fcc621f9d3d1", "3146857c-8671-4f4e-99bd-fcc621f9d3d1"}, + want: []string{"3146857c-8671-4f4e-99bd-fcc621f9d3d1", "3146857c-8671-4f4e-99bd-fcc621f9d3d1"}, skipTbl: true, }, { typ: "int8[]", val: []string{"1", "2"}, - want: []any{int64(1), int64(2)}, + want: []int64{int64(1), int64(2)}, skipInferred: true, }, } { @@ -518,7 +525,7 @@ func TestTypeRoundtrip(t *testing.T) { // here, we test without the QueryModeInferredArgTypes - tx, err := db.BeginOuterTx(ctx) + tx, err := db.BeginPreparedTx(ctx) require.NoError(t, err) defer tx.Rollback(ctx) // always rollback @@ -557,6 +564,15 @@ func mustParseUUID(s string) *types.UUID { return u } +// mustUint256 panics if the string cannot be converted to a Uint256. +func mustUint256(s string) *types.Uint256 { + u, err := types.Uint256FromString(s) + if err != nil { + panic(err) + } + return u +} + func Test_DelayedTx(t *testing.T) { ctx := context.Background() @@ -577,3 +593,217 @@ func Test_DelayedTx(t *testing.T) { err = tx2.Commit(ctx) require.NoError(t, err) } + +// This test tests changesets, and that they are properly encoded+decoded +func Test_Changesets(t *testing.T) { + for i, tc := range []interface { + run(t *testing.T) + }{ + &changesetTestcase[string, []string]{ // basic string test + datatype: "text", + val: "hello", + arrayVal: []string{"a", "b", "c"}, + val2: "world", + arrayVal2: []string{"d", "e", "f"}, + }, + &changesetTestcase[string, []string]{ // test with special characters and escaping + datatype: "text", + val: "heldcsklk;le''\"';", + arrayVal: []string{"hel,dcsklk;le','\",';", `";\\sdsw,"''"\',\""`}, + val2: "world", + arrayVal2: []string{"'\"", "heldcsklk;le''\"';"}, + }, + &changesetTestcase[int64, []int64]{ + datatype: "int8", + val: 1, + arrayVal: []int64{1, 2, 3987654}, + val2: 2, + arrayVal2: []int64{3, 4, 5}, + }, + &changesetTestcase[bool, []bool]{ + datatype: "bool", + val: true, + arrayVal: []bool{true, false, true}, + val2: false, + arrayVal2: []bool{false, true, false}, + }, + &changesetTestcase[[]byte, [][]byte]{ + datatype: "bytea", + val: []byte("hello"), + arrayVal: [][]byte{[]byte("a"), []byte("b"), []byte("c")}, + val2: []byte("world"), + arrayVal2: [][]byte{[]byte("d"), []byte("e"), []byte("f")}, + }, + &changesetTestcase[*decimal.Decimal, decimal.DecimalArray]{ + datatype: "decimal(6,3)", + val: mustDecimal("123.456"), + arrayVal: decimal.DecimalArray{mustDecimal("123.456"), mustDecimal("123.456"), mustDecimal("123.456")}, + val2: mustDecimal("123.457"), + arrayVal2: decimal.DecimalArray{mustDecimal("123.457"), mustDecimal("123.457"), mustDecimal("123.457")}, + }, + &changesetTestcase[*types.UUID, types.UUIDArray]{ + datatype: "uuid", + val: mustParseUUID("3146857c-8671-4f4e-99bd-fcc621f9d3d1"), + arrayVal: types.UUIDArray{mustParseUUID("3146857c-8671-4f4e-99bd-fcc621f9d3d1"), mustParseUUID("3146857c-8671-4f4e-99bd-fcc621f9d3d1")}, + val2: mustParseUUID("3146857c-8671-4f4e-99bd-fcc621f9d3d2"), + arrayVal2: types.UUIDArray{mustParseUUID("3146857c-8671-4f4e-99bd-fcc621f9d3d2"), mustParseUUID("3146857c-8671-4f4e-99bd-fcc621f9d3d2")}, + }, + &changesetTestcase[*types.Uint256, types.Uint256Array]{ + datatype: "uint256", + val: mustUint256("18446744073709551615000000"), + arrayVal: types.Uint256Array{mustUint256("184467440737095516150000002"), mustUint256("184467440737095516150000001")}, + val2: mustUint256("18446744073709551615000001"), + arrayVal2: types.Uint256Array{mustUint256("184467440737095516150000012"), mustUint256("1844674407370955161500000123")}, + }, + } { + t.Run(fmt.Sprint(i), tc.run) + } +} + +// this is a hack to use generics in the test +type changesetTestcase[T any, T2 any] struct { + datatype string // the postgres datatype to test + // the first vals will be inserted. + // val will be the primary key + val T // the value to test + arrayVal T2 // the array value to test + // the second vals will update the first vals + val2 T // the second value to test + arrayVal2 T2 // the second array value to test +} + +func (c *changesetTestcase[T, T2]) run(t *testing.T) { + ctx := context.Background() + + db, err := NewDB(ctx, cfg) + require.NoError(t, err) + defer db.Close() + + cleanup := func() { + db.AutoCommit(true) + _, err = db.Execute(ctx, "drop table if exists ds_test.test", QueryModeExec) + require.NoError(t, err) + _, err = db.Execute(ctx, "drop schema if exists ds_test", QueryModeExec) + db.AutoCommit(false) + } + // attempt to clean up any old failed tests + cleanup() + defer cleanup() + + regularTx, err := db.BeginPreparedTx(ctx) + require.NoError(t, err) + defer regularTx.Rollback(ctx) + + _, err = regularTx.Execute(ctx, "create schema ds_test", QueryModeExec) + require.NoError(t, err) + + _, err = regularTx.Execute(ctx, "create table ds_test.test (val "+c.datatype+" primary key, array_val "+c.datatype+"[])", QueryModeExec) + require.NoError(t, err) + + err = regularTx.Commit(ctx) + require.NoError(t, err) + + /* + Block 1: Insert + */ + + writer := new(bytes.Buffer) + + tx, err := db.BeginPreparedTx(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Execute(ctx, "insert into ds_test.test (val, array_val) values ($1, $2)", QueryModeExec, c.val, c.arrayVal) + require.NoError(t, err) + + // get the changeset + _, err = tx.Precommit(ctx, writer) + require.NoError(t, err) + + cs, err := DeserializeChangeset(writer.Bytes()) + require.NoError(t, err) + + require.Len(t, cs.Changesets, 1) + require.Len(t, cs.Changesets[0].Inserts, 1) + + insertVals, err := cs.Changesets[0].DecodeTuple(cs.Changesets[0].Inserts[0]) + require.NoError(t, err) + + // verify the insert vals are equal to the first vals + require.EqualValues(t, c.val, insertVals[0]) + require.EqualValues(t, c.arrayVal, insertVals[1]) + + err = tx.Commit(ctx) + require.NoError(t, err) + + /* + Block 2: Update + */ + + writer = new(bytes.Buffer) + + tx, err = db.BeginPreparedTx(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Execute(ctx, "update ds_test.test set val = $1, array_val = $2", QueryModeExec, c.val2, c.arrayVal2) + require.NoError(t, err) + + _, err = tx.Precommit(ctx, writer) + require.NoError(t, err) + + cs, err = DeserializeChangeset(writer.Bytes()) + require.NoError(t, err) + + require.Len(t, cs.Changesets, 1) + require.Len(t, cs.Changesets[0].Updates, 1) + + oldVals, err := cs.Changesets[0].DecodeTuple(cs.Changesets[0].Updates[0][0]) + require.NoError(t, err) + + newVals, err := cs.Changesets[0].DecodeTuple(cs.Changesets[0].Updates[0][1]) + require.NoError(t, err) + + // verify the old vals are equal to the first vals + require.EqualValues(t, c.val, oldVals[0]) + require.EqualValues(t, c.arrayVal, oldVals[1]) + + // verify the new vals are equal to the second vals + require.EqualValues(t, c.val2, newVals[0]) + require.EqualValues(t, c.arrayVal2, newVals[1]) + + err = tx.Commit(ctx) + require.NoError(t, err) + + /* + Block 3: Delete + */ + + writer = new(bytes.Buffer) + + tx, err = db.BeginPreparedTx(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Execute(ctx, "delete from ds_test.test", QueryModeExec) + require.NoError(t, err) + + _, err = tx.Precommit(ctx, writer) + require.NoError(t, err) + + cs, err = DeserializeChangeset(writer.Bytes()) + require.NoError(t, err) + + require.Len(t, cs.Changesets, 1) + require.Len(t, cs.Changesets[0].Deletes, 1) + + deleteVals, err := cs.Changesets[0].DecodeTuple(cs.Changesets[0].Deletes[0]) + require.NoError(t, err) + + // verify the delete vals are equal to the second vals + require.EqualValues(t, c.val2, deleteVals[0]) + require.EqualValues(t, c.arrayVal2, deleteVals[1]) + + err = tx.Commit(ctx) + require.NoError(t, err) +} diff --git a/internal/sql/pg/query.go b/internal/sql/pg/query.go index fdad36b29..eeb79a3b4 100644 --- a/internal/sql/pg/query.go +++ b/internal/sql/pg/query.go @@ -4,15 +4,10 @@ import ( "context" "errors" "fmt" - "math/big" - "reflect" "github.com/kwilteam/kwil-db/common/sql" - "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" ) func queryImpliedArgTypes(ctx context.Context, conn *pgx.Conn, stmt string, args ...any) (pgx.Rows, error) { @@ -42,7 +37,7 @@ optionLoop: } // convert all types to types registered in pgx's type map - args, oids, err := encodeToPGType(args...) + args, oids, err := encodeToPGType(conn.TypeMap(), args...) if err != nil { return nil, fmt.Errorf("encode to pg type failed: %w", err) } @@ -163,7 +158,7 @@ func (cq *cqWrapper) Query(ctx context.Context, sql string, args ...any) (pgx.Ro return cq.c.Query(ctx, sql, args...) } -func query(ctx context.Context, cq connQueryer, stmt string, args ...any) (*sql.ResultSet, error) { +func query(ctx context.Context, oidToDataType map[uint32]*datatype, cq connQueryer, stmt string, args ...any) (*sql.ResultSet, error) { q := cq.Query if mustInferArgs(args) { // return nil, errors.New("cannot use QueryModeInferredArgTypes with query") @@ -200,7 +195,11 @@ func query(ctx context.Context, cq connQueryer, stmt string, args ...any) (*sql. if err != nil { return nil, err } - return decodeFromPGType(pgxVals...) + oids := make([]uint32, len(pgxVals)) + for i, v := range row.FieldDescriptions() { + oids[i] = v.DataTypeOID + } + return decodeFromPG(pgxVals, oids, oidToDataType) }) if errors.Is(err, pgx.ErrNoRows) { return nil, sql.ErrNoRows @@ -215,223 +214,11 @@ func query(ctx context.Context, cq connQueryer, stmt string, args ...any) (*sql. return resSet, err } -// oidArrMap maps oids to their corresponding array oids. -// It only includes types that we care about in Kwil. -var oidArrMap = map[int]int{ - pgtype.BoolOID: pgtype.BoolArrayOID, - pgtype.ByteaOID: pgtype.ByteaArrayOID, - pgtype.Int8OID: pgtype.Int8ArrayOID, - pgtype.TextOID: pgtype.TextArrayOID, - pgtype.UUIDOID: pgtype.UUIDArrayOID, - pgtype.NumericOID: pgtype.NumericArrayOID, -} - -// encodeToPGType encodes several Go types to their corresponding pgx types. -// It is capable of detecting special Kwil types and encoding them to their -// corresponding pgx types. It is only used if using inferred argument types. -// If not using inferred argument types, pgx will rely on the Valuer interface -// to encode the Go types to their corresponding pgx types. -// It also returns the pgx type OIDs for each value. -func encodeToPGType(vals ...any) ([]any, []uint32, error) { - // encodeScalar is a helper function that converts a single value to a pgx. - encodeScalar := func(v any) (any, int, error) { - switch v := v.(type) { - case nil: - return nil, pgtype.TextOID, nil - case bool: - return v, pgtype.BoolOID, nil - case int, int8, int16, int32, int64: - return v, pgtype.Int8OID, nil - case string: - return v, pgtype.TextOID, nil - case []byte: - return v, pgtype.ByteaOID, nil - case *types.UUID: - return pgtype.UUID{Bytes: [16]byte(v.Bytes()), Valid: true}, pgtype.UUIDOID, nil - case types.UUID: - return pgtype.UUID{Bytes: [16]byte(v.Bytes()), Valid: true}, pgtype.UUIDOID, nil - case [16]byte: - return pgtype.UUID{Bytes: v, Valid: true}, pgtype.UUIDOID, nil - case decimal.Decimal: - return pgtype.Numeric{ - Int: v.BigInt(), - Exp: v.Exp(), - Valid: true, - }, pgtype.NumericOID, nil - case *decimal.Decimal: - return pgtype.Numeric{ - Int: v.BigInt(), - Exp: v.Exp(), - Valid: true, - }, pgtype.NumericOID, nil - case types.Uint256: - return pgtype.Numeric{ - Int: v.ToBig(), - Exp: 0, - Valid: true, - }, pgtype.NumericOID, nil - case *types.Uint256: - return pgtype.Numeric{ - Int: v.ToBig(), - Exp: 0, - Valid: true, - }, pgtype.NumericOID, nil - } - - return nil, 0, fmt.Errorf("unsupported type: %T", v) - } - - // we convert all types to postgres's type. If the underlying type is an - // array, we will set it as that so that pgx can handle it properly. - // The one exception is []byte, which is handled by pgx as a bytea. - pgxVals := make([]any, len(vals)) - oids := make([]uint32, len(vals)) - for i, val := range vals { - // if nil, we just set it to text. - if val == nil { - pgxVals[i] = nil - oids[i] = pgtype.TextOID - continue - } - - dt := reflect.TypeOf(vals[i]) - if (dt.Kind() == reflect.Slice || dt.Kind() == reflect.Array) && dt.Elem().Kind() != reflect.Uint8 { - valueOf := reflect.ValueOf(val) - arr := make([]any, valueOf.Len()) - var oid int - var err error - for j := 0; j < valueOf.Len(); j++ { - arr[j], oid, err = encodeScalar(valueOf.Index(j).Interface()) - if err != nil { - return nil, nil, err - } - } - pgxVals[i] = arr - - // the oid can be 0 if the array is empty. In that case, we just - // set it to text array, since we cannot infer it from an empty array. - if oid == 0 { - oids[i] = pgtype.TextArrayOID - } else { - oids[i] = uint32(oidArrMap[oid]) - } - } else { - var err error - var oid int - pgxVals[i], oid, err = encodeScalar(val) - if err != nil { - return nil, nil, err - } - oids[i] = uint32(oid) - } - } - - return pgxVals, oids, nil -} - -// decodeFromPGType decodes several pgx types to their corresponding Go types. -// It is capable of detecting special Kwil types and decoding them to their -// corresponding Go types. -func decodeFromPGType(vals ...any) ([]any, error) { - decodeScalar := func(v any) (any, error) { - switch v := v.(type) { - default: - return v, nil - - // we need to handle all ints as int64 since Kwil treats all - // ints as int64, but an integer literal in postgres can get - // returned as an int32 - case int: - return int64(v), nil - case int8: - return int64(v), nil - case int16: - return int64(v), nil - case int32: - return int64(v), nil - case int64: - return v, nil - case uint: - return int64(v), nil - case uint16: - return int64(v), nil - case uint32: - return int64(v), nil - case pgtype.UUID: - u := types.UUID(v.Bytes) - return &u, nil - case [16]byte: - u := types.UUID(v) - return &u, nil - case pgtype.Numeric: - if v.NaN { - return "NaN", nil - } - - // if we give postgres a number 5000, it will return it as 5 with exponent 3. - // Since kwil's decimal semantics do not allow negative scale, we need to multiply - // the number by 10^exp to get the correct value. - if v.Exp > 0 { - z := new(big.Int) - z.Exp(big.NewInt(10), big.NewInt(int64(v.Exp)), nil) - z.Mul(z, v.Int) - v.Int = z - v.Exp = 0 - } - - // there is a bit of an edge case here, where uint256 can be returned. - // since most results simply get returned to the user via JSON, it doesn't - // matter too much right now, so we'll leave it as-is. - return decimal.NewFromBigInt(v.Int, v.Exp) - } - } - - goVals := make([]any, len(vals)) - for i, val := range vals { - if val == nil { - goVals[i] = nil - continue - } - - dt := reflect.TypeOf(vals[i]) - if (dt.Kind() == reflect.Slice || dt.Kind() == reflect.Array) && dt.Elem().Kind() != reflect.Uint8 { - // we need to reflect the first type of the slice to determine what type the slice is. - // if empty, we return the slice as is. - valueOf := reflect.ValueOf(val) - - length := valueOf.Len() - if length == 0 { - goVals[i] = val - continue - } - - arr := make([]any, length) - for j := 0; j < length; j++ { - var err error - arr[j], err = decodeScalar(valueOf.Index(j).Interface()) - if err != nil { - return nil, err - } - } - - goVals[i] = arr - } else { - var err error - goVals[i], err = decodeScalar(val) - if err != nil { - return nil, err - } - } - } - - return goVals, nil -} - type txBeginner interface { BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) } -func queryTx(ctx context.Context, dbTx txBeginner, stmt string, args ...any) (*sql.ResultSet, error) { +func queryTx(ctx context.Context, oidToDataType map[uint32]*datatype, dbTx txBeginner, stmt string, args ...any) (*sql.ResultSet, error) { var resSet *sql.ResultSet err := pgx.BeginTxFunc(ctx, dbTx, pgx.TxOptions{ @@ -440,7 +227,7 @@ func queryTx(ctx context.Context, dbTx txBeginner, stmt string, args ...any) (*s }, func(tx pgx.Tx) error { var err error - resSet, err = query(ctx, tx, stmt, args...) + resSet, err = query(ctx, oidToDataType, tx, stmt, args...) return err }, ) diff --git a/internal/sql/pg/repl.go b/internal/sql/pg/repl.go index f9cb8e8f0..4e6b23b63 100644 --- a/internal/sql/pg/repl.go +++ b/internal/sql/pg/repl.go @@ -45,7 +45,7 @@ func replConn(ctx context.Context, host, port, user, pass, dbName string) (*pgco // the context only cancels creation of the connection. Use the quit function to // terminate the monitoring goroutine. func startRepl(ctx context.Context, conn *pgconn.PgConn, publicationName, slotName string, - schemaFilter func(string) bool) (chan []byte, chan error, context.CancelFunc, error) { + schemaFilter func(string) bool, writer *changesetIoWriter) (chan []byte, chan error, context.CancelFunc, error) { // Create the replication slot and start postgres sending WAL data. startLSN, err := createRepl(ctx, conn, publicationName, slotName) if err != nil { @@ -68,7 +68,7 @@ func startRepl(ctx context.Context, conn *pgconn.PgConn, publicationName, slotNa ctx2, cancel := context.WithCancel(context.Background()) go func() { defer close(commitHash) - done <- captureRepl(ctx2, conn, uint64(startLSN), commitHash, schemaFilter) + done <- captureRepl(ctx2, conn, uint64(startLSN), commitHash, schemaFilter, writer) }() return commitHash, done, cancel, nil @@ -132,9 +132,10 @@ var zeroHash, _ = hex.DecodeString("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b // captureRepl begins receiving and decoding messages. Consider the conn to be // hijacked after calling captureRepl, and do not use it or the stream can be -// broken. -func captureRepl(ctx context.Context, conn *pgconn.PgConn, startLSN uint64, - commitHash chan []byte, schemaFilter func(string) bool) error { +// broken. It decodeFullWal is true, it will return the entire wal serialized, +// instead of just the commit hash. +func captureRepl(ctx context.Context, conn *pgconn.PgConn, startLSN uint64, commitHash chan []byte, + schemaFilter func(string) bool, writer *changesetIoWriter) error { if cap(commitHash) == 0 { return errors.New("buffered commit hash channel required") } @@ -224,11 +225,14 @@ func captureRepl(ctx context.Context, conn *pgconn.PgConn, startLSN uint64, return fmt.Errorf("ParseXLogData failed: %w", err) } - commit, anySeq, err := decodeWALData(hasher, xld.WALData, relations, &inStream, stats, schemaFilter) + commit, anySeq, err := decodeWALData(hasher, xld.WALData, relations, &inStream, stats, schemaFilter, writer) if err != nil { return fmt.Errorf("decodeWALData failed: %w", err) } - if anySeq != -1 { + if anySeq != -1 { // the seq update at the beginning of a transaction + if seq != -1 { + return fmt.Errorf("sequence already set") + } seq = anySeq // the magic sentry table UPDATE that precedes commit } @@ -258,6 +262,7 @@ func captureRepl(ctx context.Context, conn *pgconn.PgConn, startLSN uint64, cid := binary.BigEndian.AppendUint64(nil, uint64(seq)) cid = append(cid, cHash...) + select { case commitHash <- cid: default: // don't block if the receiver has choked @@ -298,7 +303,7 @@ func (ws *walStats) reset() { // true if it was a commit message, or a non-negative seq value if it was a // special update message on the internal sentry table func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglogrepl.RelationMessageV2, - inStream *bool, stats *walStats, okSchema func(schema string) bool) (bool, int64, error) { + inStream *bool, stats *walStats, okSchema func(schema string) bool, changesetWriter *changesetIoWriter) (bool, int64, error) { logicalMsg, err := parseV3(walData, *inStream) if err != nil { return false, 0, fmt.Errorf("parse logical replication message: %w", err) @@ -320,9 +325,9 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog // from rolled back transactions. case *pglogrepl.CommitMessage: - logger.Debugf(" [msg] Commit: Commit LSN %v (%d), End LSN %v (%d), seq = %d", + logger.Debugf(" [msg] Commit: Commit LSN %v (%d), End LSN %v (%d)", logicalMsg.CommitLSN, uint64(logicalMsg.CommitLSN), - logicalMsg.TransactionEndLSN, uint64(logicalMsg.TransactionEndLSN), seq) + logicalMsg.TransactionEndLSN, uint64(logicalMsg.TransactionEndLSN)) done = true @@ -338,6 +343,11 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog break } + err = changesetWriter.decodeInsert(logicalMsg, rel) + if err != nil { + return false, 0, fmt.Errorf("decode insert: %w", err) + } + insertData := encodeInsertMsg(relName, &logicalMsg.InsertMessage) // logger.Debugf("insertData %x", insertData) hasher.Write(insertData) @@ -374,6 +384,11 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog break } + err = changesetWriter.decodeUpdate(logicalMsg, rel) + if err != nil { + return false, 0, fmt.Errorf("decode update: %w", err) + } + updateData := encodeUpdateMsg(relName, &logicalMsg.UpdateMessage) // logger.Debugf("updateData %x", updateData) hasher.Write(updateData) @@ -399,6 +414,11 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog break } + err = changesetWriter.decodeDelete(logicalMsg, rel) + if err != nil { + return false, 0, fmt.Errorf("decode delete: %w", err) + } + deleteData := encodeDeleteMsg(relName, &logicalMsg.DeleteMessage) // logger.Debugf("deleteData %x", deleteData) hasher.Write(deleteData) @@ -454,11 +474,18 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog // * msgs: Commit Prepared (NO regular "Commit" message) done = true // there will be a commit or a rollback, but this is the end of the update stream + err = changesetWriter.commit() + if err != nil { + return false, 0, fmt.Errorf("changeset commit error: %w", err) + } + case *CommitPreparedMessageV3: logger.Debugf(" [msg] COMMIT PREPARED TRANSACTION (id %v): Commit LSN %v (%d), End LSN %v (%d) \n", logicalMsg.UserGID, logicalMsg.CommitLSN, uint64(logicalMsg.CommitLSN), logicalMsg.EndCommitLSN, uint64(logicalMsg.EndCommitLSN)) - // done = true + // With a prepared transaction, we're ready for the commit ID and + // changeset once a PREPARE TRANSACTION message is received. This case + // just indicates that the second stage of commit is done. case *RollbackPreparedMessageV3: logger.Debugf(" [msg] ROLLBACK PREPARED TRANSACTION (id %v): Rollback LSN %v (%d), End LSN %v (%d) \n", @@ -466,6 +493,7 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog logicalMsg.EndLSN, uint64(logicalMsg.EndLSN)) hasher.Reset() + changesetWriter.fail() // discard changeset // v2 Stream control messages. Not expected for kwil case *pglogrepl.StreamStartMessageV2: diff --git a/internal/sql/pg/repl_changeset.go b/internal/sql/pg/repl_changeset.go new file mode 100644 index 000000000..263ced189 --- /dev/null +++ b/internal/sql/pg/repl_changeset.go @@ -0,0 +1,501 @@ +package pg + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pglogrepl" + "github.com/kwilteam/kwil-db/core/types" + "github.com/kwilteam/kwil-db/core/types/serialize" +) + +type changesetIoWriter struct { + metadata *changesetMetadata + oidToType map[uint32]*datatype + + writer io.Writer +} + +var ( + changesetInsertByte = byte(0x01) + changesetUpdateByte = byte(0x02) + changesetDeleteByte = byte(0x03) + changesetMetadataByte = byte(0x04) +) + +// registerMetadata registers a relation with the changeset metadata. +// it returns the index of the relation in the metadata. +func (c *changesetIoWriter) registerMetadata(relation *pglogrepl.RelationMessageV2) uint32 { + idx, ok := c.metadata.relationIdx[[2]string{relation.Namespace, relation.RelationName}] + if ok { + return uint32(idx) + } + + c.metadata.relationIdx[[2]string{relation.Namespace, relation.RelationName}] = len(c.metadata.Relations) + rel := &Relation{ + Schema: relation.Namespace, + Name: relation.RelationName, + Cols: make([]*Column, len(relation.Columns)), + } + + for i, col := range relation.Columns { + dt, ok := c.oidToType[col.DataType] + if !ok { + panic(fmt.Sprintf("unknown data type OID %d", col.DataType)) + } + + rel.Cols[i] = &Column{ + Name: col.Name, + Type: dt.KwilType, + } + } + + c.metadata.Relations = append(c.metadata.Relations, rel) + + return uint32(len(c.metadata.Relations) - 1) +} + +func (c *changesetIoWriter) decodeInsert(insert *pglogrepl.InsertMessageV2, relation *pglogrepl.RelationMessageV2) error { + if c == nil || c.writer == nil { // !c.writable.Load() + return nil + } + + idx := c.registerMetadata(relation) + + tup, err := convertPgxTuple(insert.Tuple, relation, c.oidToType) + if err != nil { + return err + } + tup.RelationIdx = idx + + bts, err := tup.serialize() + if err != nil { + return err + } + + _, err = c.writer.Write(append([]byte{changesetInsertByte}, bts...)) + // c.data = append(c.data, append([]byte{changesetInsertByte}, bts...)...) + return err +} + +func (c *changesetIoWriter) decodeUpdate(update *pglogrepl.UpdateMessageV2, relation *pglogrepl.RelationMessageV2) error { + if c == nil || c.writer == nil { + return nil + } + + idx := c.registerMetadata(relation) + + tup, err := convertPgxTuple(update.OldTuple, relation, c.oidToType) + if err != nil { + return err + } + tup.RelationIdx = idx + + bts, err := tup.serialize() + if err != nil { + return err + } + + tup, err = convertPgxTuple(update.NewTuple, relation, c.oidToType) + if err != nil { + return err + } + tup.RelationIdx = idx + + bts2, err := tup.serialize() + if err != nil { + return err + } + + _, err = c.writer.Write(append([]byte{changesetUpdateByte}, append(bts, bts2...)...)) + // c.data = append(c.data, append([]byte{changesetUpdateByte}, append(bts, bts2...)...)...) + return err +} + +func (c *changesetIoWriter) decodeDelete(delete *pglogrepl.DeleteMessageV2, relation *pglogrepl.RelationMessageV2) error { + if c == nil || c.writer == nil { + return nil + } + + idx := c.registerMetadata(relation) + + tup, err := convertPgxTuple(delete.OldTuple, relation, c.oidToType) + if err != nil { + return err + } + tup.RelationIdx = idx + + bts, err := tup.serialize() + if err != nil { + return err + } + + _, err = c.writer.Write(append([]byte{changesetDeleteByte}, bts...)) + // c.data = append(c.data, append([]byte{changesetDeleteByte}, bts...)...) + return err +} + +// commit is called when the changeset is complete. +// It exports the metadata to the writer. +// It zeroes the metadata, so that the changeset can be reused, +// and send a finish signal to the writer. +func (c *changesetIoWriter) commit() error { + if c == nil || c.writer == nil { + return nil + } + + bts, err := c.metadata.serialize() + if err != nil { + return err + } + + _, err = c.writer.Write(append([]byte{changesetMetadataByte}, bts...)) + // c.data = append(c.data, append([]byte{changesetMetadataByte}, bts...)...) + + c.metadata = &changesetMetadata{ + relationIdx: map[[2]string]int{}, + } + c.writer = nil + + return err +} + +// fail is called when the changeset is incomplete. +// It zeroes the metadata and writer, so that another changeset may be collected. +func (c *changesetIoWriter) fail() { + // if !c.writable.Load() { + // return + // } + + c.metadata = &changesetMetadata{ + relationIdx: map[[2]string]int{}, + } + c.writer = nil +} + +// ChangesetGroup is a group of changesets. +type ChangesetGroup struct { + // Changesets is a list of changesets, as they were + // encountered in the WAL stream. + // It is meant to be RLP encoded. + Changesets []*Changeset +} + +// convertPgxTuple converts a pgx TupleData to a Tuple. +func convertPgxTuple(pgxTuple *pglogrepl.TupleData, relation *pglogrepl.RelationMessageV2, oidToType map[uint32]*datatype) (*Tuple, error) { + tuple := &Tuple{ + Columns: make([]*TupleColumn, len(pgxTuple.Columns)), + } + + for i, col := range pgxTuple.Columns { + tupleCol := &TupleColumn{} + + dataType, ok := oidToType[relation.Columns[i].DataType] + if !ok { + return nil, fmt.Errorf("unknown data type OID %d", relation.Columns[i].DataType) + } + + switch col.DataType { + case pglogrepl.TupleDataTypeText: + tupleCol.ValueType = SerializedValue + encoded, err := dataType.SerializeChangeset(string(col.Data)) + if err != nil { + return nil, err + } + + tupleCol.Data = encoded + case pglogrepl.TupleDataTypeBinary: + panic("per pgx docs, we should never actually get this type") + case pglogrepl.TupleDataTypeNull: + tupleCol.ValueType = NullValue + case pglogrepl.TupleDataTypeToast: + tupleCol.ValueType = ToastValue + default: + panic(fmt.Sprintf("unknown tuple data type %d", col.DataType)) + } + + tuple.Columns[i] = tupleCol + } + + return tuple, nil +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +// It serializes the ChangesetGroup using RLP. We could probably make +// a custom encoding format that is faster and more compact, but for this +// initial implementation, we will use RLP. +func (c *ChangesetGroup) MarshalBinary() ([]byte, error) { + return serialize.Encode(c.Changesets) +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +func (c *ChangesetGroup) UnmarshalBinary(data []byte) error { + err := serialize.Decode(data, &c.Changesets) + if err != nil { + return err + } + + return nil +} + +// DeserializeChangeset deserializes a changeset a serialized changeset stream. +func DeserializeChangeset(data []byte) (*ChangesetGroup, error) { // todo: convert to io.Reader + var inserts []*Tuple + var updates [][2]*Tuple + var deletes []*Tuple + metadata := &changesetMetadata{} + var err error + for { + switch data[0] { + case changesetInsertByte: + tup := &Tuple{} + data, err = tup.deserialize(data[1:]) + if err != nil { + return nil, err + } + inserts = append(inserts, tup) + case changesetUpdateByte: + tup1 := &Tuple{} + data, err = tup1.deserialize(data[1:]) + if err != nil { + return nil, err + } + + tup2 := &Tuple{} + data, err = tup2.deserialize(data) + if err != nil { + return nil, err + } + + updates = append(updates, [2]*Tuple{tup1, tup2}) + case changesetDeleteByte: + tup := &Tuple{} + data, err = tup.deserialize(data[1:]) + if err != nil { + return nil, err + } + deletes = append(deletes, tup) + case changesetMetadataByte: + data, err = metadata.deserialize(data[1:]) + if err != nil { + return nil, err + } + + // this is the end of the changeset + if len(data) != 0 { + return nil, fmt.Errorf("unexpected data after metadata: %v", data) + } + default: + return nil, fmt.Errorf("unknown changeset byte %d", data[0]) + } + + if len(data) == 0 { + break + } + } + + group := &ChangesetGroup{ + Changesets: make([]*Changeset, len(metadata.Relations)), + } + + for i, rel := range metadata.Relations { + group.Changesets[i] = &Changeset{ + Schema: rel.Schema, + Table: rel.Name, + Columns: rel.Cols, + } + } + + for _, tup := range inserts { + group.Changesets[tup.RelationIdx].Inserts = append(group.Changesets[tup.RelationIdx].Inserts, tup) + } + + for _, tup := range updates { + group.Changesets[tup[0].RelationIdx].Updates = append(group.Changesets[tup[0].RelationIdx].Updates, tup) + } + + for _, tup := range deletes { + group.Changesets[tup.RelationIdx].Deletes = append(group.Changesets[tup.RelationIdx].Deletes, tup) + } + + return group, nil +} + +// Changeset is a set of changes to a table. +// It is meant to be RLP encoded, to be compact and easy to send over the wire, +// while also being deterministic. It is meant to translate a lot of the internal +// implementation details into changesets that are understood by higher-level +// Kwil components. +type Changeset struct { + // Schema is the PostgreSQL schema name. + Schema string + // Table is the name of the table. + Table string + // Columns is a list of column names and their values. + Columns []*Column + // Inserts is a list of tuples to insert. + Inserts []*Tuple + // Updates is a list of tuples pairs to update. + // The first tuple is the old tuple, the second is the new tuple. + Updates [][2]*Tuple + // Deletes is a list of tuples to delete. + // It is the values of each tuple before it was deleted. + Deletes []*Tuple +} + +// DecodeTuple decodes serialized tuple column values into their native types. +// Any value may be nil, depending on the ValueType. +func (c *Changeset) DecodeTuple(tuple *Tuple) ([]any, error) { + values := make([]any, len(tuple.Columns)) + for i, col := range tuple.Columns { + switch col.ValueType { + case NullValue: + values[i] = nil + case ToastValue: + values[i] = nil + case SerializedValue: + dt, ok := kwilTypeToDataType[*c.Columns[i].Type] + if !ok { + return nil, fmt.Errorf("unknown data type %s", c.Columns[i].Type) + } + val, err := dt.DeserializeChangeset(col.Data) + if err != nil { + return nil, err + } + + values[i] = val + } + } + + return values, nil +} + +// changesetMetadata contains metadata about a changeset. +type changesetMetadata struct { + // Relation is the schema and table name of the changeset. + // It is used as a key in the changeset map. + Relations []*Relation + // relationIdx is a map of relations, indexed by the schema and table name. + // it points to the index of the relation in the Relations list. + relationIdx map[[2]string]int +} + +// serialize serializes the metadata with the length of the serialized data +// as a 4-byte prefix. +func (m *changesetMetadata) serialize() ([]byte, error) { + bts, err := serialize.Encode(m.Relations) + if err != nil { + return nil, err + } + + buf := make([]byte, 4, 4+len(bts)) + binary.LittleEndian.PutUint32(buf, uint32(len(bts))) + + return append(buf, bts...), nil +} + +// deserialize deserializes the metadata. +// It returns the remaining data after the metadata. +func (m *changesetMetadata) deserialize(data []byte) ([]byte, error) { + if len(data) < 4 { + return nil, fmt.Errorf("data too short") + } + + size := binary.LittleEndian.Uint32(data[:4]) + if len(data) < int(size)+4 { + return nil, fmt.Errorf("data too short") + } + + err := serialize.Decode(data[4:4+size], &m.Relations) + if err != nil { + return nil, err + } + + m.relationIdx = make(map[[2]string]int) + for i, rel := range m.Relations { + m.relationIdx[[2]string{rel.Schema, rel.Name}] = i + } + + return data[4+size:], nil +} + +// Relation is a table in a schema. +type Relation struct { + Schema string + Name string + Cols []*Column +} + +// Column is a column name and value. +type Column struct { + Name string + Type *types.DataType +} + +// Tuple is a tuple of values. +type Tuple struct { + // relationIdx is the index of the relation in the changeset metadata struct. + RelationIdx uint32 + // Columns is a list of columns and their values. + Columns []*TupleColumn +} + +// serialize serializes the tuple with the length of the serialized data +// as a 4-byte prefix. +func (t *Tuple) serialize() ([]byte, error) { + bts, err := serialize.Encode(t) + if err != nil { + return nil, err + } + + buf := make([]byte, 4, 4+(len(bts))) + binary.LittleEndian.PutUint32(buf, uint32(len(bts))) + + return append(buf, bts...), nil +} + +// deserialize deserializes the tuple. +// It returns the remaining data after the tuple. +func (t *Tuple) deserialize(data []byte) ([]byte, error) { + if len(data) < 4 { + return nil, fmt.Errorf("data too short") + } + + size := binary.LittleEndian.Uint32(data[:4]) + if len(data) < int(size)+4 { + return nil, fmt.Errorf("data too short") + } + + err := serialize.Decode(data[4:4+size], &t) + if err != nil { + return nil, err + } + + return data[4+size:], nil +} + +// TupleColumn is a column within a tuple. +type TupleColumn struct { + // ValueType gives information on the type of data in the column. + // If the type is of type Null or Toast, the Data field will be nil. + ValueType ValueType + // Data is the actual data in the column. + Data []byte +} + +// ValueType gives information on the type of data in a tuple column. +type ValueType uint8 + +const ( + // NullValue indicates a NULL value + // (as opposed to something like an empty string). + NullValue ValueType = iota + // ToastValue indicates a column is a TOAST pointer, + // and that the actual value is stored elsewhere and + // was unchanged. + ToastValue + // SerializedValue indicates a column is a non-nil value + // and can be deserialized. + SerializedValue +) diff --git a/internal/sql/pg/repl_test.go b/internal/sql/pg/repl_test.go index 7e52fe91b..d5650f79e 100644 --- a/internal/sql/pg/repl_test.go +++ b/internal/sql/pg/repl_test.go @@ -57,7 +57,7 @@ func Test_repl(t *testing.T) { const publicationName = "kwild_repl" var slotName = publicationName + random.String(8) - commitChan, errChan, quit, err := startRepl(ctx, conn, publicationName, slotName, schemaFilter) + commitChan, errChan, quit, err := startRepl(ctx, conn, publicationName, slotName, schemaFilter, &changesetIoWriter{}) if err != nil { t.Fatal(err) } @@ -74,7 +74,7 @@ func Test_repl(t *testing.T) { t.Fatal(err) } - wantCommitHash, _ := hex.DecodeString("cb390afbf808256307ee0927999805ee3d5af193772e2c9b71823fbc1fe8867f") + wantCommitHash, _ := hex.DecodeString("1fd01e9d38e322285723643f27123762c3d7fd22761f7ab4de57e0a947f8127b") var wg sync.WaitGroup wg.Add(1) diff --git a/internal/sql/pg/replmon.go b/internal/sql/pg/replmon.go index 4aba403bd..33450ebde 100644 --- a/internal/sql/pg/replmon.go +++ b/internal/sql/pg/replmon.go @@ -12,6 +12,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "sync" "github.com/jackc/pgx/v5/pgconn" @@ -46,22 +47,39 @@ type replMon struct { err error // specific error, safe to read after done is closed mtx sync.Mutex - results map[int64][]byte // results should generally be unused as pg.DB will request a promise before commit promises map[int64]chan []byte + // the map above was used to support multiple concurrent write txns, but + // this is never the case where one replMon is only used by one pg.DB since + // pg.DB disallows multiple outer write transactions. consider just making + // this a chan field + + // changesetWriters map[int64]io.Writer // maps the sequence number to the changeset writer + changesetWriter *changesetIoWriter } // newReplMon creates a new connection and logical replication data monitor, and // immediately starts receiving messages from the host. A consumer should // request a commit ID promise using the recvID method prior to committing a // transaction. -func newReplMon(ctx context.Context, host, port, user, pass, dbName string, schemaFilter func(string) bool) (*replMon, error) { +func newReplMon(ctx context.Context, host, port, user, pass, dbName string, schemaFilter func(string) bool, + oidToTypes map[uint32]*datatype) (*replMon, error) { conn, err := replConn(ctx, host, port, user, pass, dbName) if err != nil { return nil, err } + // we set the changeset io.Writer to nil, as the changesetIoWriter will skip all writes + // until enabled by setting the atomic.Bool to true. + cs := &changesetIoWriter{ + metadata: &changesetMetadata{ + relationIdx: map[[2]string]int{}, + }, + oidToType: oidToTypes, + // writer is nil, set in caller prior to preparing txns, ignored if left nil + } + var slotName = publicationName + random.String(8) // arbitrary, so just avoid collisions - commitChan, errChan, quit, err := startRepl(ctx, conn, publicationName, slotName, schemaFilter) + commitChan, errChan, quit, err := startRepl(ctx, conn, publicationName, slotName, schemaFilter, cs) if err != nil { quit() conn.Close(ctx) @@ -69,11 +87,11 @@ func newReplMon(ctx context.Context, host, port, user, pass, dbName string, sche } rm := &replMon{ - conn: conn, - quit: quit, - done: make(chan struct{}), - results: make(map[int64][]byte), - promises: make(map[int64]chan []byte), + conn: conn, + quit: quit, + done: make(chan struct{}), + promises: make(map[int64]chan []byte), + changesetWriter: cs, } go func() { @@ -97,7 +115,6 @@ func newReplMon(ctx context.Context, host, port, user, pass, dbName string, sche // This is unexpected since pg.DB will call recvID first. If we are // in this `else`, it is to be discarded, from another connection. logger.Warnf("Received commit ID for seq %d BEFORE recvID", seq) - rm.results[seq] = cHash } rm.mtx.Unlock() } @@ -113,7 +130,7 @@ func newReplMon(ctx context.Context, host, port, user, pass, dbName string, sche // this channel-based approach is so that the commit ID is guaranteed to pertain // to the requested sequence number. -func (rm *replMon) recvID(seq int64) (chan []byte, bool) { +func (rm *replMon) recvID(seq int64, w io.Writer) (chan []byte, bool) { // Ensure a commit ID can be promised before we give one. select { case <-rm.done: @@ -123,22 +140,14 @@ func (rm *replMon) recvID(seq int64) (chan []byte, bool) { c := make(chan []byte, 1) - // first check if the results is already in the map, otherwise make the - // promise and store it rm.mtx.Lock() defer rm.mtx.Unlock() - if cHash, ok := rm.results[seq]; ok { - // The intended use is to do recvID BEFORE - logger.Warnf("recvID with EXISTING result for sequence %d", seq) - delete(rm.results, seq) - c <- cHash - return c, true - } - if _, have := rm.promises[seq]; have { - logger.Errorf("Commit ID promise for sequence %d ALREADY EXISTS", seq) + panic(fmt.Sprintf("Commit ID promise for sequence %d ALREADY EXISTS", seq)) } - rm.promises[seq] = c // maybe panic if one already exists, indicating program logic error + rm.promises[seq] = c + + rm.changesetWriter.writer = w // could be a map write, just starting simple return c, true } diff --git a/internal/sql/pg/sql.go b/internal/sql/pg/sql.go index eccfbf543..84fc1bb27 100644 --- a/internal/sql/pg/sql.go +++ b/internal/sql/pg/sql.go @@ -2,7 +2,6 @@ package pg import ( "context" - _ "embed" "errors" "fmt" "time" @@ -16,32 +15,17 @@ var ( // table's "replication identity" is explicitly set to "full". We ensure // that is the case by creating an event trigger to perform the ALTER TABLE // command whenever a DDL command with the "CREATE TABLE" tag is processed - // for a table with neither a primary key or unique index. These are the - // embedded plpgsql functions below. + // for a table with neither a primary key or unique index. We also do this + // for all tables that even have a primary key or unique index so that we + // can get a full changeset with the old values that are updated or deleted, + // not just the primary keys. - //go:embed trigger_repl1.sql - sqlFuncReplIfNeeded string - - //go:embed trigger_repl2.sql - sqlFuncReplIfNeeded2 string //nolint:unused - // (I'm still deciding which to use) - - //nolint:unused - sqlCreateFuncReplIdentExists = `SELECT EXISTS ( - SELECT 1 FROM pg_proc - WHERE proname = 'set_replica_identity_full' - );` - // (replace might be brute; this checks if the repl trigger created in sqlFuncReplIfNeeded exists) - - sqlCreateEvtTriggerReplIdentExists = `SELECT EXISTS ( - SELECT 1 FROM pg_event_trigger - WHERE evtname = 'trg_set_replica_identity_full' - );` - - sqlCreateEvtTriggerReplIdent = `CREATE EVENT TRIGGER trg_set_replica_identity_full ON ddl_command_end + sqlCreateEvtTriggerReplIdent = `CREATE EVENT TRIGGER set_replica_identity_on_create + ON ddl_command_end WHEN TAG IN ('CREATE TABLE') - EXECUTE FUNCTION set_replica_identity_full();` - // TIP for node reset/cleanup: DROP EVENT TRIGGER IF EXISTS trg_set_replica_identity_full; + EXECUTE FUNCTION set_replica_identity();` + + sqlDropEvtTriggerReplIdent = `DROP EVENT TRIGGER IF EXISTS set_replica_identity_on_create;` sqlCreatePublicationINE = `DO $$ BEGIN @@ -77,6 +61,21 @@ END$$;` EXCEPTION WHEN duplicate_object THEN null; END $$;` + + sqlCreateOrReplaceReplicaIdentity = `CREATE OR REPLACE FUNCTION set_replica_identity() +RETURNS event_trigger +LANGUAGE plpgsql +AS $$ +DECLARE + obj record; +BEGIN + FOR obj IN + SELECT * FROM pg_event_trigger_ddl_commands() WHERE command_tag = 'CREATE TABLE' + LOOP + EXECUTE 'ALTER TABLE ' || obj.object_identity || ' REPLICA IDENTITY FULL'; + END LOOP; +END; +$$;` ) func checkSuperuser(ctx context.Context, conn *pgx.Conn) error { @@ -188,22 +187,22 @@ func tableExists(ctx context.Context, schema, table string, conn *pgx.Conn) (boo return pgx.CollectExactlyOneRow(rows, pgx.RowTo[bool]) } -func ensureTriggerReplIdentity(ctx context.Context, conn *pgx.Conn) error { - // First create the function if needed. - _, err := conn.Exec(ctx, sqlFuncReplIfNeeded) +// ensureFullReplicaIdentityTrigger creates an event trigger to set the replica +// identity to "full" for all tables that are created. +func ensureFullReplicaIdentityTrigger(ctx context.Context, conn *pgx.Conn) error { + // Create the function for the even trigger + _, err := conn.Exec(ctx, sqlCreateOrReplaceReplicaIdentity) if err != nil { return err } - // Create the trigger for the function if needed. - rows, _ := conn.Query(ctx, sqlCreateEvtTriggerReplIdentExists) - triggerExists, err := pgx.CollectExactlyOneRow(rows, pgx.RowTo[bool]) + // Create the event trigger that calls the function. + // Drop it always in case we update the logic, new nodes will automatically get the new logic + _, err = conn.Exec(ctx, sqlDropEvtTriggerReplIdent) if err != nil { return err } - if triggerExists { - return nil - } + _, err = conn.Exec(ctx, sqlCreateEvtTriggerReplIdent) return err } diff --git a/internal/sql/pg/trigger_repl1.sql b/internal/sql/pg/trigger_repl1.sql deleted file mode 100644 index 367809e05..000000000 --- a/internal/sql/pg/trigger_repl1.sql +++ /dev/null @@ -1,34 +0,0 @@ -CREATE OR REPLACE FUNCTION set_replica_identity_full() -RETURNS event_trigger AS $$ -DECLARE - obj record; - has_primary_key boolean; - has_unique_index boolean; -BEGIN - FOR obj IN SELECT * FROM pg_event_trigger_ddl_commands() - WHERE command_tag = 'CREATE TABLE' AND object_type = 'table' - LOOP - SELECT EXISTS ( - SELECT 1 - FROM information_schema.table_constraints - WHERE table_schema || '.' || table_name = obj.object_identity - AND constraint_type = 'PRIMARY KEY' - ) INTO has_primary_key; - - SELECT EXISTS ( - SELECT 1 - FROM pg_indexes - WHERE schemaname || '.' || tablename = obj.object_identity - AND indexdef LIKE '%UNIQUE%' - ) INTO has_unique_index; - - -- alter table only if there is (no primary key) and (no unique index) - IF NOT has_primary_key AND NOT has_unique_index THEN - EXECUTE 'ALTER TABLE ' || obj.object_identity || ' REPLICA IDENTITY FULL'; -- note that object_identity is schema qualified - RAISE NOTICE 'Altered table: % to set REPLICA IDENTITY FULL.', obj.object_identity; - ELSE - RAISE NOTICE 'Table: % already has a primary key or unique index.', obj.object_identity; - END IF; - END LOOP; -END; -$$ LANGUAGE plpgsql; diff --git a/internal/sql/pg/trigger_repl2.sql b/internal/sql/pg/trigger_repl2.sql deleted file mode 100644 index 65a50e93c..000000000 --- a/internal/sql/pg/trigger_repl2.sql +++ /dev/null @@ -1,28 +0,0 @@ -CREATE OR REPLACE FUNCTION set_replica_identity_full() -RETURNS event_trigger AS $$ -DECLARE - obj record; - has_key_or_index boolean; -BEGIN - FOR obj IN SELECT * FROM pg_event_trigger_ddl_commands() - WHERE command_tag = 'CREATE TABLE' AND object_type = 'table' - LOOP - SELECT EXISTS ( - SELECT 1 - FROM pg_class c - JOIN pg_namespace n ON c.relnamespace = n.oid - LEFT JOIN pg_index i ON c.oid = i.indrelid - WHERE n.nspname || '.' || c.relname = obj.object_identity - AND (i.indisprimary OR i.indisunique) - ) INTO has_key_or_index; - - -- alter table only if there is (no primary key) and (no unique index) - IF NOT has_key_or_index THEN - EXECUTE 'ALTER TABLE ' || obj.object_identity || ' REPLICA IDENTITY FULL'; - RAISE NOTICE 'Altered table: % to set REPLICA IDENTITY FULL.', obj.object_identity; - ELSE - RAISE NOTICE 'Table: % already has a primary key or unique index.', obj.object_identity; - END IF; - END LOOP; -END; -$$ LANGUAGE plpgsql; diff --git a/internal/sql/pg/tx.go b/internal/sql/pg/tx.go index aa50c279d..63a6ba601 100644 --- a/internal/sql/pg/tx.go +++ b/internal/sql/pg/tx.go @@ -4,6 +4,7 @@ package pg import ( "context" + "io" "github.com/jackc/pgx/v5" common "github.com/kwilteam/kwil-db/common/sql" @@ -19,6 +20,7 @@ type releaser interface { type nestedTx struct { pgx.Tx accessMode common.AccessMode + oidTypes map[uint32]*datatype } var _ common.Tx = (*nestedTx)(nil) @@ -36,17 +38,18 @@ func (tx *nestedTx) BeginTx(ctx context.Context) (common.Tx, error) { return &nestedTx{ Tx: pgtx, accessMode: tx.accessMode, + oidTypes: tx.oidTypes, }, nil } func (tx *nestedTx) Query(ctx context.Context, stmt string, args ...any) (*common.ResultSet, error) { - return query(ctx, tx.Tx, stmt, args...) + return query(ctx, tx.oidTypes, tx.Tx, stmt, args...) } // Execute is now literally identical to Query in both semantics and syntax. We // might remove one or the other in this context (transaction methods). func (tx *nestedTx) Execute(ctx context.Context, stmt string, args ...any) (*common.ResultSet, error) { - return query(ctx, tx.Tx, stmt, args...) + return query(ctx, tx.oidTypes, tx.Tx, stmt, args...) } // AccessMode returns the access mode of the transaction. @@ -66,9 +69,10 @@ type dbTx struct { // Precommit creates a prepared transaction for a two-phase commit. An ID // derived from the updates is return. This must be called before Commit. Either -// Commit or Rollback must follow. -func (tx *dbTx) Precommit(ctx context.Context) ([]byte, error) { - return tx.db.precommit(ctx) +// Commit or Rollback must follow. It takes a writer to write the full changeset to. +// If the writer is nil, the changeset will not be written. +func (tx *dbTx) Precommit(ctx context.Context, writer io.Writer) ([]byte, error) { + return tx.db.precommit(ctx, writer) } // Commit commits the transaction. This partly satisfies sql.Tx. diff --git a/internal/sql/pg/types.go b/internal/sql/pg/types.go new file mode 100644 index 000000000..40aea1b73 --- /dev/null +++ b/internal/sql/pg/types.go @@ -0,0 +1,936 @@ +package pg + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "math/big" + "reflect" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/kwilteam/kwil-db/core/types" + "github.com/kwilteam/kwil-db/core/types/decimal" +) + +func init() { + registerDatatype(textType, textArrayType) + registerDatatype(intType, intArrayType) + registerDatatype(boolType, boolArrayType) + registerDatatype(blobType, blobArrayType) + registerDatatype(uuidType, uuidArrayType) + registerDatatype(decimalType, decimalArrayType) + registerDatatype(uint256Type, uint256ArrayType) +} + +var ( + dataTypesByMatch = map[reflect.Type]*datatype{} + scalarToArray = map[*datatype]*datatype{} // maps the scalar type to the array type + datatypes = map[*datatype]struct{}{} // a set of all data types (used for iteration) + kwilTypeToDataType = map[types.DataType]*datatype{} +) + +// registerOIDs registers all of the data types that we support in Postgres. +func registerDatatype(scalar *datatype, array *datatype) { + for _, match := range scalar.Matches { + _, ok := dataTypesByMatch[match] + if ok { + panic(fmt.Sprintf("data type %T already registered", match)) + } + + dataTypesByMatch[match] = scalar + datatypes[scalar] = struct{}{} + } + + for _, match := range array.Matches { + _, ok := dataTypesByMatch[match] + if ok { + panic(fmt.Sprintf("data type %T already registered", match)) + } + + dataTypesByMatch[match] = array + datatypes[array] = struct{}{} + } + + _, ok := kwilTypeToDataType[*scalar.KwilType] + if ok { + k := kwilTypeToDataType + _ = k + panic(fmt.Sprintf("Kwil type %s already registered", scalar.KwilType.String())) + } + + kwilTypeToDataType[*scalar.KwilType] = scalar + + _, ok = kwilTypeToDataType[*array.KwilType] + if ok { + panic(fmt.Sprintf("Kwil type %s already registered", array.KwilType.String())) + } + + kwilTypeToDataType[*array.KwilType] = array + + scalarToArray[scalar] = array +} + +// datatype allows us to easily register new data types. +// It is used to define how to encode and decode data types in Postgres. +// While all of the implementations for this are stored in the PG package, +// the primary reason for identifying this as an interface is to allow for +// easy addition of types in the future (knowing what needs to be implemented +// to support new data types). +type datatype struct { + // KwilType is the Kwil-native data type that is tied to this data type. + // There must be exactly one. It will ignore all metadata (e.g. for decimal, any + // precision/scale is ignore). + KwilType *types.DataType + // Matches is the list of all data types that this type matches. + // These will be stored in a map, and thus each match type can only be + // used once across all data types. + Matches []reflect.Type + // OID returns the OID of the data type in Postgres. + // It will be given to Postgres when encoding the data type + // with QueryModeInferredArgTypes, and will also be used to identify + // how values should be decoded. + OID func(*pgtype.Map) uint32 + // ExtraOIDs returns any additional OIDs which the data type can be decoded from. + // This is useful for int types, which can be decoded from int2, int4, and int8. + // These will be used in addition to the OID returned by OID(). + // This can be nil if there are no additional OIDs. + ExtraOIDs []uint32 + // EncodeInferred encodes a value into a byte slice, given the type of the value. + // The passed value will always be of a type that matches one of the Matches types. + // It must return the serialized data. + // This is used when operating in QueryModeInferredArgTypes, to infer the postgres + // data type from the native go type. + // If not using QueryModeInferredArgTypes, it will be encoded using a driver.Valuer, + // or as a native go type. + EncodeInferred func(any) (any, error) + // Decode decodes a data type received from Postgres. The input will either be a data type + // native to Go, a type defined in pgx, or a type in a custom pgx Codec (which we currently + // don't use). + Decode func(any) (any, error) + // SerializeChangeset decodes a data type received from Postgres as a string. PGX only returns + // replication data as strings, so this is used to decode replication data. Decode will never be called + // with null values, but it may be called with empty strings / 0 values. + // https://github.com/jackc/pglogrepl/blob/828fbfe908e97cfeb409a17e4ec339dede1f1a17/message.go#L379 + SerializeChangeset func(value string) ([]byte, error) + // DeserializeChangeset encodes a data type from a changeset to its native Go/Kwil type. + // This can then be used to execute an incoming changeset against a database. + DeserializeChangeset func([]byte) (any, error) +} + +var ( + textType = &datatype{ + KwilType: types.TextType, + Matches: []reflect.Type{reflect.TypeOf("")}, + OID: func(*pgtype.Map) uint32 { return pgtype.TextOID }, + EncodeInferred: defaultEncodeDecode, + Decode: defaultEncodeDecode, + SerializeChangeset: func(value string) ([]byte, error) { + return []byte(value), nil + }, + DeserializeChangeset: func(b []byte) (any, error) { + return string(b), nil + }, + } + + textArrayType = &datatype{ + KwilType: types.TextArrayType, + Matches: []reflect.Type{reflect.TypeOf([]string{})}, + OID: func(*pgtype.Map) uint32 { return pgtype.TextArrayOID }, + EncodeInferred: defaultEncodeDecode, + Decode: decodeArray[string](textType.Decode), + SerializeChangeset: func(value string) ([]byte, error) { + // text arrays are delimited by commas, so we need to split on commas. + // We also need to ensure that the commas + var ok bool + value, ok = trimCurlys(value) + if !ok { + return nil, fmt.Errorf("invalid text array: %s", value) + } + + // each string is now wrapped in double quotes in the text literal, + // e.g. "aaa","bbb","c\"cc" + // we need to split on "," but not on "\",\"" + inQuote := false + var strs []string + currentStr := strings.Builder{} + i := 0 + for i < len(value) { + v := value[i] + switch v { + case '\\': + if len(value) <= i+1 { + return nil, fmt.Errorf("invalid text array: %s", value) + } + // add the next character to the string + currentStr.WriteByte(value[i+1]) + i++ + case '"': + // toggle inQuote + inQuote = !inQuote + case ',': + if inQuote { + currentStr.WriteByte(v) + } else { + strs = append(strs, currentStr.String()) + currentStr.Reset() + } + default: + currentStr.WriteByte(v) + } + i++ + } + + // add the last string + strs = append(strs, currentStr.String()) + + return serializeArray(strs, 4, textType.SerializeChangeset) + }, + DeserializeChangeset: deserializeArrayFn[string](4, textType.DeserializeChangeset), + } + + // we intentionally ignore uint8, since we don't want to cause issues with []byte. + intType = &datatype{ + KwilType: types.IntType, + Matches: []reflect.Type{reflect.TypeOf(int(0)), reflect.TypeOf(int8(0)), reflect.TypeOf(int16(0)), reflect.TypeOf(int32(0)), reflect.TypeOf(int64(0)), reflect.TypeOf(uint(0)), reflect.TypeOf(uint16(0)), reflect.TypeOf(uint32(0)), reflect.TypeOf(uint64(0))}, + OID: func(*pgtype.Map) uint32 { return pgtype.Int8OID }, + ExtraOIDs: []uint32{pgtype.Int2OID, pgtype.Int4OID}, + EncodeInferred: defaultEncodeDecode, + Decode: func(a any) (any, error) { + switch v := a.(type) { + case int: + return int64(v), nil + case int8: + return int64(v), nil + case int16: + return int64(v), nil + case int32: + return int64(v), nil + case int64: + return v, nil + case uint: + return int64(v), nil + case uint16: + return int64(v), nil + case uint32: + return int64(v), nil + case uint64: + return int64(v), nil + default: + return nil, fmt.Errorf("unexpected type %T", a) + } + }, + SerializeChangeset: func(value string) ([]byte, error) { + intVal, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return nil, err + } + + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(intVal)) + return buf, nil + }, + DeserializeChangeset: func(b []byte) (any, error) { + return int64(binary.LittleEndian.Uint64(b)), nil + }, + } + + intArrayType = &datatype{ + KwilType: types.IntArrayType, + Matches: []reflect.Type{reflect.TypeOf([]int{}), reflect.TypeOf([]int8{}), reflect.TypeOf([]int16{}), reflect.TypeOf([]int32{}), reflect.TypeOf([]int64{}), reflect.TypeOf([]uint{}), reflect.TypeOf([]uint16{}), reflect.TypeOf([]uint32{}), reflect.TypeOf([]uint64{})}, + OID: func(*pgtype.Map) uint32 { return pgtype.Int8ArrayOID }, + ExtraOIDs: []uint32{pgtype.Int2ArrayOID, pgtype.Int4ArrayOID}, + EncodeInferred: defaultEncodeDecode, + Decode: decodeArray[int64](intType.Decode), + SerializeChangeset: arrayFromChildFunc(1, intType.SerializeChangeset), + DeserializeChangeset: deserializeArrayFn[int64](1, intType.DeserializeChangeset), + } + + boolType = &datatype{ + KwilType: types.BoolType, + Matches: []reflect.Type{reflect.TypeOf(true)}, + OID: func(*pgtype.Map) uint32 { return pgtype.BoolOID }, + EncodeInferred: defaultEncodeDecode, + Decode: defaultEncodeDecode, + SerializeChangeset: func(value string) ([]byte, error) { + if strings.EqualFold(value, "true") || strings.EqualFold(value, "t") { + return []byte{1}, nil + } + if strings.EqualFold(value, "false") || strings.EqualFold(value, "f") { + return []byte{0}, nil + } + return nil, fmt.Errorf("invalid boolean value: %s", value) + }, + DeserializeChangeset: func(b []byte) (any, error) { + return b[0] == 1, nil + }, + } + + boolArrayType = &datatype{ + KwilType: types.BoolArrayType, + Matches: []reflect.Type{reflect.TypeOf([]bool{})}, + OID: func(*pgtype.Map) uint32 { return pgtype.BoolArrayOID }, + EncodeInferred: defaultEncodeDecode, + Decode: decodeArray[bool](boolType.Decode), + SerializeChangeset: arrayFromChildFunc(1, boolType.SerializeChangeset), + DeserializeChangeset: deserializeArrayFn[bool](1, boolType.DeserializeChangeset), + } + + blobType = &datatype{ + KwilType: types.BlobType, + Matches: []reflect.Type{reflect.TypeOf([]byte{})}, + OID: func(*pgtype.Map) uint32 { return pgtype.ByteaOID }, + EncodeInferred: defaultEncodeDecode, + Decode: defaultEncodeDecode, + SerializeChangeset: func(value string) ([]byte, error) { + // postgres returns all blobs as hex, prefixed with \x + // we need to remove the \x and decode the hex + if len(value) < 2 { + return nil, fmt.Errorf("invalid blob value: %s", value) + } + + if value[0] != '\\' || value[1] != 'x' { + return nil, fmt.Errorf("invalid blob value: %s", value) + } + + return hex.DecodeString(value[2:]) + }, + DeserializeChangeset: func(b []byte) (any, error) { + return b, nil + }, + } + + blobArrayType = &datatype{ + KwilType: types.BlobArrayType, + Matches: []reflect.Type{reflect.TypeOf([][]byte{})}, + OID: func(*pgtype.Map) uint32 { return pgtype.ByteaArrayOID }, + EncodeInferred: defaultEncodeDecode, + Decode: decodeArray[[]byte](blobType.Decode), + SerializeChangeset: func(value string) ([]byte, error) { + // postgres wraps each hex encoded blob in double quotes, so we need to remove them + var ok bool + value, ok = trimCurlys(value) + if !ok { + return nil, fmt.Errorf("invalid blob array: %s", value) + } + + // each blob is now wrapped in double quotes in the text literal, + vals := strings.Split(value, ",") + + bts := make([][]byte, len(vals)) + for i, v := range vals { + if !strings.HasPrefix(v, `"`) || !strings.HasSuffix(v, `"`) { + return nil, fmt.Errorf("invalid blob array: %s", value) + } + + vals[i] = v[1 : len(v)-1] + + // for some reason, postgres adds an additional escape character to the hex in an array + // that is not present in a single value. We need to remove it. + // This irregularity is tested in db_live_test.go + if len(vals[i]) == 0 { + return nil, fmt.Errorf("invalid blob array, expected some value: %s", value) + } + + if vals[i][0] != '\\' { + return nil, fmt.Errorf("invalid blob array, expected \\: %s", value) + } + + // decode the hex + b, err := blobType.SerializeChangeset(vals[i][1:]) + if err != nil { + return nil, err + } + + bts[i] = b + } + + return serializeArray(bts, 4, func(b []byte) ([]byte, error) { + return b, nil + }) + }, + DeserializeChangeset: deserializeArrayFn[[]byte](4, blobType.DeserializeChangeset), + } + + uuidType = &datatype{ + KwilType: types.UUIDType, + Matches: []reflect.Type{reflect.TypeOf(types.NewUUIDV5([]byte{})), reflect.TypeOf(*types.NewUUIDV5([]byte{}))}, + OID: func(*pgtype.Map) uint32 { return pgtype.UUIDOID }, + EncodeInferred: func(v any) (any, error) { + var val *types.UUID + switch v := v.(type) { + case types.UUID: + val = &v + case *types.UUID: + val = v + default: + panic("unreachable") + } + + return pgtype.UUID{ + Bytes: [16]byte(val.Bytes()), + Valid: true, + }, nil + }, + Decode: func(v any) (any, error) { + var u types.UUID + switch v := v.(type) { + case pgtype.UUID: + u = types.UUID(v.Bytes) + case [16]byte: + u = types.UUID(v) + default: + return nil, fmt.Errorf("unexpected type decoding uuid %T", v) + } + return &u, nil + }, + SerializeChangeset: func(value string) ([]byte, error) { + u, err := types.ParseUUID(value) + if err != nil { + return nil, err + } + return u.Bytes(), nil + }, + DeserializeChangeset: func(b []byte) (any, error) { + u := types.UUID(b) + return &u, nil + }, + } + + uuidArrayType = &datatype{ + KwilType: types.UUIDArrayType, + Matches: []reflect.Type{reflect.TypeOf(types.UUIDArray{})}, + OID: func(*pgtype.Map) uint32 { return pgtype.UUIDArrayOID }, + EncodeInferred: func(v any) (any, error) { + val, ok := v.(types.UUIDArray) + if !ok { + return nil, fmt.Errorf("expected UUIDArray, got %T", v) + } + + var arr []any + for _, u := range val { + v2, err := uuidType.EncodeInferred(u) + if err != nil { + return nil, err + } + arr = append(arr, v2) + } + + return arr, nil + }, + Decode: func(a any) (any, error) { + arr, ok := a.([]any) // pgx always returns arrays as []any + if !ok { + return nil, fmt.Errorf("expected []any, got %T", a) + } + + vals := make(types.UUIDArray, len(arr)) + for i, v := range arr { + val, err := uuidType.Decode(v) + if err != nil { + return nil, err + } + vals[i] = val.(*types.UUID) + } + + return vals, nil + }, + SerializeChangeset: arrayFromChildFunc(1, uuidType.SerializeChangeset), + DeserializeChangeset: deserializeArrayFn[*types.UUID](1, uuidType.DeserializeChangeset), + } + + decimalType = &datatype{ + KwilType: types.DecimalType, + Matches: []reflect.Type{reflect.TypeOf(decimal.Decimal{}), reflect.TypeOf(&decimal.Decimal{})}, + OID: func(*pgtype.Map) uint32 { return pgtype.NumericOID }, + EncodeInferred: func(v any) (any, error) { + var dec *decimal.Decimal + switch v := v.(type) { + case decimal.Decimal: + dec = &v + case *decimal.Decimal: + dec = v + default: + return nil, fmt.Errorf("unexpected type encoding decimal %T", v) + } + + return pgtype.Numeric{ + Int: dec.BigInt(), + Exp: dec.Exp(), + Valid: true, + }, nil + }, + Decode: func(a any) (any, error) { + pgType, ok := a.(pgtype.Numeric) + if !ok { + return nil, fmt.Errorf("expected pgtype.Numeric, got %T", a) + } + + if pgType.NaN { + return "NaN", nil + } + + // if we give postgres a number such as 5000, it will return it as 5 with exponent 3. + // Since kwil's decimal semantics do not allow negative scale, we need to multiply + // the number by 10^exp to get the correct value. + if pgType.Exp > 0 { + z := new(big.Int) + z.Exp(big.NewInt(10), big.NewInt(int64(pgType.Exp)), nil) + z.Mul(z, pgType.Int) + pgType.Int = z + pgType.Exp = 0 + } + + // there is a bit of an edge case here, where uint256 can be returned. + // since most results simply get returned to the user via JSON, it doesn't + // matter too much right now, so we'll leave it as-is. + return decimal.NewFromBigInt(pgType.Int, pgType.Exp) + }, + SerializeChangeset: func(value string) ([]byte, error) { + // parse to ensure it is a valid decimal, then re-encode it to ensure it is in the correct format. + dec, err := decimal.NewFromString(value) + if err != nil { + return nil, err + } + + return []byte(dec.String()), nil + }, + DeserializeChangeset: func(b []byte) (any, error) { + return decimal.NewFromString(string(b)) + }, + } + + decimalArrayType = &datatype{ + KwilType: types.DecimalArrayType, + Matches: []reflect.Type{reflect.TypeOf(decimal.DecimalArray{})}, + OID: func(*pgtype.Map) uint32 { return pgtype.NumericArrayOID }, + EncodeInferred: func(v any) (any, error) { + val, ok := v.(decimal.DecimalArray) + if !ok { + return nil, fmt.Errorf("expected DecimalArray, got %T", v) + } + + var arr []pgtype.Numeric + for _, d := range val { + v2, err := decimalType.EncodeInferred(d) + if err != nil { + return nil, err + } + arr = append(arr, v2.(pgtype.Numeric)) + } + + return arr, nil + }, + Decode: func(a any) (any, error) { + arr, ok := a.([]any) // pgx always returns arrays as []any + if !ok { + return nil, fmt.Errorf("expected []any, got %T", a) + } + + vals := make(decimal.DecimalArray, len(arr)) + for i, v := range arr { + val, err := decimalType.Decode(v) + if err != nil { + return nil, err + } + vals[i] = val.(*decimal.Decimal) + } + + return vals, nil + }, + SerializeChangeset: arrayFromChildFunc(2, decimalType.SerializeChangeset), + DeserializeChangeset: deserializeArrayFn[*decimal.Decimal](2, decimalType.DeserializeChangeset), + } + + uint256Type = &datatype{ + KwilType: types.Uint256Type, + Matches: []reflect.Type{reflect.TypeOf(types.Uint256{}), reflect.TypeOf(&types.Uint256{})}, + // OID is a custom OID, since Postgres doesn't have a built-in type for uint256, + // so Kwil uses a Postgres Domain. + OID: func(m *pgtype.Map) uint32 { + pgt, ok := m.TypeForName("uint256") + if !ok { + // if this happens, it is an internal bug where we are not registering the type + panic("uint256 domain not found") + } + + return pgt.OID + }, + // Under the hood, Kwil's uint256 is a Domain built on a numeric type. + EncodeInferred: func(a any) (any, error) { + var val *types.Uint256 + switch v := a.(type) { + case types.Uint256: + val = &v + case *types.Uint256: + val = v + default: + panic("unreachable") + } + + return pgtype.Numeric{ + Int: val.ToBig(), + Exp: 0, + Valid: true, + }, nil + }, + Decode: func(a any) (any, error) { + pgType, ok := a.(pgtype.Numeric) + if !ok { + return nil, fmt.Errorf("expected pgtype.Numeric, got %T", a) + } + + // if the number ends in 0s, it will have an exponent, so we need to multiply + // the number by 10^exp to get the correct value. + if pgType.Exp > 0 { + z := new(big.Int) + z.Exp(big.NewInt(10), big.NewInt(int64(pgType.Exp)), nil) + z.Mul(z, pgType.Int) + pgType.Int = z + pgType.Exp = 0 + } + + return types.Uint256FromBig(pgType.Int) + }, + SerializeChangeset: func(value string) ([]byte, error) { + // parse to ensure it is a valid uint256, then re-encode it to ensure it is in the correct format. + u, err := types.Uint256FromString(value) + if err != nil { + return nil, err + } + + return u.Bytes(), nil + }, + DeserializeChangeset: func(b []byte) (any, error) { + return types.Uint256FromBytes(b) + }, + } + + uint256ArrayType = &datatype{ + KwilType: types.Uint256ArrayType, + Matches: []reflect.Type{reflect.TypeOf(types.Uint256Array{})}, + // OID is a custom OID, since Postgres doesn't have a built-in type for uint256, + // See the comment on uint256Type for more information. + OID: func(m *pgtype.Map) uint32 { + pgt, ok := m.TypeForName("uint256[]") + if !ok { + // if this happens, it is an internal bug where we are not registering the type + panic("uint256[] domain not found") + } + + return pgt.OID + }, + EncodeInferred: func(a any) (any, error) { + val, ok := a.(types.Uint256Array) + if !ok { + return nil, fmt.Errorf("expected Uint256Array, got %T", a) + } + + vals := make([]pgtype.Numeric, len(val)) + for i, u := range val { + v2, err := uint256Type.EncodeInferred(u) + if err != nil { + return nil, err + } + vals[i] = v2.(pgtype.Numeric) + } + + return vals, nil + }, + Decode: func(a any) (any, error) { + arr, ok := a.([]any) // pgx always returns arrays as []any + if !ok { + return nil, fmt.Errorf("expected []any, got %T", a) + } + + vals := make(types.Uint256Array, len(arr)) + for i, v := range arr { + val, err := uint256Type.Decode(v) + if err != nil { + return nil, err + } + vals[i] = val.(*types.Uint256) + } + + return vals, nil + }, + SerializeChangeset: arrayFromChildFunc(2, uint256Type.SerializeChangeset), + DeserializeChangeset: deserializeArrayFn[*types.Uint256](2, uint256Type.DeserializeChangeset), + } +) + +// defaultEncodeDecode is the default Encode and Decode function for data types. +// It simply returns the value as is, without any modifications. +func defaultEncodeDecode(v any) (any, error) { return v, nil } + +// decodeArrayFn creates a function that decodes an array of a given type. +// it takes a generic for the target scalar type, as well as a decode function +// for the scalar type. +func decodeArray[T any](decode func(any) (any, error)) func(any) (any, error) { + return func(a any) (any, error) { + arr, ok := a.([]any) // pgx always returns arrays as []any + if !ok { + return nil, fmt.Errorf("expected []any, got %T", a) + } + + vals := make([]T, len(arr)) + for i, v := range arr { + val, err := decode(v) + if err != nil { + return nil, err + } + + if val == nil { + continue // leaving it as nil / zero value + } + + vals[i] = val.(T) + } + + return vals, nil + } +} + +// encodeToPGType encodes several Go types to their corresponding pgx types. +// It is capable of detecting special Kwil types and encoding them to their +// corresponding pgx types. It is only used if using inferred argument types. +// If not using inferred argument types, pgx will rely on the Valuer interface +// to encode the Go types to their corresponding pgx types. +// It also returns the pgx type OIDs for each value. +func encodeToPGType(oids *pgtype.Map, values ...any) ([]any, []uint32, error) { + if len(values) == 0 { + return nil, nil, nil + } + + encoded := make([]any, len(values)) + oidsArr := make([]uint32, len(values)) + for i, v := range values { + if v == nil { + encoded[i] = nil + oidsArr[i] = pgtype.TextOID + continue + } + + // special case, if []any, we need to encode each element + if arr, ok := v.([]any); ok { + if len(arr) == 0 { + encoded[i] = nil + oidsArr[i] = pgtype.TextOID + continue + } + + encodedArr, oidsArrArr, err := encodeToPGType(oids, arr...) + if err != nil { + return nil, nil, err + } + + encoded[i] = encodedArr + + // check that all OIDs are the same + oid := oidsArrArr[0] + for _, oid2 := range oidsArrArr { + if oid != oid2 { + return nil, nil, fmt.Errorf("all elements in an array must have the same data type") + } + } + + dt, ok := dataTypesByMatch[reflect.TypeOf(arr[0])] + if !ok { + return nil, nil, fmt.Errorf("unsupported type %T", arr[0]) + } + + arrDt, ok := scalarToArray[dt] + if !ok { + return nil, nil, fmt.Errorf("no array type for %T", arr[0]) + } + + oidsArr[i] = arrDt.OID(oids) + + continue + } + + dt, ok := dataTypesByMatch[reflect.TypeOf(v)] + if !ok { + return nil, nil, fmt.Errorf("unsupported type %T", v) + } + + encodedVal, err := dt.EncodeInferred(v) + if err != nil { + return nil, nil, err + } + + encoded[i] = encodedVal + oidsArr[i] = dt.OID(oids) + } + + return encoded, oidsArr, nil +} + +// for functions that return void, it will actually return +// a nil value with the void OID. +var voidOID = uint32(2278) + +// decodeFromPGType decodes several pgx types to their corresponding Go types. +// It is capable of detecting special Kwil types and decoding them to their +// corresponding Go types. +func decodeFromPG(vals []any, oids []uint32, oidToDataType map[uint32]*datatype) ([]any, error) { + var results []any + for i, oid := range oids { + if oid == voidOID { + continue + } + + if vals[i] == nil { + results = append(results, nil) + continue + } + + dt, ok := oidToDataType[oid] + if !ok { + return nil, fmt.Errorf("unsupported oid %d", oid) + } + + decoded, err := dt.Decode(vals[i]) + if err != nil { + return nil, err + } + + results = append(results, decoded) + } + + return results, nil +} + +// oidTypesMap makes a map mapping oids to the Kwil type definition. +// It needs to be called after registerTypes. +func oidTypesMap(conn *pgtype.Map) map[uint32]*datatype { + m := make(map[uint32]*datatype) + for dt := range datatypes { + oid := dt.OID(conn) + _, ok := m[oid] + if ok { + panic("duplicate oid for type. OID:" + fmt.Sprint(oid)) + } + m[oid] = dt + + for _, extraOID := range dt.ExtraOIDs { + _, ok := m[extraOID] + if ok { + panic("duplicate oid for type. OID:" + fmt.Sprint(extraOID)) + } + m[extraOID] = dt + } + } + return m +} + +// trimCurlys parses curly brackets on the outside of a string. +// It returns the string without the curly brackets, and a boolean +// indicating whether the string had curly brackets. It is useful +// for parsing stringified Postgres arrays. +func trimCurlys(s string) (string, bool) { + if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") { + return s[1 : len(s)-1], true + } + + return s, false +} + +// serializeArray serializes an array of some type to []byte. +// It takes a function that serializes the scalar values to []byte. +// lengthSize is the byte size of the length of each element, which allows +// us to more efficiently serialize arrays of fixed-size elements (int, bool, etc). +// lengthSize must be 1, 2, or 4, corresponding to 8-bit, 16-bit, and 32-bit lengths. +func serializeArray[T any](arr []T, lengthSize uint8, serialize func(T) ([]byte, error)) ([]byte, error) { + encodeLength := func(length int) []byte { + switch lengthSize { + case 1: + return []byte{byte(length)} + case 2: + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, uint16(length)) + return buf + case 4: + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(length)) + return buf + default: + panic("invalid length size") + } + } + + var buf []byte + for _, v := range arr { + encoded, err := serialize(v) + if err != nil { + return nil, err + } + + buf = append(buf, encodeLength(len(encoded))...) + buf = append(buf, encoded...) + } + + return buf, nil +} + +// deserializeArray deserializes an array of some type from []byte. +// It takes a function that deserializes the scalar values from []byte. +// it is the inverse of serializeArray. lengthSize must be 1, 2, or 4, +// corresponding to 8-bit, 16-bit, and 32-bit lengths. +func deserializeArray[T any](buf []byte, lengthSize uint8, deserialize func([]byte) (any, error)) ([]T, error) { + // the lengthSize thing might be a bit overkill, but it is very encapsulated so + // I'll keep it for now, since it can help decrease the size of the changeset that + // a network has to process. + determineLength := func(buf []byte) (int, []byte) { + switch lengthSize { + case 1: + return int(buf[0]), buf[1:] + case 2: + return int(binary.BigEndian.Uint16(buf[:2])), buf[2:] + case 4: + return int(binary.BigEndian.Uint32(buf[:4])), buf[4:] + default: + panic("invalid length size") + } + } + + var arr []T + for len(buf) > 0 { + length, rest := determineLength(buf) + + v, err := deserialize(rest[:length]) + if err != nil { + return nil, err + } + + arr = append(arr, v.(T)) + buf = rest[length:] + } + + return arr, nil +} + +// arrayFromChildFunc splits a stringified array into its elements, and uses +// the callback function to serialize each element. It is meant to be used with +// array data types that do not have special parsing rules. It returns it as a function +// that can be used for decoding changesets +func arrayFromChildFunc(size uint8, serialize func(string) ([]byte, error)) func(string) ([]byte, error) { + return func(s string) ([]byte, error) { + s, ok := trimCurlys(s) + if !ok { + return nil, fmt.Errorf("invalid array: %s", s) + } + + strs := strings.Split(s, ",") + return serializeArray(strs, size, serialize) + } +} + +// deserializeArrayFn returns a function that deserializes an array of some type from a serialized array. +// It is the logical inverse of arrayFromChildFunc. +func deserializeArrayFn[T any](size uint8, deserialize func([]byte) (any, error)) func([]byte) (any, error) { + return func(b []byte) (any, error) { + return deserializeArray[T](b, size, deserialize) + } +} diff --git a/internal/sql/pg/types_test.go b/internal/sql/pg/types_test.go new file mode 100644 index 000000000..b4c0142d6 --- /dev/null +++ b/internal/sql/pg/types_test.go @@ -0,0 +1,38 @@ +package pg + +import ( + "encoding/binary" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_ArrayEncodeDecode(t *testing.T) { + arr := []string{"a", "b", "c"} + res, err := serializeArray(arr, 4, func(s string) ([]byte, error) { + return []byte(s), nil + }) + require.NoError(t, err) + + res2, err := deserializeArray[string](res, 4, func(b []byte) (any, error) { + return string(b), nil + }) + require.NoError(t, err) + + require.EqualValues(t, arr, res2) + + arr2 := []int64{1, 2, 3} + res, err = serializeArray(arr2, 1, func(i int64) ([]byte, error) { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(i)) + return buf, nil + }) + require.NoError(t, err) + + res3, err := deserializeArray[int64](res, 1, func(b []byte) (any, error) { + return int64(binary.LittleEndian.Uint64(b)), nil + }) + require.NoError(t, err) + + require.EqualValues(t, arr2, res3) +} diff --git a/internal/txapp/interfaces.go b/internal/txapp/interfaces.go index 0f18162c3..de1b930aa 100644 --- a/internal/txapp/interfaces.go +++ b/internal/txapp/interfaces.go @@ -19,7 +19,7 @@ type Rebroadcaster interface { // from within a transaction. A DB can create read transactions or the special // two-phase outer write transaction. type DB interface { - sql.OuterTxMaker + sql.PreparedTxMaker sql.ReadTxMaker sql.SnapshotTxMaker // BeginReservedReadTx creates a read-only transaction on a reserved diff --git a/internal/txapp/mempool.go b/internal/txapp/mempool.go index b625c093b..ca86b670a 100644 --- a/internal/txapp/mempool.go +++ b/internal/txapp/mempool.go @@ -9,6 +9,7 @@ import ( "math/big" "sync" + "github.com/kwilteam/kwil-db/common" sql "github.com/kwilteam/kwil-db/common/sql" "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/core/types/transactions" @@ -19,10 +20,6 @@ type mempool struct { accounts map[string]*types.Account acctsMtx sync.Mutex // protects accounts - // consensus parameters - gasEnabled bool - maxVotesPerTx int64 - nodeAddr []byte } @@ -53,14 +50,40 @@ func (m *mempool) accountInfoSafe(ctx context.Context, tx sql.Executor, acctID [ } // applyTransaction validates account specific info and applies valid transactions to the mempool state. -func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transaction, dbTx sql.Executor, rebroadcaster Rebroadcaster) error { +func (m *mempool) applyTransaction(ctx *common.TxContext, tx *transactions.Transaction, dbTx sql.Executor, rebroadcaster Rebroadcaster) error { m.acctsMtx.Lock() defer m.acctsMtx.Unlock() + // if the network is in a migration, there are numerous + // transaction types we must disallow. + // see [internal/migrations/migrations.go] for more info + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + switch tx.Body.PayloadType { + case transactions.PayloadTypeValidatorJoin: + return fmt.Errorf("validator joins are not allowed during migration") + case transactions.PayloadTypeValidatorLeave: + return fmt.Errorf("validator leaves are not allowed during migration") + case transactions.PayloadTypeValidatorApprove: + return fmt.Errorf("validator approvals are not allowed during migration") + case transactions.PayloadTypeValidatorRemove: + return fmt.Errorf("validator removals are not allowed during migration") + case transactions.PayloadTypeValidatorVoteIDs: + return fmt.Errorf("validator vote ids are not allowed during migration") + case transactions.PayloadTypeValidatorVoteBodies: + return fmt.Errorf("validator vote bodies are not allowed during migration") + case transactions.PayloadTypeDeploySchema: + return fmt.Errorf("deploy schema transactions are not allowed during migration") + case transactions.PayloadTypeDropSchema: + return fmt.Errorf("drop schema transactions are not allowed during migration") + case transactions.PayloadTypeTransfer: + return fmt.Errorf("transfer transactions are not allowed during migration") + } + } + // seems like maybe this should go in the switch statement below, // but I put it here to avoid extra db call for account info if tx.Body.PayloadType == transactions.PayloadTypeValidatorVoteIDs { - power, err := voting.GetValidatorPower(ctx, dbTx, tx.Sender) + power, err := voting.GetValidatorPower(ctx.Ctx, dbTx, tx.Sender) if err != nil { return err } @@ -75,8 +98,8 @@ func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transac if err != nil { return err } - if (int64)(len(voteID.ResolutionIDs)) > m.maxVotesPerTx { - return fmt.Errorf("number of voteIDs exceeds the limit of %d", m.maxVotesPerTx) + if maxVotes := ctx.BlockContext.ChainContext.NetworkParameters.MaxVotesPerTx; (int64)(len(voteID.ResolutionIDs)) > maxVotes { + return fmt.Errorf("number of voteIDs exceeds the limit of %d", maxVotes) } } @@ -86,13 +109,13 @@ func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transac } // get account info from mempool state or account store - acct, err := m.accountInfo(ctx, dbTx, tx.Sender) + acct, err := m.accountInfo(ctx.Ctx, dbTx, tx.Sender) if err != nil { return err } // reject the transactions from unfunded user accounts in gasEnabled mode - if m.gasEnabled && acct.Nonce == 0 && acct.Balance.Sign() == 0 { + if !ctx.BlockContext.ChainContext.NetworkParameters.DisabledGasCosts && acct.Nonce == 0 && acct.Balance.Sign() == 0 { delete(m.accounts, string(tx.Sender)) return transactions.ErrInsufficientBalance } @@ -116,7 +139,7 @@ func (m *mempool) applyTransaction(ctx context.Context, tx *transactions.Transac return err } - err = rebroadcaster.MarkRebroadcast(ctx, voteID.ResolutionIDs) + err = rebroadcaster.MarkRebroadcast(ctx.Ctx, voteID.ResolutionIDs) if err != nil { return err } diff --git a/internal/txapp/mempool_test.go b/internal/txapp/mempool_test.go index 733765eae..a54aff5fd 100644 --- a/internal/txapp/mempool_test.go +++ b/internal/txapp/mempool_test.go @@ -5,6 +5,7 @@ import ( "math/big" "testing" + "github.com/kwilteam/kwil-db/common" sql "github.com/kwilteam/kwil-db/common/sql" "github.com/kwilteam/kwil-db/core/crypto/auth" "github.com/kwilteam/kwil-db/core/types" @@ -21,53 +22,73 @@ func Test_MempoolWithoutGas(t *testing.T) { db := &mockDb{} rebroadcast := &mockRebroadcast{} + txCtx := &common.TxContext{ + Ctx: ctx, + BlockContext: &common.BlockContext{ + ChainContext: &common.ChainContext{ + NetworkParameters: &common.NetworkParameters{ + DisabledGasCosts: true, + }, + }, + }, + } + // Successful transaction A: 1 - err := m.applyTransaction(ctx, newTx(t, 1, "A"), db, rebroadcast) + err := m.applyTransaction(txCtx, newTx(t, 1, "A"), db, rebroadcast) assert.NoError(t, err) assert.EqualValues(t, m.accounts["A"].Nonce, 1) // Successful transaction A: 2 - err = m.applyTransaction(ctx, newTx(t, 2, "A"), db, rebroadcast) + err = m.applyTransaction(txCtx, newTx(t, 2, "A"), db, rebroadcast) assert.NoError(t, err) assert.EqualValues(t, m.accounts["A"].Nonce, 2) // Duplicate nonce failure - err = m.applyTransaction(ctx, newTx(t, 2, "A"), db, rebroadcast) + err = m.applyTransaction(txCtx, newTx(t, 2, "A"), db, rebroadcast) assert.Error(t, err) assert.EqualValues(t, m.accounts["A"].Nonce, 2) // Invalid order - err = m.applyTransaction(ctx, newTx(t, 4, "A"), db, rebroadcast) + err = m.applyTransaction(txCtx, newTx(t, 4, "A"), db, rebroadcast) assert.Error(t, err) assert.EqualValues(t, m.accounts["A"].Nonce, 2) - err = m.applyTransaction(ctx, newTx(t, 3, "A"), db, rebroadcast) + err = m.applyTransaction(txCtx, newTx(t, 3, "A"), db, rebroadcast) assert.NoError(t, err) assert.EqualValues(t, m.accounts["A"].Nonce, 3) // Recheck nonce 4 transaction - err = m.applyTransaction(ctx, newTx(t, 4, "A"), db, rebroadcast) + err = m.applyTransaction(txCtx, newTx(t, 4, "A"), db, rebroadcast) assert.NoError(t, err) assert.EqualValues(t, m.accounts["A"].Nonce, 4) } func Test_MempoolWithGas(t *testing.T) { m := &mempool{ - accounts: make(map[string]*types.Account), - gasEnabled: true, + accounts: make(map[string]*types.Account), + } + + txCtx := &common.TxContext{ + Ctx: context.Background(), + BlockContext: &common.BlockContext{ + ChainContext: &common.ChainContext{ + NetworkParameters: &common.NetworkParameters{ + DisabledGasCosts: false, + }, + }, + }, } - ctx := context.Background() db := &mockDb{} rebroadcast := &mockRebroadcast{} // Transaction from Unknown sender should fail tx := newTx(t, 1, "A") - err := m.applyTransaction(ctx, tx, db, rebroadcast) + err := m.applyTransaction(txCtx, tx, db, rebroadcast) assert.Error(t, err) // Resubmitting the same transaction should fail - err = m.applyTransaction(ctx, tx, db, rebroadcast) + err = m.applyTransaction(txCtx, tx, db, rebroadcast) assert.Error(t, err) // Credit account A @@ -78,7 +99,7 @@ func Test_MempoolWithGas(t *testing.T) { } // Successful transaction A: 1 - err = m.applyTransaction(ctx, tx, db, rebroadcast) + err = m.applyTransaction(txCtx, tx, db, rebroadcast) assert.NoError(t, err) } diff --git a/internal/txapp/routes.go b/internal/txapp/routes.go index 9cb9e2fa5..c9a788700 100644 --- a/internal/txapp/routes.go +++ b/internal/txapp/routes.go @@ -242,6 +242,10 @@ func (d *deployDatasetRoute) Price(ctx context.Context, app *common.App, tx *tra } func (d *deployDatasetRoute) PreTx(ctx common.TxContext, svc *common.Service, tx *transactions.Transaction) (transactions.TxCode, error) { + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + return transactions.CodeNetworkInMigration, fmt.Errorf("cannot deploy dataset during migration") + } + schemaPayload := &transactions.Schema{} err := schemaPayload.UnmarshalBinary(tx.Body.Payload) if err != nil { @@ -289,6 +293,10 @@ func (d *dropDatasetRoute) Price(ctx context.Context, app *common.App, tx *trans } func (d *dropDatasetRoute) PreTx(ctx common.TxContext, svc *common.Service, tx *transactions.Transaction) (transactions.TxCode, error) { + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + return transactions.CodeNetworkInMigration, fmt.Errorf("cannot drop dataset during migration") + } + drop := &transactions.DropSchema{} err := drop.UnmarshalBinary(tx.Body.Payload) if err != nil { @@ -414,6 +422,10 @@ func (d *transferRoute) PreTx(ctx common.TxContext, svc *common.Service, tx *tra return transactions.CodeEncodingError, err } + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + return transactions.CodeNetworkInMigration, fmt.Errorf("cannot transfer during migration") + } + bigAmt, ok := new(big.Int).SetString(transferBody.Amount, 10) if !ok { return transactions.CodeInvalidAmount, fmt.Errorf("failed to parse amount: %s", transferBody.Amount) @@ -459,6 +471,10 @@ func (d *validatorJoinRoute) Price(ctx context.Context, app *common.App, tx *tra } func (d *validatorJoinRoute) PreTx(ctx common.TxContext, svc *common.Service, tx *transactions.Transaction) (transactions.TxCode, error) { + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + return transactions.CodeNetworkInMigration, fmt.Errorf("cannot join validator during migration") + } + join := &transactions.ValidatorJoin{} err := join.UnmarshalBinary(tx.Body.Payload) if err != nil { @@ -527,6 +543,10 @@ func (d *validatorApproveRoute) Price(ctx context.Context, app *common.App, tx * } func (d *validatorApproveRoute) PreTx(ctx common.TxContext, svc *common.Service, tx *transactions.Transaction) (transactions.TxCode, error) { + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + return transactions.CodeNetworkInMigration, fmt.Errorf("cannot approve validator join during migration") + } + approve := &transactions.ValidatorApprove{} err := approve.UnmarshalBinary(tx.Body.Payload) if err != nil { @@ -588,6 +608,10 @@ func (d *validatorRemoveRoute) Price(ctx context.Context, app *common.App, tx *t } func (d *validatorRemoveRoute) PreTx(ctx common.TxContext, svc *common.Service, tx *transactions.Transaction) (transactions.TxCode, error) { + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + return transactions.CodeNetworkInMigration, fmt.Errorf("cannot remove validator during migration") + } + remove := &transactions.ValidatorRemove{} err := remove.UnmarshalBinary(tx.Body.Payload) if err != nil { @@ -655,6 +679,9 @@ func (d *validatorLeaveRoute) Price(ctx context.Context, app *common.App, tx *tr } func (d *validatorLeaveRoute) PreTx(ctx common.TxContext, svc *common.Service, tx *transactions.Transaction) (transactions.TxCode, error) { + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + return transactions.CodeNetworkInMigration, fmt.Errorf("cannot leave validator during migration") + } return 0, nil // no payload to decode or validate for this route } @@ -696,6 +723,9 @@ func (d *validatorVoteIDsRoute) Price(ctx context.Context, app *common.App, tx * } func (d *validatorVoteIDsRoute) PreTx(ctx common.TxContext, svc *common.Service, tx *transactions.Transaction) (transactions.TxCode, error) { + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + return transactions.CodeNetworkInMigration, fmt.Errorf("cannot vote during migration") + } return 0, nil } @@ -773,6 +803,11 @@ func (d *validatorVoteBodiesRoute) Price(ctx context.Context, _ *common.App, tx } func (d *validatorVoteBodiesRoute) PreTx(ctx common.TxContext, _ *common.Service, tx *transactions.Transaction) (transactions.TxCode, error) { + if ctx.BlockContext.ChainContext.NetworkParameters.InMigration { + return transactions.CodeNetworkInMigration, fmt.Errorf("cannot vote during migration") + + } + // Only proposer can issue a VoteBody transaction. if !bytes.Equal(tx.Sender, ctx.BlockContext.Proposer) { return transactions.CodeInvalidSender, ErrCallerNotProposer diff --git a/internal/txapp/routes_test.go b/internal/txapp/routes_test.go index bef19c48f..9e0825863 100644 --- a/internal/txapp/routes_test.go +++ b/internal/txapp/routes_test.go @@ -56,12 +56,12 @@ func Test_Routes(t *testing.T) { // we can have scoped data in our mock implementations type testcase struct { name string - fn func(t *testing.T, callback func(*TxApp)) // required, uses callback to allow for scoped data - payload transactions.Payload // required - fee int64 // optional, if nil, will automatically use 0 - ctx TxContext // optional, if nil, will automatically create a mock - from auth.Signer // optional, if nil, will automatically use default validatorSigner1 - err error // if not nil, expect this error + fn func(t *testing.T, callback func()) // required, uses callback to control when the test is run + payload transactions.Payload // required + fee int64 // optional, if nil, will automatically use 0 + ctx TxContext // optional, if nil, will automatically create a mock + from auth.Signer // optional, if nil, will automatically use default validatorSigner1 + err error // if not nil, expect this error } // due to the relative simplicity of routes and pricing, I have only tested a few complex ones. @@ -73,7 +73,7 @@ func Test_Routes(t *testing.T) { // we expect that it will approve and then attempt to delete the event name: "validator_vote_id, as local validator", fee: voting.ValidatorVoteIDPrice, - fn: func(t *testing.T, callback func(*TxApp)) { + fn: func(t *testing.T, callback func()) { approveCount := 0 deleteCount := 0 @@ -93,9 +93,7 @@ func Test_Routes(t *testing.T) { return 1, nil } - callback(&TxApp{ - GasEnabled: true, - }) + callback() assert.Equal(t, 1, approveCount) assert.Equal(t, 1, deleteCount) @@ -111,7 +109,7 @@ func Test_Routes(t *testing.T) { // we expect that it will approve and not attempt to delete the event name: "validator_vote_id, as non-local validator", fee: voting.ValidatorVoteIDPrice, - fn: func(t *testing.T, callback func(*TxApp)) { + fn: func(t *testing.T, callback func()) { approveCount := 0 deleteCount := 0 @@ -131,9 +129,7 @@ func Test_Routes(t *testing.T) { return 1, nil } - callback(&TxApp{ - GasEnabled: true, - }) + callback() assert.Equal(t, 1, approveCount) assert.Equal(t, 0, deleteCount) @@ -150,14 +146,12 @@ func Test_Routes(t *testing.T) { // we expect that it will fail name: "validator_vote_id, as non-validator", fee: voting.ValidatorVoteIDPrice, - fn: func(t *testing.T, callback func(*TxApp)) { + fn: func(t *testing.T, callback func()) { getVoterPower = func(ctx context.Context, db sql.Executor, identifier []byte) (int64, error) { return 0, nil } - callback(&TxApp{ - GasEnabled: true, - }) + callback() }, payload: &transactions.ValidatorVoteIDs{ ResolutionIDs: []*types.UUID{ @@ -170,7 +164,7 @@ func Test_Routes(t *testing.T) { // testing validator_vote_bodies, as the proposer name: "validator_vote_bodies, as proposer", fee: voting.ValidatorVoteIDPrice, - fn: func(t *testing.T, callback func(*TxApp)) { + fn: func(t *testing.T, callback func()) { deleteCount := 0 // override the functions with mocks @@ -186,9 +180,7 @@ func Test_Routes(t *testing.T) { return 1, nil } - callback(&TxApp{ - GasEnabled: true, - }) + callback() assert.Equal(t, 1, deleteCount) }, payload: &transactions.ValidatorVoteBodies{ @@ -211,7 +203,7 @@ func Test_Routes(t *testing.T) { // should fail name: "validator_vote_bodies, as non-proposer", fee: voting.ValidatorVoteIDPrice, - fn: func(t *testing.T, callback func(*TxApp)) { + fn: func(t *testing.T, callback func()) { deleteCount := 0 deleteEvent = func(_ context.Context, _ sql.Executor, _ *types.UUID) error { @@ -224,9 +216,7 @@ func Test_Routes(t *testing.T) { return 1, nil } - callback(&TxApp{ - GasEnabled: true, - }) + callback() assert.Equal(t, 0, deleteCount) // 0, since this does not go through }, payload: &transactions.ValidatorVoteBodies{ @@ -281,8 +271,9 @@ func Test_Routes(t *testing.T) { require.Fail(t, "no callback provided") } - tc.fn(t, func(app *TxApp) { + tc.fn(t, func() { db := &mockTx{&mockDb{}} + app := &TxApp{} // since every test case needs an account store, we'll just create a mock one here // if one isn't provided @@ -293,6 +284,22 @@ func Test_Routes(t *testing.T) { app.signer = validatorSigner1() } + if tc.ctx.BlockContext == nil { + tc.ctx.BlockContext = &common.BlockContext{ + ChainContext: &common.ChainContext{ + NetworkParameters: &common.NetworkParameters{ + DisabledGasCosts: false, + }, + }, + } + } else if tc.ctx.BlockContext.ChainContext == nil { + tc.ctx.BlockContext.ChainContext = &common.ChainContext{ + NetworkParameters: &common.NetworkParameters{ + DisabledGasCosts: false, + }, + } + } + res := app.Execute(tc.ctx, db, tx) if tc.err != nil { require.ErrorIs(t, tc.err, res.Error) diff --git a/internal/txapp/txapp.go b/internal/txapp/txapp.go index ce7f494ee..004b4ea3c 100644 --- a/internal/txapp/txapp.go +++ b/internal/txapp/txapp.go @@ -44,16 +44,11 @@ func NewTxApp(ctx context.Context, db sql.Executor, engine common.Engine, signer Engine: engine, events: events, log: log, - mempool: &mempool{ - accounts: make(map[string]*types.Account), - gasEnabled: !chainParams.ConsensusParams.WithoutGasCosts, - maxVotesPerTx: chainParams.ConsensusParams.Votes.MaxVotesPerTx, - nodeAddr: signer.Identity(), + mempool: &mempool{accounts: make(map[string]*types.Account), + nodeAddr: signer.Identity(), }, signer: signer, chainID: chainParams.ChainID, - GasEnabled: !chainParams.ConsensusParams.WithoutGasCosts, - maxVotesPerTx: chainParams.ConsensusParams.Votes.MaxVotesPerTx, extensionConfigs: extensionConfigs, emptyVoteBodyTxSize: voteBodyTxSize, resTypes: resTypes, @@ -73,10 +68,6 @@ type TxApp struct { forks forks.Forks - // Genesis config - GasEnabled bool - maxVotesPerTx int64 - events Rebroadcaster chainID string @@ -89,9 +80,6 @@ type TxApp struct { valMtx sync.RWMutex // protects validators access valChans []chan []*types.Validator - // transaction that exists between Begin and Commit - // currentTx sql.OuterTx - extensionConfigs map[string]map[string]string // precomputed variables @@ -518,7 +506,7 @@ func (r *TxApp) processVotes(ctx context.Context, db sql.DB, block *common.Block // now we will apply credits if gas is enabled. // Since it is a map, we need to order it for deterministic results. - if r.GasEnabled { + if !block.ChainContext.NetworkParameters.DisabledGasCosts { for _, kv := range order.OrderMap(credits) { err = credit(ctx, db, []byte(kv.Key), kv.Value) if err != nil { @@ -576,7 +564,7 @@ func (r *TxApp) Commit(ctx context.Context) { // ApplyMempool applies the transactions in the mempool. // If it returns an error, then the transaction is invalid. -func (r *TxApp) ApplyMempool(ctx context.Context, db sql.DB, tx *transactions.Transaction) error { +func (r *TxApp) ApplyMempool(ctx *common.TxContext, db sql.DB, tx *transactions.Transaction) error { // check that payload type is valid if getRoute(tx.Body.PayloadType.String()) == nil { return fmt.Errorf("unknown payload type: %s", tx.Body.PayloadType.String()) @@ -615,8 +603,8 @@ func (r *TxApp) ProposerTxs(ctx context.Context, db sql.DB, txNonce uint64, maxT } bal, nonce := acct.Balance, acct.Nonce - if r.GasEnabled && nonce == 0 && bal.Sign() == 0 { - r.log.Debug("proposer account has no balance, not allowed to propose any new transactions") + if !block.ChainContext.NetworkParameters.DisabledGasCosts && nonce == 0 && bal.Sign() == 0 { + r.log.Debug("proposer account has no balance, not allowed to propose any new transactions", log.Int("height", block.Height)) return nil, nil } @@ -630,6 +618,7 @@ func (r *TxApp) ProposerTxs(ctx context.Context, db sql.DB, txNonce uint64, maxT return nil, err } if len(events) == 0 { + r.log.Debug("no events to propose", log.Int("height", block.Height)) return nil, nil } @@ -641,8 +630,8 @@ func (r *TxApp) ProposerTxs(ctx context.Context, db sql.DB, txNonce uint64, maxT // Is thre any reason to check for notProcessed events here? Becase event store will never have events that are already processed. // Limit upto only 50 VoteBodies per block - if len(ids) > int(r.maxVotesPerTx) { - ids = ids[:r.maxVotesPerTx] + if len(ids) > int(block.ChainContext.NetworkParameters.MaxVotesPerTx) { + ids = ids[:block.ChainContext.NetworkParameters.MaxVotesPerTx] } eventMap := make(map[types.UUID]*types.VotableEvent) @@ -659,6 +648,7 @@ func (r *TxApp) ProposerTxs(ctx context.Context, db sql.DB, txNonce uint64, maxT evtSz := int64(len(event.Type)) + int64(len(event.Body)) + eventRLPSize if evtSz > maxTxsSize { + r.log.Debug("reached maximum proposer tx size", log.Int("height", block.Height)) break } maxTxsSize -= evtSz @@ -669,6 +659,13 @@ func (r *TxApp) ProposerTxs(ctx context.Context, db sql.DB, txNonce uint64, maxT } if len(finalEvents) == 0 { + r.log.Debug("found proposer events to propose, but cannot fit them in a block", + log.Int("height", block.Height), + log.Int("maxTxsSize", maxTxsSize), + log.Int("emptyVoteBodyTxSize", r.emptyVoteBodyTxSize), + log.Int("foundEvents", len(events)), + log.Int("maxVotesPerTx", block.ChainContext.NetworkParameters.MaxVotesPerTx), + ) return nil, nil } @@ -682,7 +679,7 @@ func (r *TxApp) ProposerTxs(ctx context.Context, db sql.DB, txNonce uint64, maxT } // Fee Estimate - amt, err := r.Price(ctx, db, tx) + amt, err := r.Price(ctx, db, tx, block.ChainContext) if err != nil { return nil, err } @@ -741,8 +738,8 @@ type TxResponse struct { // Price estimates the price of a transaction. // It returns the estimated price in tokens. -func (r *TxApp) Price(ctx context.Context, dbTx sql.DB, tx *transactions.Transaction) (*big.Int, error) { - if !r.GasEnabled { +func (r *TxApp) Price(ctx context.Context, dbTx sql.DB, tx *transactions.Transaction, chainContext *common.ChainContext) (*big.Int, error) { + if chainContext.NetworkParameters.DisabledGasCosts { return big.NewInt(0), nil } @@ -770,7 +767,7 @@ func (r *TxApp) checkAndSpend(ctx TxContext, tx *transactions.Transaction, price amt := big.NewInt(0) var err error - if r.GasEnabled { + if !ctx.BlockContext.ChainContext.NetworkParameters.DisabledGasCosts { amt, err = pricer.Price(ctx.Ctx, r, dbTx, tx) if err != nil { return nil, transactions.CodeUnknownError, err diff --git a/internal/voting/broadcast/broadcast.go b/internal/voting/broadcast/broadcast.go index 4c42bcc58..a9f3c4ca3 100644 --- a/internal/voting/broadcast/broadcast.go +++ b/internal/voting/broadcast/broadcast.go @@ -21,10 +21,12 @@ import ( "context" "math/big" + "github.com/kwilteam/kwil-db/common" "github.com/kwilteam/kwil-db/common/sql" cmtCoreTypes "github.com/cometbft/cometbft/rpc/core/types" "github.com/kwilteam/kwil-db/core/crypto/auth" + "github.com/kwilteam/kwil-db/core/log" "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/core/types/transactions" ) @@ -50,11 +52,11 @@ type TxApp interface { // AccountInfo gets uncommitted information about an account. AccountInfo(ctx context.Context, db sql.DB, acctID []byte, getUncommitted bool) (balance *big.Int, nonce int64, err error) // Price gets the estimated fee for a transaction. - Price(ctx context.Context, db sql.DB, tx *transactions.Transaction) (*big.Int, error) + Price(ctx context.Context, db sql.DB, tx *transactions.Transaction, chain *common.ChainContext) (*big.Int, error) GetValidators(ctx context.Context, db sql.DB) ([]*types.Validator, error) } -func NewEventBroadcaster(store EventStore, broadcaster Broadcaster, app TxApp, signer *auth.Ed25519Signer, chainID string, voteLimit int64) *EventBroadcaster { +func NewEventBroadcaster(store EventStore, broadcaster Broadcaster, app TxApp, signer *auth.Ed25519Signer, chainID string, voteLimit int64, logger log.Logger) *EventBroadcaster { return &EventBroadcaster{ store: store, broadcaster: broadcaster, @@ -62,6 +64,7 @@ func NewEventBroadcaster(store EventStore, broadcaster Broadcaster, app TxApp, s chainID: chainID, app: app, maxVoteIDsPerTx: voteLimit, + logger: logger, } } @@ -77,13 +80,15 @@ type EventBroadcaster struct { // This is to limit the long external roundtrips to the postgres database // 10k voteIDs in a block takes around 30s to process, which is too long. maxVoteIDsPerTx int64 + + logger log.Logger } // RunBroadcast tells the EventBroadcaster to broadcast any events it wishes. // It implements Kwil's abci.CommitHook function signature. // If the node is not a validator, it will do nothing. // It broadcasts votes for the existing resolutions. -func (e *EventBroadcaster) RunBroadcast(ctx context.Context, db sql.DB, proposer []byte) error { +func (e *EventBroadcaster) RunBroadcast(ctx context.Context, db sql.DB, block *common.BlockContext) error { readTx, err := db.BeginTx(ctx) if err != nil { return err @@ -105,6 +110,7 @@ func (e *EventBroadcaster) RunBroadcast(ctx context.Context, db sql.DB, proposer } if !isCurrent { + e.logger.Debug("local node is not a validator, skipping voteID broadcast") return nil } @@ -116,7 +122,8 @@ func (e *EventBroadcaster) RunBroadcast(ctx context.Context, db sql.DB, proposer // in the on-going block. This probably is a temporary restriction until // we figure out a better way to track both // mempool(uncommitted), committed and proposer introduced txns. - if bytes.Equal(proposer, e.signer.Identity()) { + if bytes.Equal(block.Proposer, e.signer.Identity()) { + e.logger.Debug("local node is current block proposer, skipping voteID broadcast") return nil } @@ -128,6 +135,7 @@ func (e *EventBroadcaster) RunBroadcast(ctx context.Context, db sql.DB, proposer } if len(ids) == 0 { + e.logger.Debug("no voteIDs to broadcast") return nil } @@ -147,7 +155,7 @@ func (e *EventBroadcaster) RunBroadcast(ctx context.Context, db sql.DB, proposer } // Get the fee estimate - fee, err := e.app.Price(ctx, readTx, tx) + fee, err := e.app.Price(ctx, readTx, tx, block.ChainContext) if err != nil { return err } @@ -156,6 +164,7 @@ func (e *EventBroadcaster) RunBroadcast(ctx context.Context, db sql.DB, proposer if bal.Cmp(fee) < 0 { // Not enough balance to pay for the tx fee + e.logger.Warnf("skipping voteID broadcast: not enough balance to pay for the tx fee, balance: %s, fee: %s", bal.String(), fee.String()) return nil } @@ -174,6 +183,8 @@ func (e *EventBroadcaster) RunBroadcast(ctx context.Context, db sql.DB, proposer return err } + e.logger.Infof("broadcasted %d voteIDs", len(ids)) + // mark these events as broadcasted return e.store.MarkBroadcasted(ctx, ids) } diff --git a/internal/voting/broadcast/broadcast_test.go b/internal/voting/broadcast/broadcast_test.go index c105dd2f1..dcd620a0d 100644 --- a/internal/voting/broadcast/broadcast_test.go +++ b/internal/voting/broadcast/broadcast_test.go @@ -9,9 +9,11 @@ import ( cmtCoreTypes "github.com/cometbft/cometbft/rpc/core/types" "github.com/stretchr/testify/require" + "github.com/kwilteam/kwil-db/common" "github.com/kwilteam/kwil-db/common/sql" "github.com/kwilteam/kwil-db/core/crypto" "github.com/kwilteam/kwil-db/core/crypto/auth" + "github.com/kwilteam/kwil-db/core/log" "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/core/types/transactions" dbtest "github.com/kwilteam/kwil-db/internal/sql/pg/test" @@ -131,7 +133,7 @@ func Test_Broadcaster(t *testing.T) { } } - bc := broadcast.NewEventBroadcaster(e, b, txapp, validatorSigner(), "test-chain", maxVoteIDsPerTx) + bc := broadcast.NewEventBroadcaster(e, b, txapp, validatorSigner(), "test-chain", maxVoteIDsPerTx, log.NewNoOp()) // create resolutions for the events for _, event := range e.events { @@ -139,7 +141,9 @@ func Test_Broadcaster(t *testing.T) { require.NoError(t, err) } - err = bc.RunBroadcast(ctx, &mockDB{}, []byte("proposer")) + err = bc.RunBroadcast(ctx, &mockDB{}, &common.BlockContext{ + Proposer: []byte("proposer"), + }) if tc.err != nil { require.Equal(t, tc.err, err) return @@ -189,7 +193,7 @@ func (m *mockTxApp) AccountInfo(ctx context.Context, db sql.DB, acctID []byte, g return m.balance, m.nonce, nil } -func (m *mockTxApp) Price(ctx context.Context, db sql.DB, tx *transactions.Transaction) (*big.Int, error) { +func (m *mockTxApp) Price(ctx context.Context, db sql.DB, tx *transactions.Transaction, c *common.ChainContext) (*big.Int, error) { if m.price == nil { return big.NewInt(0), nil } diff --git a/internal/voting/events.go b/internal/voting/events.go index 0e315a3e0..9067456fd 100644 --- a/internal/voting/events.go +++ b/internal/voting/events.go @@ -72,7 +72,9 @@ const ( LEFT JOIN ` + votingSchemaName + `.resolutions AS r ON e.id = r.id WHERE r.id IS NULL;` - // eventsToBroadcast returns the list of the resolutionIDs observed by the validator that are not previously broadcasted + // eventsToBroadcast returns the list of the resolutionIDs observed by the validator that are not previously broadcasted. + // It will only search for votes from which resolutions exist (it achieves this by inner joining against the existing resolutions, + // effectively filtering out events that do not have resolutions yet). eventsToBroadcast = `SELECT e.id FROM ` + votingSchemaName + `.resolutions AS r INNER JOIN ` + schemaName + `.events AS e ON r.id = e.id diff --git a/internal/voting/voting.go b/internal/voting/voting.go index ee79d1a20..e2042e015 100644 --- a/internal/voting/voting.go +++ b/internal/voting/voting.go @@ -186,8 +186,8 @@ func fromRow(row []any) (*resolutions.Resolution, error) { } } - var voters []any - voters, ok = row[5].([]any) + var voters [][]byte + voters, ok = row[5].([][]byte) if !ok { return nil, fmt.Errorf("invalid type for voters (%T)", row[5]) } @@ -199,26 +199,21 @@ func fromRow(row []any) (*resolutions.Resolution, error) { continue // pgx returns nil aggregates as length one []interface{} with a nil element } - voterBts, ok := voter.([]byte) - if !ok { - return nil, fmt.Errorf("invalid type for voter (%T)", voter) - } - // the first 8 bytes are the power - if len(voterBts) < 8 { + if len(voter) < 8 { // this should never happen, just for safety - return nil, fmt.Errorf("invalid length for voter (%d)", len(voterBts)) + return nil, fmt.Errorf("invalid length for voter (%d)", len(voter)) } var num uint64 - err := binary.Read(bytes.NewReader(voterBts[:8]), binary.BigEndian, &num) + err := binary.Read(bytes.NewReader(voter[:8]), binary.BigEndian, &num) if err != nil { return nil, fmt.Errorf("failed to read bigendian int64 from voter: %w", err) } v.Voters = append(v.Voters, &types.Validator{ Power: int64(num), - PubKey: slices.Clone(voterBts[8:]), + PubKey: slices.Clone(voter[8:]), }) } diff --git a/testing/testing.go b/testing/testing.go index d9c3c453c..255e8eef9 100644 --- a/testing/testing.go +++ b/testing/testing.go @@ -154,7 +154,7 @@ func (tc SchemaTest) Run(ctx context.Context, opts *Options) error { logger.Logf(`running test "%s"`, testFnIdentifiers[i]) // setup a tx and execution engine - outerTx, err := d.BeginOuterTx(ctx) + outerTx, err := d.BeginPreparedTx(ctx) if err != nil { return err }