Skip to content

Commit

Permalink
added min and max (#779)
Browse files Browse the repository at this point in the history
  • Loading branch information
brennanjl authored May 30, 2024
1 parent f8a7dd5 commit 1661f0f
Show file tree
Hide file tree
Showing 10 changed files with 2,026 additions and 2,489 deletions.
43 changes: 35 additions & 8 deletions internal/engine/integration/procedure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ func Test_Procedures(t *testing.T) {
tests := []testcase{
{
name: "basic test",
procedure: `procedure create_user2($name text) public {
INSERT INTO users (id, name, wallet_address)
procedure: `procedure create_user2($name text, $usernum int) public {
INSERT INTO users (id, name, wallet_address, user_num)
VALUES (uuid_generate_v5('985b93a4-2045-44d6-bde4-442a4e498bc6'::uuid, @txid),
$name,
@caller
@caller,
$usernum
);
}`,
inputs: []any{"test_user"},
inputs: []any{"test_user", 4},
},
{
name: "for loop",
Expand Down Expand Up @@ -240,6 +241,21 @@ func Test_Procedures(t *testing.T) {
}
}`,
},
{
name: "min/max",
procedure: `procedure min_max() public view returns (min int, max int) {
$max := 0;
for $row in select max(user_num) as m from users {
$max := $row.m;
}
$min := 0;
for $row2 in select min(user_num) as m from users {
$min := $row2.m;
}
return $min, $max;
}`,
outputs: [][]any{{int64(1), int64(3)}},
},
}

for _, test := range tests {
Expand Down Expand Up @@ -486,7 +502,8 @@ database ecclesia;
table users {
id uuid primary key,
name text not null maxlen(100) minlen(4) unique,
wallet_address text not null
wallet_address text not null,
user_num int unique notnull // this could be the primary key, but it's more for testing than to be useful
}
table posts {
Expand All @@ -497,10 +514,20 @@ table posts {
}
procedure create_user($name text) public {
INSERT INTO users (id, name, wallet_address)
$max int;
for $row in select max(user_num) as m from users {
$max := $row.m;
}
if $max is null {
$max := 0;
}
INSERT INTO users (id, name, wallet_address, user_num)
VALUES (uuid_generate_v5('985b93a4-2045-44d6-bde4-442a4e498bc6'::uuid, @txid),
$name,
@caller
@caller,
$max + 1
);
}
Expand Down Expand Up @@ -540,7 +567,7 @@ procedure delete_users() public {
}
procedure get_users() public returns table(id uuid, name text, wallet_address text) {
return SELECT * FROM users;
return SELECT id, name, wallet_address FROM users;
}
`

Expand Down
175 changes: 117 additions & 58 deletions parse/antlr.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,18 +448,130 @@ func (s *schemaVisitor) VisitTable_declaration(ctx *gen.Table_declarationContext

func (s *schemaVisitor) VisitColumn_def(ctx *gen.Column_defContext) any {
col := &types.Column{
Name: s.getIdent(ctx.IDENTIFIER()),
Type: ctx.Type_().Accept(s).(*types.DataType),
Attributes: arr[*types.Attribute](len(ctx.AllConstraint())),
Name: s.getIdent(ctx.IDENTIFIER()),
Type: ctx.Type_().Accept(s).(*types.DataType),
}

// due to unfortunate lexing edge cases to support min/max, we
// have to parse the constraints here. Each constraint is a text, and should be
// one of:
// MIN/MAX/MINLEN/MAXLEN/MIN_LENGTH/MAX_LENGTH/NOTNULL/NOT/NULL/PRIMARY/KEY/PRIMARY_KEY/PK/DEFAULT/UNIQUE
// If NOT is present, it needs to be followed by NULL; similarly, if NULL is present, it needs to be preceded by NOT.
// If PRIMARY is present, it can be followed by key, but does not have to be. key must be preceded by primary.
// MIN, MAX, MINLEN, MAXLEN, MIN_LENGTH, MAX_LENGTH, and DEFAULT must also have a literal following them.
type constraint struct {
ident string
lit *string
}
constraints := make([]constraint, len(ctx.AllConstraint()))
for i, c := range ctx.AllConstraint() {
con := constraint{}
switch {
case c.IDENTIFIER() != nil:
con.ident = c.IDENTIFIER().GetText()
case c.PRIMARY() != nil:
con.ident = "primary_key"
case c.NOT() != nil:
con.ident = "notnull"
case c.DEFAULT() != nil:
con.ident = "default"
case c.UNIQUE() != nil:
con.ident = "unique"
default:
panic("unknown constraint")
}

if c.Literal() != nil {
l := strings.ToLower(c.Literal().Accept(s).(*ExpressionLiteral).String())
con.lit = &l
}
constraints[i] = con
}

for i, a := range ctx.AllConstraint() {
col.Attributes[i] = a.Accept(s).(*types.Attribute)
for i := 0; i < len(constraints); i++ {
switch constraints[i].ident {
case "min":
if constraints[i].lit == nil {
s.errs.RuleErr(ctx, ErrSyntax, "missing literal for min constraint")
return col
}
col.Attributes = append(col.Attributes, &types.Attribute{
Type: types.MIN,
Value: *constraints[i].lit,
})
case "max":
if constraints[i].lit == nil {
s.errs.RuleErr(ctx, ErrSyntax, "missing literal for max constraint")
return col
}
col.Attributes = append(col.Attributes, &types.Attribute{
Type: types.MAX,
Value: *constraints[i].lit,
})
case "minlen", "min_length":
if constraints[i].lit == nil {
s.errs.RuleErr(ctx, ErrSyntax, "missing literal for min length constraint")
return col
}
col.Attributes = append(col.Attributes, &types.Attribute{
Type: types.MIN_LENGTH,
Value: *constraints[i].lit,
})
case "maxlen", "max_length":
if constraints[i].lit == nil {
s.errs.RuleErr(ctx, ErrSyntax, "missing literal for max length constraint")
return col
}
col.Attributes = append(col.Attributes, &types.Attribute{
Type: types.MAX_LENGTH,
Value: *constraints[i].lit,
})
case "notnull":
if constraints[i].lit != nil {
s.errs.RuleErr(ctx, ErrSyntax, "unexpected literal for not null constraint")
return col
}
col.Attributes = append(col.Attributes, &types.Attribute{
Type: types.NOT_NULL,
})
case "primary_key", "pk":
if constraints[i].lit != nil {
s.errs.RuleErr(ctx, ErrSyntax, "unexpected literal for primary key constraint")
return col
}
col.Attributes = append(col.Attributes, &types.Attribute{
Type: types.PRIMARY_KEY,
})
case "default":
if constraints[i].lit == nil {
s.errs.RuleErr(ctx, ErrSyntax, "missing literal for default constraint")
return col
}
col.Attributes = append(col.Attributes, &types.Attribute{
Type: types.DEFAULT,
Value: *constraints[i].lit,
})
case "unique":
if constraints[i].lit != nil {
s.errs.RuleErr(ctx, ErrSyntax, "unexpected literal for unique constraint")
return col
}
col.Attributes = append(col.Attributes, &types.Attribute{
Type: types.UNIQUE,
})
default:
s.errs.RuleErr(ctx, ErrSyntax, "unknown constraint: %s", constraints[i].ident)
return col
}
}

return col
}

func (s *schemaVisitor) VisitConstraint(ctx *gen.ConstraintContext) any {
panic("VisitConstraint should not be called, as the logic should be implemented in VisitColumn_def")
}

func (s *schemaVisitor) VisitIndex_def(ctx *gen.Index_defContext) any {
name := ctx.HASH_IDENTIFIER().GetText()
name = strings.TrimLeft(name, "#")
Expand Down Expand Up @@ -562,59 +674,6 @@ func (s *schemaVisitor) VisitTyped_variable_list(ctx *gen.Typed_variable_listCon
return vars
}

func (s *schemaVisitor) VisitMin_constraint(ctx *gen.Min_constraintContext) any {
return &types.Attribute{
Type: types.MIN,
Value: ctx.Literal().Accept(s).(*ExpressionLiteral).String(),
}
}

func (s *schemaVisitor) VisitMax_constraint(ctx *gen.Max_constraintContext) any {
return &types.Attribute{
Type: types.MAX,
Value: ctx.Literal().Accept(s).(*ExpressionLiteral).String(),
}
}

func (s *schemaVisitor) VisitMin_len_constraint(ctx *gen.Min_len_constraintContext) any {
return &types.Attribute{
Type: types.MIN_LENGTH,
Value: ctx.Literal().Accept(s).(*ExpressionLiteral).String(),
}
}

func (s *schemaVisitor) VisitMax_len_constraint(ctx *gen.Max_len_constraintContext) any {
return &types.Attribute{
Type: types.MAX_LENGTH,
Value: ctx.Literal().Accept(s).(*ExpressionLiteral).String(),
}
}

func (s *schemaVisitor) VisitNot_null_constraint(ctx *gen.Not_null_constraintContext) any {
return &types.Attribute{
Type: types.NOT_NULL,
}
}

func (s *schemaVisitor) VisitPrimary_key_constraint(ctx *gen.Primary_key_constraintContext) any {
return &types.Attribute{
Type: types.PRIMARY_KEY,
}
}

func (s *schemaVisitor) VisitDefault_constraint(ctx *gen.Default_constraintContext) any {
return &types.Attribute{
Type: types.DEFAULT,
Value: ctx.Literal().Accept(s).(*ExpressionLiteral).String(),
}
}

func (s *schemaVisitor) VisitUnique_constraint(ctx *gen.Unique_constraintContext) any {
return &types.Attribute{
Type: types.UNIQUE,
}
}

func (s *schemaVisitor) VisitAccess_modifier(ctx *gen.Access_modifierContext) any {
// we will have to parse this at a later stage, since this is either public/private,
// or a types.Modifier
Expand Down
50 changes: 50 additions & 0 deletions parse/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,56 @@ var (
},
StarArgReturn: types.IntType,
},
"min": {
ValidateArgs: func(args []*types.DataType) (*types.DataType, error) {
// as per postgres docs, min can take any numeric or string type: https://www.postgresql.org/docs/8.0/functions-aggregate.html
if len(args) != 1 {
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].IsNumeric() && !args[0].EqualsStrict(types.TextType) {
return nil, fmt.Errorf("expected argument to be numeric or text, got %s", args[0].String())
}

return args[0], nil
},
IsAggregate: true,
PGFormat: func(inputs []string, distinct bool, star bool) (string, error) {
if star {
return "", errStar("min")
}
if distinct {
return "min(DISTINCT %s)", nil
}

return fmt.Sprintf("min(%s)", inputs[0]), nil
},
},
"max": {
ValidateArgs: func(args []*types.DataType) (*types.DataType, error) {
// as per postgres docs, max can take any numeric or string type: https://www.postgresql.org/docs/8.0/functions-aggregate.html
if len(args) != 1 {
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].IsNumeric() && !args[0].EqualsStrict(types.TextType) {
return nil, fmt.Errorf("expected argument to be numeric or text, got %s", args[0].String())
}

return args[0], nil
},
IsAggregate: true,
PGFormat: func(inputs []string, distinct bool, star bool) (string, error) {
if star {
return "", errStar("max")
}
if distinct {
return "max(DISTINCT %s)", nil
}

return fmt.Sprintf("max(%s)", inputs[0]), nil
},
},
}
)

Expand Down
Loading

0 comments on commit 1661f0f

Please sign in to comment.