Skip to content

Commit

Permalink
Height var (#764)
Browse files Browse the repository at this point in the history
* Adds @height and contextual variables tests.

This commit adds an @height variable at the request of the Truflation team.
It also adds an integration test for testing contextual variables.

Finally, it fixes two bugs. The first one caused SQL generated for actions
to not use the `current_setting` postgres function, which caused it to
incorrectly read contextual variables. The second bug caused postgres to
incorrectly decode the transaction signer, essentially yielding an incorrect
signer for the @signer variable.

fix lint

* made json util more flexible

* fixed ci
  • Loading branch information
brennanjl authored May 24, 2024
1 parent b646ff5 commit f2a2dc3
Show file tree
Hide file tree
Showing 21 changed files with 314 additions and 81 deletions.
3 changes: 3 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ type TransactionData struct {

// TxID is the transaction ID of the incoming transaction.
TxID string

// Height is the block height of the incoming transaction.
Height int64
}

// ExecutionOptions is contextual data that is passed to a procedure
Expand Down
33 changes: 3 additions & 30 deletions core/rpc/client/user/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package http
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
Expand All @@ -18,6 +17,7 @@ import (
httpTx "github.com/kwilteam/kwil-db/core/rpc/http/tx"
"github.com/kwilteam/kwil-db/core/types"
"github.com/kwilteam/kwil-db/core/types/transactions"
jsonUtil "github.com/kwilteam/kwil-db/core/utils/json"

"github.com/antihax/optional"
)
Expand Down Expand Up @@ -113,7 +113,7 @@ func (c *Client) Call(ctx context.Context, msg *transactions.CallMessage, opts .
return nil, err
}

return unmarshalMapResults(decodedResult)
return jsonUtil.UnmarshalMapWithoutFloat(decodedResult)
}

func (c *Client) ChainInfo(ctx context.Context) (*types.ChainInfo, error) {
Expand Down Expand Up @@ -249,7 +249,7 @@ func (c *Client) Query(ctx context.Context, dbid string, query string) ([]map[st
return nil, err
}

return unmarshalMapResults(decodedResult)
return jsonUtil.UnmarshalMapWithoutFloat(decodedResult)
}

func (c *Client) TxQuery(ctx context.Context, txHash []byte) (*transactions.TcTxQueryResponse, error) {
Expand Down Expand Up @@ -288,30 +288,3 @@ func (c *Client) TxQuery(ctx context.Context, txHash []byte) (*transactions.TcTx
TxResult: *convertedTxResult,
}, nil
}

func unmarshalMapResults(b []byte) ([]map[string]any, error) {
d := json.NewDecoder(strings.NewReader(string(b)))
d.UseNumber()

// unmashal result
var result []map[string]any
err := d.Decode(&result)
if err != nil {
return nil, err
}

// convert numbers to int64
for _, record := range result {
for k, v := range record {
if num, ok := v.(json.Number); ok {
i, err := num.Int64()
if err != nil {
return nil, err
}
record[k] = i
}
}
}

return result, nil
}
45 changes: 3 additions & 42 deletions core/rpc/client/user/jsonrpc/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
"fmt"
"math/big"
"net/url"
"strings"

rpcclient "github.com/kwilteam/kwil-db/core/rpc/client"
"github.com/kwilteam/kwil-db/core/rpc/client/user"
jsonrpc "github.com/kwilteam/kwil-db/core/rpc/json"
"github.com/kwilteam/kwil-db/core/types"
"github.com/kwilteam/kwil-db/core/types/transactions"
jsonUtil "github.com/kwilteam/kwil-db/core/utils/json"
)

// Client is a JSON-RPC client for the Kwil user service. It use the JSONRPCClient
Expand Down Expand Up @@ -79,45 +79,6 @@ func (cl *Client) Broadcast(ctx context.Context, tx *transactions.Transaction, s
return res.TxHash, nil
}

