Skip to content

Commit

Permalink
added crypto functions, encode/decode (#772)
Browse files Browse the repository at this point in the history
  • Loading branch information
brennanjl authored May 28, 2024
1 parent 8e0c11b commit ae11433
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 0 deletions.
14 changes: 14 additions & 0 deletions internal/engine/integration/procedure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ package integration_test

import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"strings"
"testing"

"github.com/kwilteam/kwil-db/common"
"github.com/kwilteam/kwil-db/core/crypto"
"github.com/kwilteam/kwil-db/parse"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -189,6 +192,17 @@ func Test_Procedures(t *testing.T) {
}`,
outputs: [][]any{{int64(0)}},
},
{
name: "encode, decode, and digest functions",
procedure: `procedure encode_decode_digest($hex text) public view returns (encoded text, decoded blob, digest blob) {
$decoded := decode($hex, 'hex');
$encoded := encode($decoded, 'base64');
$digest := digest($decoded, 'sha256');
return $encoded, $decoded, $digest;
}`,
inputs: []any{hex.EncodeToString([]byte("hello"))},
outputs: [][]any{{base64.StdEncoding.EncodeToString([]byte("hello")), []byte("hello"), crypto.Sha256([]byte("hello"))}},
},
}

for _, test := range tests {
Expand Down
4 changes: 4 additions & 0 deletions internal/sql/pg/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ func NewDB(ctx context.Context, cfg *DBConfig) (*DB, error) {
return nil, fmt.Errorf("failed to create UUID extension: %w", err)
}

if err = ensurePgCryptoExtension(ctx, conn); err != nil {
return nil, fmt.Errorf("failed to create pgcrypto extension: %w", err)
}

okSchema := cfg.SchemaFilter
if okSchema == nil {
okSchema = defaultSchemaFilter
Expand Down
5 changes: 5 additions & 0 deletions internal/sql/pg/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ func ensureUUIDExtension(ctx context.Context, conn *pgx.Conn) error {
return err
}

func ensurePgCryptoExtension(ctx context.Context, conn *pgx.Conn) error {
_, err := conn.Exec(ctx, `CREATE EXTENSION IF NOT EXISTS pgcrypto;`)
return err
}

func ensureUint256Domain(ctx context.Context, conn *pgx.Conn) error {
_, err := conn.Exec(ctx, sqlCreateUint256Domain)
return err
Expand Down
84 changes: 84 additions & 0 deletions parse/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,90 @@ var (
return fmt.Sprintf("uuid_generate_v5(%s)", strings.Join(inputs, ", ")), nil
},
},
"encode": {
ValidateArgs: func(args []*types.DataType) (*types.DataType, error) {
// first must be blob, second must be text
if len(args) != 2 {
return nil, wrapErrArgumentNumber(2, len(args))
}

if !args[0].EqualsStrict(types.BlobType) {
return nil, wrapErrArgumentType(types.BlobType, args[0])
}

if !args[1].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[1])
}

return types.TextType, nil
},
PGFormat: func(inputs []string, distinct bool, star bool) (string, error) {
if star {
return "", errStar("encode")
}
if distinct {
return "", errDistinct("encode")
}

return fmt.Sprintf("encode(%s)", strings.Join(inputs, ", ")), nil
},
},
"decode": {
ValidateArgs: func(args []*types.DataType) (*types.DataType, error) {
// first must be text, second must be text
if len(args) != 2 {
return nil, wrapErrArgumentNumber(2, len(args))
}

if !args[0].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[0])
}

if !args[1].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[1])
}

return types.BlobType, nil
},
PGFormat: func(inputs []string, distinct bool, star bool) (string, error) {
if star {
return "", errStar("decode")
}
if distinct {
return "", errDistinct("decode")
}

return fmt.Sprintf("decode(%s)", strings.Join(inputs, ", ")), nil
},
},
"digest": {
ValidateArgs: func(args []*types.DataType) (*types.DataType, error) {
// first must be either text or blob, second must be text
if len(args) != 2 {
return nil, wrapErrArgumentNumber(2, len(args))
}

if !args[0].EqualsStrict(types.TextType) && !args[0].EqualsStrict(types.BlobType) {
return nil, fmt.Errorf("expected first argument to be text or blob, got %s", args[0].String())
}

if !args[1].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[1])
}

return types.BlobType, nil
},
PGFormat: func(inputs []string, distinct bool, star bool) (string, error) {
if star {
return "", errStar("digest")
}
if distinct {
return "", errDistinct("digest")
}

return fmt.Sprintf("digest(%s)", strings.Join(inputs, ", ")), nil
},
},
// array functions
"array_append": {
ValidateArgs: func(args []*types.DataType) (*types.DataType, error) {
Expand Down

0 comments on commit ae11433

Please sign in to comment.