From 327a34f12c767d3b9a83f57393bc2ee1e4aabff6 Mon Sep 17 00:00:00 2001 From: Brennan Lamey <66885902+brennanjl@users.noreply.github.com> Date: Fri, 24 May 2024 18:46:27 -0500 Subject: [PATCH] add missing return detection (#769) * add missing return detection fixed several bugs * fixed failing test * added sanity check test * prevent error when looping on null arrays * fix select count(*) ordering bug * add comment --- internal/engine/generate/plpgsql.go | 3 +- internal/engine/integration/procedure_test.go | 32 ++ parse/analyze.go | 231 ++++++-- parse/parse.go | 12 +- parse/parse_test.go | 523 +++++++++++++++--- 5 files changed, 675 insertions(+), 126 deletions(-) diff --git a/internal/engine/generate/plpgsql.go b/internal/engine/generate/plpgsql.go index 9956c4029..38a16ff9c 100644 --- a/internal/engine/generate/plpgsql.go +++ b/internal/engine/generate/plpgsql.go @@ -817,7 +817,8 @@ func (p *procedureGenerator) VisitLoopTermSQL(p0 *parse.LoopTermSQL) any { } func (p *procedureGenerator) VisitLoopTermVariable(p0 *parse.LoopTermVariable) any { - return fmt.Sprintf("ARRAY %s", p0.Variable.Accept(p).(string)) + // we use coalesce here so that we do not error when looping on null arrays + return fmt.Sprintf("ARRAY COALESCE(%s, '{}')", p0.Variable.Accept(p).(string)) } func (p *procedureGenerator) VisitProcedureStmtIf(p0 *parse.ProcedureStmtIf) any { diff --git a/internal/engine/integration/procedure_test.go b/internal/engine/integration/procedure_test.go index fcea643ed..09fe279f3 100644 --- a/internal/engine/integration/procedure_test.go +++ b/internal/engine/integration/procedure_test.go @@ -157,6 +157,38 @@ func Test_Procedures(t *testing.T) { }`, outputs: [][]any{{int64(11)}}, }, + { + name: "return next from a non-table", + procedure: `procedure return_next($vals int[]) public view returns table(val int) { + for $i in $vals { + return next $i*2; + } + }`, + inputs: []any{[]int64{1, 2, 3}}, + outputs: [][]any{{int64(2)}, {int64(4)}, {int64(6)}}, + }, + { + name: "table return with no hits doesn't return postgres no-return error", + procedure: `procedure return_next($vals int[]) public view returns table(val int) { + for $i in $vals { + error('unreachable'); + } + }`, + inputs: []any{[]int64{}}, + outputs: [][]any{}, + }, + { + name: "loop over null array", + procedure: `procedure loop_over_null() public view returns (count int) { + $vals int[]; + $count := 0; + for $i in $vals { + $count := $count + 1; + } + return $count; + }`, + outputs: [][]any{{int64(0)}}, + }, } for _, test := range tests { diff --git a/parse/analyze.go b/parse/analyze.go index b70775e81..eee4460fb 100644 --- a/parse/analyze.go +++ b/parse/analyze.go @@ -136,6 +136,9 @@ type sqlContext struct { // containsAggregate is true if the current expression contains an aggregate function. // it is set in ExpressionFunctionCall, and accessed/reset in SelectCore. _containsAggregate bool + // containsAggregateWithoutGroupBy is true if the current expression contains an aggregate function, + // but there is no GROUP BY clause. This is set in SelectCore, and accessed in SelectStatement. + _containsAggregateWithoutGroupBy bool // columnInAggregate is the column found within an aggregate function, // comprised of the relation and attribute. // It is set in ExpressionColumn, and accessed/reset in @@ -185,10 +188,11 @@ func (c *sqlContext) copy() sqlContext { copy(colsOutsideAgg, c._columnsOutsideAggregate) return sqlContext{ - joinedRelations: joinedRelations, - outerRelations: outerRelations, - ctes: c.ctes, - joinedTables: c.joinedTables, + joinedRelations: joinedRelations, + outerRelations: outerRelations, + ctes: c.ctes, + joinedTables: c.joinedTables, + _containsAggregateWithoutGroupBy: c._containsAggregateWithoutGroupBy, // we want to carry this over } } @@ -419,7 +423,7 @@ func (s *sqlAnalyzer) expressionTypeErr(e Expression) *types.DataType { // if expression is a receiver from a loop, it will be a map _, ok := e.Accept(s).(map[string]*types.DataType) if ok { - s.errs.AddErr(e, ErrType, "invalid usage of compound type, expected scalar value") + s.errs.AddErr(e, ErrType, "invalid usage of compound type. you must reference a field using $compound.field notation") return cast(e, types.UnknownType) } @@ -1263,6 +1267,13 @@ func (s *sqlAnalyzer) VisitSelectStatement(p0 *SelectStatement) any { } } + // we want to re-set the rel1 scope, since it is used in ordering, + // as well as grouping re-checks if the statement is not a compound select. + // e.g. "select a, b from t1 union select c, d from t2 order by a" + oldScope := s.sqlCtx + s.sqlCtx = rel1Scope + defer func() { s.sqlCtx = oldScope }() + // If it is not a compound select, we should use the scope from the first select core, // so that we can analyze joined tables in the order and limit clauses. It if is a compound // select, then we should flatten all joined tables into a single anonymous table. This can @@ -1310,7 +1321,12 @@ func (s *sqlAnalyzer) VisitSelectStatement(p0 *SelectStatement) any { // we order by all columns returned, in the order they are returned. // 3. If there is a group by clause, none of the above apply, and instead we order by // all columns specified in the group by. - if p0.SelectCores[0].GroupBy != nil { + // 4. If there is an aggregate clause with no group by, then no ordering is applied. + + // addressing point 4: if there is an aggregate clause with no group by, then no ordering is applied. + if s.sqlCtx._containsAggregateWithoutGroupBy { + // do nothing. + } else if p0.SelectCores[0].GroupBy != nil { // reset and visit the group by to get the columns var colsToOrder [][2]string for _, g := range p0.SelectCores[0].GroupBy { @@ -1364,14 +1380,15 @@ func (s *sqlAnalyzer) VisitSelectStatement(p0 *SelectStatement) any { } } - oldScope := s.sqlCtx - s.sqlCtx = rel1Scope - defer func() { s.sqlCtx = oldScope }() - // we need to inform the analyzer that we are in ordering s.sqlCtx._inOrdering = true s.sqlCtx._result = rel1 + // if the user is trying to order and there is an aggregate without group by, we should throw an error. + if s.sqlCtx._containsAggregateWithoutGroupBy && len(p0.Ordering) > 0 { + s.errs.AddErr(p0, ErrAggregate, "cannot use order by with aggregate function without group by") + return rel1 + } // analyze the ordering, limit, and offset for _, o := range p0.Ordering { o.Accept(s) @@ -1479,15 +1496,16 @@ func (s *sqlAnalyzer) VisitSelectCore(p0 *SelectCore) any { // columns in having must be in the group by if not in aggregate for _, col := range s.sqlCtx._columnsOutsideAggregate { if _, ok := colsInGroupBy[col]; !ok { - s.errs.AddErr(p0.Having, ErrAggregate, "column used in having must be in group by") + s.errs.AddErr(p0.Having, ErrAggregate, "column used in having must be in group by, or must be in aggregate function") } } - if s.sqlCtx._columnInAggregate != nil { - if _, ok := colsInGroupBy[*s.sqlCtx._columnInAggregate]; !ok { - s.errs.AddErr(p0.Having, ErrAggregate, "cannot use column in aggregate if not in group by") - } - } + // COMMENTING THIS OUT: if a column is in an aggregate in the having, then it is ok if it is not in the group by + // if s.sqlCtx._columnInAggregate != nil { + // if _, ok := colsInGroupBy[*s.sqlCtx._columnInAggregate]; !ok { + // s.errs.AddErr(p0.Having, ErrAggregate, "cannot use column in having if not in group by or in aggregate function") + // } + // } s.expect(p0.Having, havingType, types.BoolType) } @@ -1519,6 +1537,7 @@ func (s *sqlAnalyzer) VisitSelectCore(p0 *SelectCore) any { if len(p0.Columns) != 1 { s.errs.AddErr(c, ErrAggregate, "cannot return multiple values in SELECT that uses aggregate function and no group by") } + s.sqlCtx._containsAggregateWithoutGroupBy = true } else if hasGroupBy { // if column used in aggregate, ensure it is not in group by if s.sqlCtx._columnInAggregate != nil { @@ -2045,7 +2064,9 @@ type loopTargetTracker struct { // language can execute sql statements, it uses the sqlAnalyzer. type procedureAnalyzer struct { sqlAnalyzer - procCtx *procedureContext + procCtx *procedureContext + // procResult stores data that the analyzer will return with the parsed procedure. + // The information is used by the code generator to generate the plpgsql code. procResult struct { // allLoopReceivers tracks all loop receivers that have occurred over the lifetime // of the procedure. This is used to generate variables to hold the loop target @@ -2104,7 +2125,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtDeclaration(p0 *ProcedureStmtDecla if p.variableExists(p0.Variable.String()) { p.errs.AddErr(p0, ErrVariableAlreadyDeclared, p0.Variable.String()) - return nil + return zeroProcedureReturn() } // TODO: we need to figure out how to undeclare a variable if it is declared in a loop/if block @@ -2115,7 +2136,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtDeclaration(p0 *ProcedureStmtDecla // now that it is declared, we can visit it p0.Variable.Accept(p) - return nil + return zeroProcedureReturn() } func (p *procedureAnalyzer) VisitProcedureStmtAssignment(p0 *ProcedureStmtAssign) any { @@ -2123,7 +2144,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtAssignment(p0 *ProcedureStmtAssign dt, ok := p0.Value.Accept(p).(*types.DataType) if !ok { p.expressionTypeErr(p0.Value) - return nil + return zeroProcedureReturn() } _, ok = p.variables[p0.Variable.String()] @@ -2131,7 +2152,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtAssignment(p0 *ProcedureStmtAssign // if it does not exist, we can declare it here. p.variables[p0.Variable.String()] = dt p.markDeclared(p0.Variable, p0.Variable.String(), dt) - return nil + return zeroProcedureReturn() } // the type can be inferred from the value. @@ -2140,7 +2161,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtAssignment(p0 *ProcedureStmtAssign if p0.Type != nil { if !p0.Type.Equals(dt) { p.errs.AddErr(p0, ErrType, "declared type: %s, inferred type: %s", p0.Type.String(), dt.String()) - return nil + return zeroProcedureReturn() } } @@ -2148,14 +2169,14 @@ func (p *procedureAnalyzer) VisitProcedureStmtAssignment(p0 *ProcedureStmtAssign dt2, ok := p0.Variable.Accept(p).(*types.DataType) if !ok { p.expressionTypeErr(p0.Variable) - return nil + return zeroProcedureReturn() } if !dt2.Equals(dt) { p.typeErr(p0, dt2, dt) } - return nil + return zeroProcedureReturn() } func (p *procedureAnalyzer) VisitProcedureStmtCall(p0 *ProcedureStmtCall) any { @@ -2176,11 +2197,16 @@ func (p *procedureAnalyzer) VisitProcedureStmtCall(p0 *ProcedureStmtCall) any { returns2, ok := p0.Call.Accept(p).([]*types.DataType) if !ok { p.errs.AddErr(p0.Call, ErrType, "expected function/procedure to return one or more variables") - return nil + return zeroProcedureReturn() } callReturns = returns2 } + // if calling the `error` function, then this branch will return + exits := false + if p0.Call.FunctionName() == "error" { + exits = true + } // if calling a non-view procedure, the above will set the sqlResult to be mutative // if this procedure is a view, we should throw an error. @@ -2192,7 +2218,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtCall(p0 *ProcedureStmtCall) any { // we do not have more receivers than return values. if len(p0.Receivers) != len(callReturns) { p.errs.AddErr(p0, ErrResultShape, `function/procedure "%s" returns %d value(s), statement expects %d value(s)`, p0.Call.FunctionName(), len(callReturns), len(p0.Receivers)) - return nil + return zeroProcedureReturn() } for i, r := range p0.Receivers { @@ -2222,7 +2248,9 @@ func (p *procedureAnalyzer) VisitProcedureStmtCall(p0 *ProcedureStmtCall) any { } } - return nil + return &procedureStmtResult{ + willReturn: exits, + } } // VisitProcedureStmtForLoop visits a for loop statement. @@ -2236,7 +2264,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtForLoop(p0 *ProcedureStmtForLoop) // check to make sure the receiver has not already been declared if p.variableExists(p0.Receiver.String()) { p.errs.AddErr(p0.Receiver, ErrVariableAlreadyDeclared, p0.Receiver.String()) - return nil + return zeroProcedureReturn() } tracker := &loopTargetTracker{ @@ -2272,16 +2300,31 @@ func (p *procedureAnalyzer) VisitProcedureStmtForLoop(p0 *ProcedureStmtForLoop) for _, t := range p.procResult.allLoopReceivers { if t.name.String() == p0.Receiver.String() { p.errs.AddErr(p0.Receiver, ErrVariableAlreadyDeclared, p0.Receiver.String()) - return nil + return zeroProcedureReturn() } } p.procCtx.activeLoopReceivers = append([]string{tracker.name.String()}, p.procCtx.activeLoopReceivers...) p.procResult.allLoopReceivers = append(p.procResult.allLoopReceivers, tracker) + // returns tracks whether this loop is guaranteed to exit. + returns := false + canBreakPrematurely := false // we will now visit the statements in the loop. for _, stmt := range p0.Body { - stmt.Accept(p) + res := stmt.Accept(p).(*procedureStmtResult) + if res.canBreak { + canBreakPrematurely = true + } + if res.willReturn { + returns = true + } + } + // if it is possible for a for loop to break prematurely, then it is possible + // that it does not include a return, and so we need to inform the caller + // that it does not guarantee a return. + if canBreakPrematurely { + returns = false } // pop the loop target @@ -2297,7 +2340,9 @@ func (p *procedureAnalyzer) VisitProcedureStmtForLoop(p0 *ProcedureStmtForLoop) delete(p.variables, p0.Receiver.String()) } - return nil + return &procedureStmtResult{ + willReturn: returns, + } } func (p *procedureAnalyzer) VisitLoopTermRange(p0 *LoopTermRange) any { @@ -2348,33 +2393,66 @@ func (p *procedureAnalyzer) VisitLoopTermVariable(p0 *LoopTermVariable) any { } func (p *procedureAnalyzer) VisitProcedureStmtIf(p0 *ProcedureStmtIf) any { + canBreak := false + + allThensReturn := true for _, c := range p0.IfThens { - c.Accept(p) + res := c.Accept(p).(*procedureStmtResult) + if !res.willReturn { + allThensReturn = false + } + if res.canBreak { + canBreak = true + } } + // initialize to true, so that if else does not exist, we know we still exit. + // It gets set to false if we encounter an else block. + elseReturns := true if p0.Else != nil { + elseReturns = false for _, stmt := range p0.Else { - stmt.Accept(p) + res := stmt.Accept(p).(*procedureStmtResult) + if res.willReturn { + elseReturns = true + } + if res.canBreak { + canBreak = true + } } } - return nil + return &procedureStmtResult{ + willReturn: allThensReturn && elseReturns, + canBreak: canBreak, + } } func (p *procedureAnalyzer) VisitIfThen(p0 *IfThen) any { dt, ok := p0.If.Accept(p).(*types.DataType) if !ok { p.expressionTypeErr(p0.If) - return nil + return zeroProcedureReturn() } p.expect(p0.If, dt, types.BoolType) + canBreak := false + returns := false for _, stmt := range p0.Then { - stmt.Accept(p) + res := stmt.Accept(p).(*procedureStmtResult) + if res.willReturn { + returns = true + } + if res.canBreak { + canBreak = true + } } - return nil + return &procedureStmtResult{ + willReturn: returns, + canBreak: canBreak, + } } func (p *procedureAnalyzer) VisitProcedureStmtSQL(p0 *ProcedureStmtSQL) any { @@ -2386,7 +2464,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtSQL(p0 *ProcedureStmtSQL) any { panic("expected query to return attributes") } - return nil + return zeroProcedureReturn() } func (p *procedureAnalyzer) VisitProcedureStmtBreak(p0 *ProcedureStmtBreak) any { @@ -2394,20 +2472,26 @@ func (p *procedureAnalyzer) VisitProcedureStmtBreak(p0 *ProcedureStmtBreak) any p.errs.AddErr(p0, ErrBreak, "break statement outside of loop") } - return nil + return &procedureStmtResult{ + canBreak: true, + } } func (p *procedureAnalyzer) VisitProcedureStmtReturn(p0 *ProcedureStmtReturn) any { if p.procCtx.procedureDefinition.Returns == nil { p.errs.AddErr(p0, ErrFunctionSignature, "procedure does not return any values") - return nil + return &procedureStmtResult{ + willReturn: true, + } } returns := p.procCtx.procedureDefinition.Returns if p0.SQL != nil { if !returns.IsTable { p.errs.AddErr(p0, ErrReturn, "procedure expects scalar returns, cannot return SQL statement") - return nil + return &procedureStmtResult{ + willReturn: true, + } } p.startSQLAnalyze() @@ -2420,7 +2504,9 @@ func (p *procedureAnalyzer) VisitProcedureStmtReturn(p0 *ProcedureStmtReturn) an if len(res) != len(returns.Fields) { p.errs.AddErr(p0, ErrReturn, "expected %d return table columns, received %d", len(returns.Fields), len(res)) - return nil + return &procedureStmtResult{ + willReturn: true, + } } // we will compare the return types to the procedure definition @@ -2435,22 +2521,31 @@ func (p *procedureAnalyzer) VisitProcedureStmtReturn(p0 *ProcedureStmtReturn) an } } - return nil + return &procedureStmtResult{ + willReturn: true, + } } if returns.IsTable { p.errs.AddErr(p0, ErrReturn, "procedure expects table returns, cannot return scalar values") - return nil + return &procedureStmtResult{ + willReturn: true, + } } if len(p0.Values) != len(returns.Fields) { p.errs.AddErr(p0, ErrReturn, "expected %d return values, received %d", len(returns.Fields), len(p0.Values)) - return nil + return &procedureStmtResult{ + willReturn: true, + } } for i, v := range p0.Values { dt, ok := v.Accept(p).(*types.DataType) if !ok { - return p.expressionTypeErr(v) + p.expressionTypeErr(v) + return &procedureStmtResult{ + willReturn: true, + } } if !dt.Equals(returns.Fields[i].Type) { @@ -2458,29 +2553,40 @@ func (p *procedureAnalyzer) VisitProcedureStmtReturn(p0 *ProcedureStmtReturn) an } } - return nil + return &procedureStmtResult{ + willReturn: true, + } } func (p *procedureAnalyzer) VisitProcedureStmtReturnNext(p0 *ProcedureStmtReturnNext) any { if p.procCtx.procedureDefinition.Returns == nil { p.errs.AddErr(p0, ErrFunctionSignature, "procedure does not return any values") - return nil + return &procedureStmtResult{ + willReturn: true, + } } if !p.procCtx.procedureDefinition.Returns.IsTable { p.errs.AddErr(p0, ErrReturn, "procedure expects scalar returns, cannot return next") - return nil + return &procedureStmtResult{ + willReturn: true, + } } if len(p0.Values) != len(p.procCtx.procedureDefinition.Returns.Fields) { p.errs.AddErr(p0, ErrReturn, "expected %d return values, received %d", len(p.procCtx.procedureDefinition.Returns.Fields), len(p0.Values)) - return nil + return &procedureStmtResult{ + willReturn: true, + } } for i, v := range p0.Values { dt, ok := v.Accept(p).(*types.DataType) if !ok { - return p.expressionTypeErr(v) + p.expressionTypeErr(v) + return &procedureStmtResult{ + willReturn: true, + } } if !dt.Equals(p.procCtx.procedureDefinition.Returns.Fields[i].Type) { @@ -2488,5 +2594,26 @@ func (p *procedureAnalyzer) VisitProcedureStmtReturnNext(p0 *ProcedureStmtReturn } } - return nil + return &procedureStmtResult{ + willReturn: true, + } +} + +// zeroProcedureReturn creates a new procedure return with all 0 values. +func zeroProcedureReturn() *procedureStmtResult { + return &procedureStmtResult{} +} + +// procedureStmtResult is returned from each procedure statement visit. +type procedureStmtResult struct { + // willReturn is true if the statement contains a return statement that it will + // always hit. This is used to determine if a path will exit a procedure. + // it is used to tell whether or not a statement can potentially exit a procedure, + // since all procedures that have an expected return must always return that value. + // It only tells us whether or not a return is guaranteed to be hit from a statement. + // The return types are checked at the point of the return statement. + willReturn bool + // canBreak is true if the statement that can break a for loop it is in. + // For example, an IF statement that breaks a for loop will set canBreak to true. + canBreak bool } diff --git a/parse/parse.go b/parse/parse.go index 985f8d8f1..ecf30472f 100644 --- a/parse/parse.go +++ b/parse/parse.go @@ -187,8 +187,18 @@ func analyzeProcedureAST(proc *types.Procedure, schema *types.Schema, ast []Proc } // visit the AST + returns := false for _, stmt := range res.AST { - stmt.Accept(visitor) + res := stmt.Accept(visitor).(*procedureStmtResult) + if res.willReturn { + returns = true + } + } + + // if the procedure is expecting a return that is not a table, and it does not guarantee + // returning a value, we should add an error. + if proc.Returns != nil && !returns && !proc.Returns.IsTable { + errLis.AddErr(res.AST[len(res.AST)-1], ErrReturn, "procedure does not return a value") } for k, v := range visitor.procResult.allVariables { diff --git a/parse/parse_test.go b/parse/parse_test.go index bba4bd267..c9d485655 100644 --- a/parse/parse_test.go +++ b/parse/parse_test.go @@ -1198,6 +1198,253 @@ func Test_Procedure(t *testing.T) { }, }, }, + { + name: "missing return values", + returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{ + { + Name: "id", + Type: types.IntType, + }, + { + Name: "name", + Type: types.TextType, + }, + }, + }, + proc: `return 1;`, + err: parse.ErrReturn, + }, + { + name: "no return values", + returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{{ + Name: "id", + Type: types.IntType, + }}, + }, + proc: `$a := 1;`, + err: parse.ErrReturn, + }, + { + name: "if/then missing return in one branch", + returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{{ + Name: "id", + Type: types.IntType, + }}, + }, + proc: ` + if true { + return 1; + } else { + $a := 1; + } + `, + err: parse.ErrReturn, + }, + { + name: "for loop with if return", + returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{{ + Name: "id", + Type: types.IntType, + }}, + }, + proc: ` + $arr := [1,2,3]; + for $i in $arr { + if $i == -1 { + break; + } + return $i; + } + `, + err: parse.ErrReturn, + }, + { + name: "nested for loop", + returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{{ + Name: "id", + Type: types.IntType, + }}, + }, + proc: ` + $arr int[]; + for $i in $arr { + for $j in 1..$i { + break; // only breaks the inner loop + } + + return $i; // this will always exit on first $i iteration + } + `, + want: &parse.ProcedureParseResult{ + Variables: map[string]*types.DataType{ + "$arr": types.ArrayType(types.IntType), + "$i": types.IntType, + "$j": types.IntType, + }, + AST: []parse.ProcedureStmt{ + &parse.ProcedureStmtDeclaration{ + Variable: exprVar("$arr"), + Type: types.ArrayType(types.IntType), + }, + &parse.ProcedureStmtForLoop{ + Receiver: exprVar("$i"), + LoopTerm: &parse.LoopTermVariable{ + Variable: exprVar("$arr"), + }, + Body: []parse.ProcedureStmt{ + &parse.ProcedureStmtForLoop{ + Receiver: exprVar("$j"), + LoopTerm: &parse.LoopTermRange{ + Start: exprLit(1), + End: exprVar("$i"), + }, + Body: []parse.ProcedureStmt{ + &parse.ProcedureStmtBreak{}, + }, + }, + &parse.ProcedureStmtReturn{ + Values: []parse.Expression{exprVar("$i")}, + }, + }, + }, + }, + }, + }, + { + name: "returns table incorrect", + returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{{ + Name: "id", + Type: types.IntType, + }}, + }, + proc: `return select id from users;`, // this is intentional- plpgsql treats this as a table return + err: parse.ErrReturn, + }, + { + name: "returns table correct", + returns: &types.ProcedureReturn{ + IsTable: true, + Fields: []*types.NamedType{{ + Name: "id", + Type: types.IntType, + }}, + }, + proc: `return select 1 as id;`, + want: &parse.ProcedureParseResult{ + AST: []parse.ProcedureStmt{ + &parse.ProcedureStmtReturn{ + SQL: &parse.SQLStatement{ + SQL: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnExpression{ + Expression: exprLit(1), + Alias: "id", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "returns next incorrect", + returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{{ + Name: "id", + Type: types.IntType, + }}, + }, + proc: `$a int[]; + for $row in $a { + return next $row; + } + `, + err: parse.ErrReturn, + }, + { + name: "returns next correct", + returns: &types.ProcedureReturn{ + IsTable: true, + Fields: []*types.NamedType{{ + Name: "id", + Type: types.IntType, + }}, + }, + proc: ` + for $row in select * from get_all_user_ids() { + return next $row.id; + } + `, + want: &parse.ProcedureParseResult{ + CompoundVariables: map[string]struct{}{ + "$row": {}, + }, + AST: []parse.ProcedureStmt{ + &parse.ProcedureStmtForLoop{ + Receiver: exprVar("$row"), + LoopTerm: &parse.LoopTermSQL{ + Statement: &parse.SQLStatement{ + SQL: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{&parse.ResultColumnWildcard{}}, + From: &parse.RelationFunctionCall{ + FunctionCall: &parse.ExpressionFunctionCall{ + Name: "get_all_user_ids", + }, + }, + }, + }, + }, + }, + }, + Body: []parse.ProcedureStmt{ + &parse.ProcedureStmtReturnNext{ + Values: []parse.Expression{ + &parse.ExpressionFieldAccess{ + Record: exprVar("$row"), + Field: "id", + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "error func exits", + proc: `error('error message');`, + returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{{ + Name: "id", + Type: types.IntType, + }}, + }, + want: &parse.ProcedureParseResult{ + AST: []parse.ProcedureStmt{ + &parse.ProcedureStmtCall{ + Call: &parse.ExpressionFunctionCall{ + Name: "error", + Args: []parse.Expression{ + exprLit("error message"), + }, + }, + }, + }, + }, + }, } for _, tt := range tests { @@ -1320,18 +1567,14 @@ func Test_SQL(t *testing.T) { Columns: []parse.ResultColumn{ &parse.ResultColumnWildcard{}, &parse.ResultColumnExpression{ - Expression: &parse.ExpressionColumn{ - Column: "id", - }, - Alias: "i", + Expression: exprColumn("", "id"), + Alias: "i", }, &parse.ResultColumnExpression{ Expression: &parse.ExpressionFunctionCall{ Name: "length", Args: []parse.Expression{ - &parse.ExpressionColumn{ - Column: "username", - }, + exprColumn("", "username"), }, }, Alias: "name_len", @@ -1342,10 +1585,7 @@ func Test_SQL(t *testing.T) { Alias: "u", }, Where: &parse.ExpressionComparison{ - Left: &parse.ExpressionColumn{ - Table: "u", - Column: "id", - }, + Left: exprColumn("u", "id"), Operator: parse.ComparisonOperatorEqual, Right: &parse.ExpressionLiteral{ Type: types.IntType, @@ -1357,10 +1597,7 @@ func Test_SQL(t *testing.T) { // apply default ordering Ordering: []*parse.OrderingTerm{ { - Expression: &parse.ExpressionColumn{ - Table: "u", - Column: "id", - }, + Expression: exprColumn("u", "id"), }, }, }, @@ -1396,18 +1633,14 @@ func Test_SQL(t *testing.T) { { Columns: []parse.ResultColumn{ &parse.ResultColumnExpression{ - Expression: &parse.ExpressionColumn{ - Column: "id", - }, + Expression: exprColumn("", "id"), }, }, From: &parse.RelationTable{ Table: "users", }, Where: &parse.ExpressionComparison{ - Left: &parse.ExpressionColumn{ - Column: "username", - }, + Left: exprColumn("", "username"), Operator: parse.ComparisonOperatorEqual, Right: &parse.ExpressionLiteral{ Type: types.TextType, @@ -1423,10 +1656,7 @@ func Test_SQL(t *testing.T) { // apply default ordering Ordering: []*parse.OrderingTerm{ { - Expression: &parse.ExpressionColumn{ - Table: "users", - Column: "id", - }, + Expression: exprColumn("users", "id"), }, }, }, @@ -1447,18 +1677,12 @@ func Test_SQL(t *testing.T) { { Columns: []parse.ResultColumn{ &parse.ResultColumnExpression{ - Expression: &parse.ExpressionColumn{ - Column: "id", - Table: "p", - }, - Alias: "id", + Expression: exprColumn("p", "id"), + Alias: "id", }, &parse.ResultColumnExpression{ - Expression: &parse.ExpressionColumn{ - Column: "username", - Table: "u", - }, - Alias: "author", + Expression: exprColumn("u", "username"), + Alias: "author", }, }, From: &parse.RelationTable{ @@ -1473,23 +1697,14 @@ func Test_SQL(t *testing.T) { Alias: "u", }, On: &parse.ExpressionComparison{ - Left: &parse.ExpressionColumn{ - Column: "author_id", - Table: "p", - }, + Left: exprColumn("p", "author_id"), Operator: parse.ComparisonOperatorEqual, - Right: &parse.ExpressionColumn{ - Column: "id", - Table: "u", - }, + Right: exprColumn("u", "id"), }, }, }, Where: &parse.ExpressionComparison{ - Left: &parse.ExpressionColumn{ - Column: "username", - Table: "u", - }, + Left: exprColumn("u", "username"), Operator: parse.ComparisonOperatorEqual, Right: &parse.ExpressionLiteral{ Type: types.TextType, @@ -1501,25 +1716,16 @@ func Test_SQL(t *testing.T) { Ordering: []*parse.OrderingTerm{ { - Expression: &parse.ExpressionColumn{ - Table: "u", - Column: "username", - }, - Order: parse.OrderTypeDesc, - Nulls: parse.NullOrderLast, + Expression: exprColumn("u", "username"), + Order: parse.OrderTypeDesc, + Nulls: parse.NullOrderLast, }, // apply default ordering { - Expression: &parse.ExpressionColumn{ - Table: "p", - Column: "id", - }, + Expression: exprColumn("p", "id"), }, { - Expression: &parse.ExpressionColumn{ - Table: "u", - Column: "id", - }, + Expression: exprColumn("u", "id"), }, }, }, @@ -1532,9 +1738,7 @@ func Test_SQL(t *testing.T) { SQL: &parse.DeleteStatement{ Table: "users", Where: &parse.ExpressionComparison{ - Left: &parse.ExpressionColumn{ - Column: "id", - }, + Left: exprColumn("", "id"), Operator: parse.ComparisonOperatorEqual, Right: &parse.ExpressionLiteral{ Type: types.IntType, @@ -1565,15 +1769,9 @@ func Test_SQL(t *testing.T) { { Column: "id", Value: &parse.ExpressionArithmetic{ - Left: &parse.ExpressionColumn{ - Column: "id", - Table: "users", - }, + Left: exprColumn("users", "id"), Operator: parse.ArithmeticOperatorAdd, - Right: &parse.ExpressionColumn{ - Column: "id", - Table: "excluded", - }, + Right: exprColumn("excluded", "id"), }, }, }, @@ -1613,6 +1811,180 @@ func Test_SQL(t *testing.T) { INNER JOIN (SELECT id as uid FROM users WHERE id = 1) ON p.author_id = uid;`, err: parse.ErrUnnamedJoin, }, + { + name: "compound select", + sql: `SELECT * FROM users union SELECT * FROM users;`, + want: &parse.SQLStatement{ + SQL: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnWildcard{}, + }, + From: &parse.RelationTable{ + Table: "users", + }, + }, + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnWildcard{}, + }, + From: &parse.RelationTable{ + Table: "users", + }, + }, + }, + CompoundOperators: []parse.CompoundOperator{ + parse.CompoundOperatorUnion, + }, + // apply default ordering + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("", "id"), + }, + { + Expression: exprColumn("", "username"), + }, + }, + }, + }, + }, + { + name: "compound with mismatched shape", + sql: `SELECT username, id FROM users union SELECT id, username FROM users;`, + err: parse.ErrResultShape, + }, + { + name: "group by", + sql: `SELECT u.username, count(u.id) FROM users as u GROUP BY u.username HAVING count(u.id) > 1;`, + want: &parse.SQLStatement{ + SQL: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnExpression{ + Expression: exprColumn("u", "username"), + }, + &parse.ResultColumnExpression{ + Expression: &parse.ExpressionFunctionCall{ + Name: "count", + Args: []parse.Expression{ + exprColumn("u", "id"), + }, + }, + }, + }, + From: &parse.RelationTable{ + Table: "users", + Alias: "u", + }, + GroupBy: []parse.Expression{ + exprColumn("u", "username"), + }, + Having: &parse.ExpressionComparison{ + Left: &parse.ExpressionFunctionCall{ + Name: "count", + Args: []parse.Expression{ + exprColumn("u", "id"), + }, + }, + Operator: parse.ComparisonOperatorGreaterThan, + Right: exprLit(1), + }, + }, + }, + // apply default ordering + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("u", "username"), + }, + }, + }, + }, + }, + { + name: "group by with having, having is in group by clause", + // there's a much easier way to write this query, but this is just to test the parser + sql: `SELECT username FROM users GROUP BY username HAVING length(username) > 1;`, + want: &parse.SQLStatement{ + SQL: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnExpression{ + Expression: exprColumn("", "username"), + }, + }, + From: &parse.RelationTable{ + Table: "users", + }, + GroupBy: []parse.Expression{ + exprColumn("", "username"), + }, + Having: &parse.ExpressionComparison{ + Left: &parse.ExpressionFunctionCall{ + Name: "length", + Args: []parse.Expression{ + exprColumn("", "username"), + }, + }, + Operator: parse.ComparisonOperatorGreaterThan, + Right: exprLit(1), + }, + }, + }, + // apply default ordering + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("users", "username"), + }, + }, + }, + }, + }, + { + name: "group by with having, invalid column without aggregate", + sql: `SELECT u.username, count(u.id) FROM users as u GROUP BY u.username HAVING u.id > 1;`, + err: parse.ErrAggregate, + }, + { + name: "compound select with group by", + sql: `SELECT u.username, count(u.id) FROM users as u GROUP BY u.username HAVING count(u.id) > 1 UNION SELECT u.username, count(u.id) FROM users as u GROUP BY u.username HAVING count(u.id) > 1;`, + err: parse.ErrAggregate, + }, + { + name: "aggregate with no group by returns many columns", + sql: `SELECT count(u.id), u.username FROM users as u;`, + err: parse.ErrAggregate, + }, + { + name: "aggregate with no group by returns one column", + sql: `SELECT count(*) FROM users;`, + want: &parse.SQLStatement{ + SQL: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnExpression{ + Expression: &parse.ExpressionFunctionCall{ + Name: "count", + Star: true, + }, + }, + }, + From: &parse.RelationTable{ + Table: "users", + }, + }, + }, + }, + }, + }, + { + name: "aggregate with no group by and ordering fails", + sql: `SELECT count(*) FROM users order by count(*) DESC;`, + err: parse.ErrAggregate, + }, } for _, tt := range tests { @@ -1646,6 +2018,13 @@ func Test_SQL(t *testing.T) { } } +func exprColumn(t, c string) *parse.ExpressionColumn { + return &parse.ExpressionColumn{ + Table: t, + Column: c, + } +} + // deepCompare deep compares the values of two nodes. // It ignores the parseTypes.Node field. func deepCompare(node1, node2 any) bool {