func unmarshalMapResults(b []byte) ([]map[string]any, error) {
d := json.NewDecoder(strings.NewReader(string(b)))
d.UseNumber()

// unmashal result
var result []map[string]any
err := d.Decode(&result)
if err != nil {
return nil, err
}

// convert numbers to int64
for _, record := range result {
for k, v := range record {
if num, ok := v.(json.Number); ok {
i, err := num.Int64()
if err != nil {
record[k] = num.String()
} else {
record[k] = i
}
} else if num, ok := v.([]any); ok {
for j, n := range num {
if n, ok := n.(json.Number); ok {
i, err := n.Int64()
if err != nil {
num[j] = n.String()
} else {
num[j] = i
}
}
}
}
}
}

return result, nil
}

func (cl *Client) Call(ctx context.Context, msg *transactions.CallMessage, opts ...rpcclient.ActionCallOption) ([]map[string]any, error) {
cmd := &jsonrpc.CallRequest{
Body: msg.Body,
Expand All @@ -129,7 +90,7 @@ func (cl *Client) Call(ctx context.Context, msg *transactions.CallMessage, opts
if err != nil {
return nil, err
}
return unmarshalMapResults(res.Result)
return jsonUtil.UnmarshalMapWithoutFloat(res.Result)
}

func (cl *Client) ChainInfo(ctx context.Context) (*types.ChainInfo, error) {
Expand Down Expand Up @@ -230,7 +191,7 @@ func (cl *Client) Query(ctx context.Context, dbid, query string) ([]map[string]a
if err != nil {
return nil, err
}
return unmarshalMapResults(res.Result)
return jsonUtil.UnmarshalMapWithoutFloat(res.Result)
}

func (cl *Client) TxQuery(ctx context.Context, txHash []byte) (*transactions.TcTxQueryResponse, error) {
Expand Down
75 changes: 75 additions & 0 deletions core/utils/json/json.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// package json includes JSON utilities commonly used in Kwil.
package json

import (
"encoding/json"
"reflect"
"strings"
)

// UnmarshalMapWithoutFloat unmarshals a JSON byte slice into a slice of maps.
// It will try to convert all return values into ints, but will keep them as strings if it fails.
// It ensures they aren't returned as floats, which is important for maintaining consistency
// with Kwil's decimal types. All returned types will be string or int64.
func UnmarshalMapWithoutFloat(b []byte) ([]map[string]any, error) {
d := json.NewDecoder(strings.NewReader(string(b)))
d.UseNumber()

// unmashal result
var result []map[string]any
err := d.Decode(&result)
if err != nil {
return nil, err
}

// convert numbers to int64
result = convertJsonNumbers(result).([]map[string]any)

return result, nil
}

// convertJsonNumbers recursively converts json.Number to int64.
// It traverses through the map and array and converts all json.Number to int64.
func convertJsonNumbers(val any) any {
if val == nil {
return nil
}
switch val := val.(type) {
case map[string]any:
for k, v := range val {
val[k] = convertJsonNumbers(v)
}
return val
case []map[string]any:
for i, v := range val {
for j, n := range v {
v[j] = convertJsonNumbers(n)
}
val[i] = v
}
return val
case []any:
for i, v := range val {
val[i] = convertJsonNumbers(v)
}
return val
case json.Number:
i, err := val.Int64()
if err != nil {
return val.String()
}
return i
default:
// in case we are unmarshalling something crazy like a double nested slice,
// we reflect on the value and recursively call convertJsonNumbers if it's a slice.
typeOf := reflect.TypeOf(val)
if typeOf.Kind() == reflect.Slice {
s := reflect.ValueOf(val)
for i := 0; i < s.Len(); i++ {
s.Index(i).Set(reflect.ValueOf(convertJsonNumbers(s.Index(i).Interface())))
}
return s.Interface()
}
return val
}
}
64 changes: 64 additions & 0 deletions core/utils/json/json_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package json

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/require"
)

func Test_convertJsonNumbers(t *testing.T) {

tests := []struct {
name string
val any
want any
}{
{
name: "number",
val: json.Number("123"),
want: int64(123),
},
{
name: "object",
val: map[string]any{
"key": json.Number("123"),
"val": []map[string]any{
{
"key": json.Number("123"),
},
{
"key": json.Number("456"),
"val": []map[string]any{
{
"key": json.Number("789"),
},
},
},
},
},
want: map[string]any{
"key": int64(123),
"val": []map[string]any{
{
"key": int64(123),
},
{
"key": int64(456),
"val": []map[string]any{
{
"key": int64(789),
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := convertJsonNumbers(tt.val)
require.EqualValues(t, tt.want, a)
})
}
}
4 changes: 4 additions & 0 deletions extensions/precompiles/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ type ProcedureContext struct {
Procedure string
// Result is the result of the most recent SQL query.
Result *sql.ResultSet
// Height is the block height of the current execution.
Height int64

// StackDepth tracks the current depth of the procedure call stack. It is
// incremented each time a procedure calls another procedure.
Expand Down Expand Up @@ -104,6 +106,7 @@ func (p *ProcedureContext) Values() map[string]any {
values["@caller"] = p.Caller
values["@txid"] = p.TxID
values["@signer"] = p.Signer
values["@height"] = p.Height

return values
}
Expand All @@ -122,6 +125,7 @@ func (p *ProcedureContext) NewScope() *ProcedureContext {
Procedure: p.Procedure,
StackDepth: p.StackDepth,
UsedGas: p.UsedGas,
Height: p.Height,
}
}

Expand Down
1 change: 1 addition & 0 deletions internal/engine/execution/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ func (g *GlobalContext) Procedure(ctx context.Context, tx sql.DB, options *commo
DBID: options.Dataset,
Procedure: options.Procedure,
TxID: options.TxID,
Height: options.Height,
// starting with stack depth 0, increment in each action call
}

Expand Down
9 changes: 7 additions & 2 deletions internal/engine/execution/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package execution

import (
"context"
"encoding/hex"
"encoding/base64"
"encoding/json"
"fmt"

Expand Down Expand Up @@ -334,7 +334,12 @@ func setContextualVars(ctx context.Context, db sql.DB, data *common.ExecutionDat
return err
}

_, err = db.Execute(ctx, fmt.Sprintf(`SET LOCAL %s.%s = '%s';`, generate.PgSessionPrefix, parse.SignerVar, hex.EncodeToString(data.Signer)))
_, err = db.Execute(ctx, fmt.Sprintf(`SET LOCAL %s.%s = '%s';`, generate.PgSessionPrefix, parse.SignerVar, base64.StdEncoding.EncodeToString(data.Signer)))
if err != nil {
return err
}

_, err = db.Execute(ctx, fmt.Sprintf(`SET LOCAL %s.%s = %d;`, generate.PgSessionPrefix, parse.HeightVar, data.Height))
if err != nil {
return err
}
Expand Down
6 changes: 4 additions & 2 deletions internal/engine/generate/plpgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ func (s *sqlGenerator) VisitExpressionForeignCall(p0 *parse.ExpressionForeignCal
}

func (s *sqlGenerator) VisitExpressionVariable(p0 *parse.ExpressionVariable) any {
if s.numberParameters {
// if a user param $, then we need to number it.
// Vars using @ get set and accessed using postgres's current_setting function
if s.numberParameters && p0.Prefix == parse.VariablePrefixDollar {
str := p0.String()

// if it already exists, we write it as that index.
Expand Down Expand Up @@ -978,7 +980,7 @@ func formatContextualVariableName(name string) string {

switch dataType {
case types.BlobType:
return fmt.Sprintf("%s::bytea", str)
return fmt.Sprintf("decode(%s, 'base64')", str)
case types.IntType:
return fmt.Sprintf("%s::int8", str)
case types.BoolType:
Expand Down
1 change: 1 addition & 0 deletions internal/services/grpc/txsvc/v1/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func (s *Service) Call(ctx context.Context, req *txpb.CallRequest) (*txpb.CallRe
TransactionData: common.TransactionData{
Signer: signer,
Caller: caller,
Height: -1, // not available
},
})
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions internal/services/jsonrpc/usersvc/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ func (svc *Service) Call(ctx context.Context, req *jsonrpc.CallRequest) (*jsonrp
TransactionData: common.TransactionData{
Signer: signer,
Caller: caller,
Height: -1, // not available
},
})
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions internal/txapp/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ func (e *executeActionRoute) Execute(ctx TxContext, router *TxApp, tx *transacti
Signer: tx.Sender,
Caller: identifier,
TxID: hex.EncodeToString(ctx.TxID),
Height: ctx.BlockHeight,
},
})
if err != nil {
Expand Down
Loading

0 comments on commit f2a2dc3

Please sign in to comment.