Skip to content

Commit

Permalink
Fix null comparison (#767)
Browse files Browse the repository at this point in the history
* fixed bug where null comparison did not work correctly

added and/or and is distinct to procedures

* fix lint

* fixed IS DISTINCT FROM generation bug
  • Loading branch information
brennanjl authored May 24, 2024
1 parent df53a6a commit e1a80b7
Show file tree
Hide file tree
Showing 9 changed files with 975 additions and 465 deletions.
6 changes: 6 additions & 0 deletions internal/engine/execution/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ func (d *baseDataset) Call(caller *precompiles.ProcedureContext, app *common.App
// return nil, err
// }

// this is not a strictly necessary check, as postgres will throw an error, but this gives a more
// helpful error message
if len(inputs) != len(proc.parameters) {
return nil, fmt.Errorf(`procedure "%s" expects %d argument(s), got %d`, method, len(proc.parameters), len(inputs))
}

res, err := app.DB.Execute(caller.Ctx, proc.callString(d.schema.DBID()), append([]any{pg.QueryModeExec}, inputs...)...)
if err != nil {
return nil, err
Expand Down
3 changes: 2 additions & 1 deletion internal/engine/generate/plpgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ func (s *sqlGenerator) VisitExpressionComparison(p0 *parse.ExpressionComparison)
str.WriteString(p0.Left.Accept(s).(string))
str.WriteString(" ")
str.WriteString(string(p0.Operator))
str.WriteString(" ")
str.WriteString(p0.Right.Accept(s).(string))
// compare cannot be typecasted
return str.String()
Expand Down Expand Up @@ -263,7 +264,7 @@ func (s *sqlGenerator) VisitExpressionIs(p0 *parse.ExpressionIs) any {
str.WriteString("NOT ")
}
if p0.Distinct {
str.WriteString("DISTINCT ")
str.WriteString("DISTINCT FROM ")
}
str.WriteString(p0.Right.Accept(s).(string))
// cannot be typecasted
Expand Down
45 changes: 45 additions & 0 deletions internal/engine/integration/procedure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,51 @@ func Test_Procedures(t *testing.T) {
inputs: []any{[]int64{1, 2, 3}},
outputs: [][]any{{[]any{int64(2), int64(4), int64(6)}}}, // returns 1 row, 1 column, with an array of ints
},
{
name: "is (null)",
procedure: `procedure is_null($val text) public view returns (is_null bool, is_null2 bool, is_null3 bool, is_null4 bool) {
$val2 := 1;
return $val is not distinct from null, $val2 is not distinct from null, $val is distinct from null, $val2 is distinct from null;
}`,
inputs: []any{nil},
outputs: [][]any{{true, false, false, true}},
},
{
name: "is (concrete)",
procedure: `procedure is_equal() public view returns (is_equal bool, is_equal2 bool, is_equal3 bool, is_equal4 bool) {
$val := 'hello';
return $val is not distinct from 'hello', $val is not distinct from 'world', $val is distinct from 'hello', $val is distinct from 'world';
}`,
outputs: [][]any{{true, false, false, true}},
},
{
name: "equals",
procedure: `procedure equals($val text) public view returns (is_equal bool, is_equal2 bool, is_equal3 bool, is_equal4 bool) {
$val2 text;
return $val = 'hello', $val = 'world', $val != null, $val2 != null;
}`,
inputs: []any{"hello"},
outputs: [][]any{{true, false, nil, nil}}, // equals with null should return null
},
{
name: "and/or",
procedure: `procedure and_or() public view returns (count int) {
$count := 0;
if true and true {
$count := $count + 1;
}
if true and false {
$count := $count + 100;
}
if (true or false) or (true or true) {
$count := $count + 10;
}
return $count;
}`,
outputs: [][]any{{int64(11)}},
},
}

for _, test := range tests {
Expand Down
60 changes: 60 additions & 0 deletions parse/antlr.go
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,66 @@ func (s *schemaVisitor) VisitVariable_procedure_expr(ctx *gen.Variable_procedure
return e
}

func (s *schemaVisitor) VisitIs_procedure_expr(ctx *gen.Is_procedure_exprContext) any {
e := &ExpressionIs{
Left: ctx.Procedure_expr(0).Accept(s).(Expression),
}

if ctx.NOT() != nil {
e.Not = true
}

if ctx.DISTINCT() != nil {
e.Distinct = true
}

switch {
case ctx.NULL() != nil:
e.Right = &ExpressionLiteral{
Type: types.NullType,
}
e.Right.SetToken(ctx.NULL().GetSymbol())
case ctx.TRUE() != nil:
e.Right = &ExpressionLiteral{
Type: types.BoolType,
Value: true,
}
e.Right.SetToken(ctx.TRUE().GetSymbol())
case ctx.FALSE() != nil:
e.Right = &ExpressionLiteral{
Type: types.BoolType,
Value: false,
}
e.Right.SetToken(ctx.FALSE().GetSymbol())
case ctx.GetRight() != nil:
e.Right = ctx.GetRight().Accept(s).(Expression)
default:
panic("unknown right side of IS")
}

e.Set(ctx)
return e
}

func (s *schemaVisitor) VisitLogical_procedure_expr(ctx *gen.Logical_procedure_exprContext) any {
e := &ExpressionLogical{
Left: ctx.Procedure_expr(0).Accept(s).(Expression),
Right: ctx.Procedure_expr(1).Accept(s).(Expression),
}

switch {
case ctx.AND() != nil:
e.Operator = LogicalOperatorAnd
case ctx.OR() != nil:
e.Operator = LogicalOperatorOr
default:
panic("unknown logical operator")
}

e.Set(ctx)
return e
}

func (s *schemaVisitor) VisitProcedure_expr_arithmetic(ctx *gen.Procedure_expr_arithmeticContext) any {
e := &ExpressionArithmetic{
Left: ctx.Procedure_expr(0).Accept(s).(Expression),
Expand Down
Loading

0 comments on commit e1a80b7

Please sign in to comment.