From c1a1430e0d1d4f276a5c7c977f94d16254ad7884 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Mon, 20 May 2024 16:28:12 +0200 Subject: [PATCH] Support IN keyword for [][]byte in sql builder. Extract args from transaction. Generate valid tx for tx generator. Add tests for list handler --- api/grpcserver/v2alpha1/transaction.go | 79 +++++++++++++++-- api/grpcserver/v2alpha1/transaction_test.go | 93 +++++++++++++++++++++ common/fixture/transaction_results.go | 19 ++++- sql/builder/builder.go | 56 ++++++++++--- 4 files changed, 225 insertions(+), 22 deletions(-) diff --git a/api/grpcserver/v2alpha1/transaction.go b/api/grpcserver/v2alpha1/transaction.go index 35ce083ea4..0161781885 100644 --- a/api/grpcserver/v2alpha1/transaction.go +++ b/api/grpcserver/v2alpha1/transaction.go @@ -1,9 +1,14 @@ package v2alpha1 import ( + "bytes" "context" "errors" "fmt" + "github.com/spacemeshos/go-scale" + "github.com/spacemeshos/go-spacemesh/genvm/registry" + "github.com/spacemeshos/go-spacemesh/genvm/templates/vault" + "github.com/spacemeshos/go-spacemesh/genvm/templates/vesting" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" spacemeshv2alpha1 "github.com/spacemeshos/api/release/go/spacemesh/v2alpha1" @@ -245,12 +250,20 @@ func toTransactionOperations(filter *spacemeshv2alpha1.TransactionRequest) (buil return builder.Operations{}, err } ops.Filter = append(ops.Filter, builder.Op{ - Field: builder.Address, + Field: builder.Principal, Token: builder.Eq, Value: addr.Bytes(), }) } + if len(filter.Txid) > 0 { + ops.Filter = append(ops.Filter, builder.Op{ + Field: builder.Id, + Token: builder.In, + Value: filter.Txid, + }) + } + if filter.StartLayer != nil { ops.Filter = append(ops.Filter, builder.Op{ Field: builder.Layer, @@ -267,6 +280,10 @@ func toTransactionOperations(filter *spacemeshv2alpha1.TransactionRequest) (buil }) } + if len(ops.Filter) > 0 { + ops.StartWith = "and" + } + ops.Modifiers = append(ops.Modifiers, builder.Modifier{ Key: builder.OrderBy, Value: "layer asc, id", @@ -284,7 +301,6 @@ func toTransactionOperations(filter *spacemeshv2alpha1.TransactionRequest) (buil Value: int64(filter.Offset), }) } - return ops, nil } @@ -307,22 +323,20 @@ func (s *TransactionService) toTx(tx *types.MeshTransaction, result *types.Trans t.GasPrice = tx.GasPrice t.MaxSpend = tx.MaxSpend t.Contents = &spacemeshv2alpha1.TransactionContents{} - - req := s.conState.Validation(tx.GetRaw()) - _, _ = req.Parse() + txArgs, _ := decodeTxArgs(scale.NewDecoder(bytes.NewReader(tx.Raw))) switch tx.Method { case core.MethodSpawn: switch tx.TxHeader.TemplateAddress { case wallet.TemplateAddress: - args := req.Args().(*wallet.SpawnArguments) + args := txArgs.(*wallet.SpawnArguments) t.Contents.Contents = &spacemeshv2alpha1.TransactionContents_SingleSigSpawn{ SingleSigSpawn: &spacemeshv2alpha1.ContentsSingleSigSpawn{ Pubkey: args.PublicKey.String(), }, } case multisig.TemplateAddress: - args := req.Args().(*multisig.SpawnArguments) + args := txArgs.(*multisig.SpawnArguments) contents := &spacemeshv2alpha1.TransactionContents_MultiSigSpawn{ MultiSigSpawn: &spacemeshv2alpha1.ContentsMultiSigSpawn{ Required: uint32(args.Required), @@ -335,7 +349,7 @@ func (s *TransactionService) toTx(tx *types.MeshTransaction, result *types.Trans t.Contents.Contents = contents } case core.MethodSpend: - args := req.Args().(*wallet.SpendArguments) + args := txArgs.(*wallet.SpendArguments) t.Contents.Contents = &spacemeshv2alpha1.TransactionContents_Send{ Send: &spacemeshv2alpha1.ContentsSend{ Destination: args.Destination.String(), @@ -393,3 +407,52 @@ func convertTxState(tx *types.MeshTransaction) spacemeshv2alpha1.TransactionStat return spacemeshv2alpha1.TransactionState_TRANSACTION_STATE_UNSPECIFIED } } + +func decodeTxArgs(decoder *scale.Decoder) (scale.Encodable, error) { + reg := registry.New() + wallet.Register(reg) + multisig.Register(reg) + vesting.Register(reg) + vault.Register(reg) + + _, _, err := scale.DecodeCompact8(decoder) + if err != nil { + return nil, fmt.Errorf("%w: failed to decode version %w", core.ErrMalformed, err) + } + + var principal core.Address + if _, err := principal.DecodeScale(decoder); err != nil { + return nil, fmt.Errorf("%w failed to decode principal: %w", core.ErrMalformed, err) + } + + method, _, err := scale.DecodeCompact8(decoder) + if err != nil { + return nil, fmt.Errorf("%w: failed to decode method selector %w", core.ErrMalformed, err) + } + + var handler core.Handler + templateAddress := &core.Address{} + if _, err := templateAddress.DecodeScale(decoder); err != nil { + return nil, fmt.Errorf("%w failed to decode template address %w", core.ErrMalformed, err) + } + + handler = reg.Get(*templateAddress) + if handler == nil { + return nil, fmt.Errorf("%w: unknown template %s", core.ErrMalformed, *templateAddress) + } + + var p core.Payload + if _, err = p.DecodeScale(decoder); err != nil { + return nil, fmt.Errorf("%w: %w", core.ErrMalformed, err) + } + + args := handler.Args(method) + if args == nil { + return nil, fmt.Errorf("%w: unknown method %s %d", core.ErrMalformed, *templateAddress, method) + } + if _, err := args.DecodeScale(decoder); err != nil { + return nil, fmt.Errorf("%w failed to decode method arguments %w", core.ErrMalformed, err) + } + + return args, nil +} diff --git a/api/grpcserver/v2alpha1/transaction_test.go b/api/grpcserver/v2alpha1/transaction_test.go index 10ba4d2e5b..05145d16a6 100644 --- a/api/grpcserver/v2alpha1/transaction_test.go +++ b/api/grpcserver/v2alpha1/transaction_test.go @@ -2,8 +2,11 @@ package v2alpha1 import ( "context" + "github.com/spacemeshos/go-spacemesh/common/fixture" + "github.com/spacemeshos/go-spacemesh/sql/transactions" "math/rand" "testing" + "time" "github.com/oasisprotocol/curve25519-voi/primitives/ed25519" spacemeshv2alpha1 "github.com/spacemeshos/api/release/go/spacemesh/v2alpha1" @@ -20,6 +23,96 @@ import ( "github.com/spacemeshos/go-spacemesh/txs" ) +func TestTransactionService_List(t *testing.T) { + types.SetLayersPerEpoch(5) + db := sql.InMemory() + ctx := context.Background() + + gen := fixture.NewTransactionResultGenerator().WithAddresses(2) + txsList := make([]types.TransactionWithResult, 100) + require.NoError(t, db.WithTx(ctx, func(dtx *sql.Tx) error { + for i := range txsList { + tx := gen.Next() + + require.NoError(t, transactions.Add(dtx, &tx.Transaction, time.Time{})) + require.NoError(t, transactions.AddResult(dtx, tx.ID, &tx.TransactionResult)) + txsList[i] = *tx + } + return nil + })) + + svc := NewTransactionService(db, nil, nil, nil, nil) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) + + conn := dialGrpc(ctx, t, cfg) + client := spacemeshv2alpha1.NewTransactionServiceClient(conn) + + t.Run("limit set too high", func(t *testing.T) { + _, err := client.List(ctx, &spacemeshv2alpha1.TransactionRequest{Limit: 200}) + require.Error(t, err) + + s, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.InvalidArgument, s.Code()) + require.Equal(t, "limit is capped at 100", s.Message()) + }) + + t.Run("no limit set", func(t *testing.T) { + _, err := client.List(ctx, &spacemeshv2alpha1.TransactionRequest{}) + require.Error(t, err) + + s, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.InvalidArgument, s.Code()) + require.Equal(t, "limit must be set to <= 100", s.Message()) + }) + + t.Run("limit and offset", func(t *testing.T) { + list, err := client.List(ctx, &spacemeshv2alpha1.TransactionRequest{Limit: 25, Offset: 50}) + require.NoError(t, err) + require.Len(t, list.Transactions, 25) + }) + + t.Run("all", func(t *testing.T) { + list, err := client.List(ctx, &spacemeshv2alpha1.TransactionRequest{Limit: 100}) + require.NoError(t, err) + require.Len(t, list.Transactions, len(txsList)) + }) + + t.Run("principal", func(t *testing.T) { + principal := txsList[0].TxHeader.Principal.String() + list, err := client.List(ctx, &spacemeshv2alpha1.TransactionRequest{ + Principal: &principal, + Limit: 1, + }) + require.NoError(t, err) + require.Len(t, list.Transactions, 1) + require.Equal(t, list.Transactions[0].Tx.GetV1().Principal, principal) + }) + + t.Run("tx id", func(t *testing.T) { + list, err := client.List(ctx, &spacemeshv2alpha1.TransactionRequest{ + Txid: [][]byte{txsList[0].ID[:]}, + Limit: 100, + }) + require.NoError(t, err) + require.Len(t, list.Transactions, 1) + require.Equal(t, list.Transactions[0].Tx.GetV1().Id, txsList[0].ID[:]) + }) + + t.Run("multiple tx ids", func(t *testing.T) { + list, err := client.List(ctx, &spacemeshv2alpha1.TransactionRequest{ + Txid: [][]byte{txsList[0].ID[:], txsList[1].ID[:]}, + Limit: 100, + }) + require.NoError(t, err) + require.Len(t, list.Transactions, 2) + require.Equal(t, list.Transactions[0].Tx.GetV1().Id, txsList[0].ID[:]) + require.Equal(t, list.Transactions[1].Tx.GetV1().Id, txsList[1].ID[:]) + }) +} + func TestTransactionService_EstimateGas(t *testing.T) { types.SetLayersPerEpoch(5) db := sql.InMemory() diff --git a/common/fixture/transaction_results.go b/common/fixture/transaction_results.go index 6e250f22ef..97dff4cc3b 100644 --- a/common/fixture/transaction_results.go +++ b/common/fixture/transaction_results.go @@ -2,6 +2,10 @@ package fixture import ( "encoding/binary" + "github.com/oasisprotocol/curve25519-voi/primitives/ed25519" + "github.com/spacemeshos/go-spacemesh/genvm/core" + wallet2 "github.com/spacemeshos/go-spacemesh/genvm/sdk/wallet" + "github.com/spacemeshos/go-spacemesh/genvm/templates/wallet" "math/rand" "time" @@ -66,10 +70,21 @@ func (g *TransactionResultGenerator) WithLayers(start, n int) *TransactionResult func (g *TransactionResultGenerator) Next() *types.TransactionWithResult { var tx types.TransactionWithResult g.rng.Read(tx.ID[:]) - tx.Raw = make([]byte, 10) - g.rng.Read(tx.Raw) + + _, priv, _ := ed25519.GenerateKey(g.rng) + spawnTx := wallet2.SelfSpawn(priv, types.Nonce(1)) + rawTx := types.NewRawTx(spawnTx) + tx.RawTx = rawTx + tx.Block = g.Blocks[g.rng.Intn(len(g.Blocks))] tx.Layer = g.Layers[g.rng.Intn(len(g.Layers))] + tx.TxHeader = &types.TxHeader{ + TemplateAddress: wallet.TemplateAddress, + Method: core.MethodSpawn, + Principal: g.Addrs[g.rng.Intn(len(g.Addrs))], + Nonce: types.Nonce(1), + } + if lth := g.rng.Intn(len(g.Addrs)); lth > 0 { tx.Addresses = make([]types.Address, lth%10+1) diff --git a/sql/builder/builder.go b/sql/builder/builder.go index 93ca8c4218..3ad0d9dc43 100644 --- a/sql/builder/builder.go +++ b/sql/builder/builder.go @@ -17,17 +17,19 @@ const ( Gte token = ">=" Lt token = "<" Lte token = "<=" + In token = "in" ) type field string const ( - Epoch field = "epoch" - Smesher field = "pubkey" - Coinbase field = "coinbase" - Id field = "id" - Layer field = "layer" - Address field = "address" + Epoch field = "epoch" + Smesher field = "pubkey" + Coinbase field = "coinbase" + Id field = "id" + Layer field = "layer" + Address field = "address" + Principal field = "principal" ) type modifier string @@ -58,6 +60,7 @@ type Modifier struct { type Operations struct { Filter []Op Modifiers []Modifier + StartWith string } func FilterEpochOnly(publish types.EpochID) Operations { @@ -71,13 +74,33 @@ func FilterEpochOnly(publish types.EpochID) Operations { func FilterFrom(operations Operations) string { var queryBuilder strings.Builder + bindIndex := 1 for i, op := range operations.Filter { if i == 0 { - queryBuilder.WriteString(" where") + if operations.StartWith != "" { + queryBuilder.WriteString(" " + operations.StartWith) + } else { + queryBuilder.WriteString(" where") + } } else { queryBuilder.WriteString(" and") } - fmt.Fprintf(&queryBuilder, " %s%s %s ?%d", op.Prefix, op.Field, op.Token, i+1) + + if op.Token == In { + values, ok := op.Value.([][]byte) + if !ok { + panic("value for 'In' token must be a slice of []byte") + } + placeholders := make([]string, len(values)) + for j := range values { + placeholders[j] = fmt.Sprintf("?%d", bindIndex) + bindIndex++ + } + fmt.Fprintf(&queryBuilder, " %s%s %s (%s)", op.Prefix, op.Field, op.Token, strings.Join(placeholders, ", ")) + } else { + fmt.Fprintf(&queryBuilder, " %s%s %s ?%d", op.Prefix, op.Field, op.Token, bindIndex) + bindIndex++ + } } for _, m := range operations.Modifiers { @@ -89,14 +112,23 @@ func FilterFrom(operations Operations) string { func BindingsFrom(operations Operations) sql.Encoder { return func(stmt *sql.Statement) { - for i, op := range operations.Filter { + bindIndex := 1 + for _, op := range operations.Filter { switch value := op.Value.(type) { case int64: - stmt.BindInt64(i+1, value) + stmt.BindInt64(bindIndex, value) + bindIndex++ case []byte: - stmt.BindBytes(i+1, value) + stmt.BindBytes(bindIndex, value) + bindIndex++ case types.EpochID: - stmt.BindInt64(i+1, int64(value)) + stmt.BindInt64(bindIndex, int64(value)) + bindIndex++ + case [][]byte: + for _, v := range value { + stmt.BindBytes(bindIndex, v) + bindIndex++ + } default: panic(fmt.Sprintf("unexpected type %T", value)) }