-
Notifications
You must be signed in to change notification settings - Fork 12
/
analyze.go
2783 lines (2351 loc) · 83.4 KB
/
analyze.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
package parse
import (
"fmt"
"maps"
"github.com/kwilteam/kwil-db/core/types"
)
/*
this file performs analysis of SQL and procedures. It performs several main types of validation:
1. Type checking: it ensures that all statements and expressions return correct types.
This is critical because plpgsql only throws type errors at runtime, which is really bad
for a smart contract language.
2. Deterministic ordering: it ensures that all queries have deterministic ordering, even if
not specified by the query writer. It adds necessary ordering clauses to achieve this.
3. Aggregate checks: it ensures that aggregate functions are used correctly, and that they
can not be used to create non-determinism, and also that they return errors that would otherwise
be thrown by Postgres at runtime.
4. Mutative checks: it analyzes whether or not a procedure / sql statement is attempting to
modify state. It does not return an error if it does, but will return a boolean indicating
whether or not it is mutative. This can be used by callers to ensure that VIEW procedures
are not mutative, which would otherwise only be checked at runtime when executing them with
a read-only transaction.
5. Contextual statement checks: Procedure statements that can only be used in certain contexts
(e.g. loop breaks and RETURN NEXT) are checked to ensure that they are only used in loops.
6. PLPGSQL Variable Declarations: It analyzes what variables should be declared at the top
of a PLPGSQL statement, and what types they should be.
7. Cartesian Join Checks: All joins must be joined using =, and one side of the join condition
must be a unique column with no other math applied to it. Primary keys are also counted as unique,
unless it is a compound primary key.
DETERMINISTIC ORDERING RULES:
If a SELECT statement is a simple select (e.g. does not use compound operators):
1. All joined tables that are physical (and not subqueries or procedure calls) are ordered by their primary keys,
in the order they are joined.
2. If a SELECT has a DISTINCT clause, it will order by all columns being returned. The reason
for this can be seen here: https://stackoverflow.com/questions/3492093/does-postgresql-support-distinct-on-multiple-columns.
All previous rules do not apply.
3. If a SELECT has a GROUP BY clause, all columns specified in the GROUP BY clause will be ordered.
All previous rules do not apply.
If a SELECT statement is a compound select (e.g. uses UNION, UNION ALL INTERSECT, EXCEPT):
1. All returned columns are ordered by their position in the SELECT statement.
2. If any compound SELECT statement has a GROUP BY, then it will return an error.
This is a remnant of SQLite's rudimentary indexing, but these queries are fairly uncommon,
and therefore not allowed for the time being
AGGREGATE FUNCTION RULES:
1. Aggregate functions can only be used in the SELECT clause, and not in the WHERE clause.
2. All columns referenced in HAVING or return columns must be in the GROUP BY clause, unless they are
in an aggregate function.
3. All columns used within aggregate functions cannot be specified in the GROUP BY clause.
4. If there is an aggregate in the return columns and no GROUP BY clause, then there can only
be one column in the return columns (the column with the aggregate function).
*/
// unknownPosition is used to signal that a position is unknown in the AST because
// it was not present in the parse statement. This is used for when we make modifications
// to the ast, e.g. for enforcing ordering.
func unknownPosition() Position {
return Position{
IsSet: true,
StartLine: -1,
StartCol: -1,
EndLine: -1,
EndCol: -1,
}
}
// blockContext is the context for the current block. This is can be an action, procedure,
// or sql block.
type blockContext struct {
// schema is the current schema
schema *types.Schema
// variables holds information about all variable declarations in the block
// It holds both user variables like $arg1, $arg2, and contextual variables,
// like @caller and @txid. It will be populated while the analyzer is running,
// but is prepopulated with the procedure's arguments and contextual variables.
variables map[string]*types.DataType
// anonymousVariables holds information about all anonymous variable declarations in the block.
// Anonymous variables are objects with fields, such as the receiver of loops.
// The map maps the name to the fields to their data types.
// The map will be populated while the analyzer is running.
anonymousVariables map[string]map[string]*types.DataType
// errs is used for passing errors back to the caller.
errs *errorListener
}
// variableExists checks if a variable exists in the current block.
// It will check both user variables and anonymous variables.
func (b *blockContext) variableExists(name string) bool {
_, ok := b.variables[name]
if ok {
return true
}
_, ok = b.anonymousVariables[name]
return ok
}
// copyVariables copies both the user variables and anonymous variables.
func (b *blockContext) copyVariables() (map[string]*types.DataType, map[string]map[string]*types.DataType) {
// we do not need to deep copy anonymousVariables because anonymousVariables maps an object name
// to an objects fields and their data types. The only way to declare an object in Kuneiform
// is for $row in SELECT ..., the $row will have fields. Since these variables can only be declared once
// per procedure, we do not need to worry about the object having different fields throughout the
// procedure.
return maps.Clone(b.variables), maps.Clone(b.anonymousVariables)
}
// sqlContext is the context of the current SQL statement
type sqlContext struct {
// joinedRelations tracks all relations joined on the current SQL statement.
joinedRelations []*Relation
// outerRelations are relations that are not joined on the scope, but are available.
// These are typically outer queries in a subquery. Calling these will be a correlated subquery.
outerRelations []*Relation
// joinedTables maps all used table names/aliases to their table definitions.
// The tables named here are also included in joinedRelations, but not
// all joinedRelations are in this map. This map ONLY includes actual SQL
// tables joined in this context, not joined subqueries or procedure calls.
// These are used for default ordering.
joinedTables map[string]*types.Table
// ctes are the common table expressions in the current scope.
ctes []*Relation
// outerScope is the scope of the outer query.
outerScope *sqlContext
// isInlineAction is true if the visitor is analyzing a SQL expression within an in-line
// statement in an action
isInlineAction bool
// inConflict is true if we are in an ON CONFLICT clause
inConflict bool
// targetTable is the name (or alias) of the table being inserted, updated, or deleted to/from.
// It is not set if we are not in an insert, update, or delete statement.
targetTable string
// hasAnonymousTable is true if an unnamed table has been joined. If this is true,
// it can be the only table joined in a select statement.
hasAnonymousTable bool
// inSelect is true if we are in a select statement.
inSelect bool
// inLoneSQL is true if this is being parsed as a lone SQL query, and
// not as part of an action or procedure. This allows us to bypass certain
// checks, such as that variables are declared as part of the procedure.
inLoneSQL bool
// temp are values that are temporary and not even saved within the same scope.
// they are used in highly specific contexts, and shouldn't be relied on unless
// specifically documented. All temp values are denoted with a _.
// inAggregate is true if we are within an aggregate functions parameters.
// it should only be used in ExpressionColumn, and set in ExpressionFunctionCall.
_inAggregate bool
// containsAggregate is true if the current expression contains an aggregate function.
// it is set in ExpressionFunctionCall, and accessed/reset in SelectCore.
_containsAggregate bool
// containsAggregateWithoutGroupBy is true if the current expression contains an aggregate function,
// but there is no GROUP BY clause. This is set in SelectCore, and accessed in SelectStatement.
_containsAggregateWithoutGroupBy bool
// columnInAggregate is the column found within an aggregate function,
// comprised of the relation and attribute.
// It is set in ExpressionColumn, and accessed/reset in
// SelectCore. It is nil if none are found.
_columnInAggregate *[2]string
// columnsOutsideAggregate are columns found outside of an aggregate function.
// It is set in ExpressionColumn, and accessed/reset in
// SelectCore
_columnsOutsideAggregate [][2]string
// inOrdering is true if we are in an ordering clause
_inOrdering bool
// result is the result of a query. It is only set when analyzing the ordering clause
_result []*Attribute
}
func newSQLContext() sqlContext {
return sqlContext{
joinedTables: make(map[string]*types.Table),
}
}
// setTempValuesToZero resets all temp values to their zero values.
func (s *sqlContext) setTempValuesToZero() {
s._inAggregate = false
s._containsAggregate = false
s._columnInAggregate = nil
s._columnsOutsideAggregate = nil
s._inOrdering = false
s._result = nil
}
// copy copies the sqlContext.
// it does not copy the outer scope.
func (c *sqlContext) copy() sqlContext {
joinedRelations := make([]*Relation, len(c.joinedRelations))
for i, r := range c.joinedRelations {
joinedRelations[i] = r.Copy()
}
outerRelations := make([]*Relation, len(c.outerRelations))
for i, r := range c.outerRelations {
outerRelations[i] = r.Copy()
}
// ctes don't need to be copied right now since they are not modified.
colsOutsideAgg := make([][2]string, len(c._columnsOutsideAggregate))
copy(colsOutsideAgg, c._columnsOutsideAggregate)
return sqlContext{
joinedRelations: joinedRelations,
outerRelations: outerRelations,
ctes: c.ctes,
joinedTables: c.joinedTables,
_containsAggregateWithoutGroupBy: c._containsAggregateWithoutGroupBy, // we want to carry this over
}
}
// joinRelation adds a relation to the context.
func (c *sqlContext) joinRelation(r *Relation) error {
// check if the relation is already joined
_, ok := c.getJoinedRelation(r.Name)
if ok {
return ErrTableAlreadyJoined
}
c.joinedRelations = append(c.joinedRelations, r)
return nil
}
// join joins a table. It will return an error if the table is already joined.
func (c *sqlContext) join(name string, t *types.Table) error {
_, ok := c.joinedTables[name]
if ok {
return ErrTableAlreadyJoined
}
c.joinedTables[name] = t
return nil
}
// getJoinedRelation returns the relation with the given name.
func (c *sqlContext) getJoinedRelation(name string) (*Relation, bool) {
for _, r := range c.joinedRelations {
if r.Name == name {
return r, true
}
}
return nil, false
}
// getOuterRelation returns the relation with the given name.
func (c *sqlContext) getOuterRelation(name string) (*Relation, bool) {
for _, r := range c.outerRelations {
if r.Name == name {
return r, true
}
}
return nil, false
}
// the following special table names track table names that mean something in the context of the SQL statement.
const (
tableExcluded = "excluded"
)
// findAttribute searches for a attribute in the specified relation.
// if the relation is empty, it will search all joined relations.
// It does NOT search the outer relations unless specifically specified;
// this matches Postgres' behavior.
// If the relation is empty and many columns are found, it will return an error.
// It returns both an error and an error message in case of an error.
// This is because it is meant to pass errors back to the error listener.
func (c *sqlContext) findAttribute(relation string, column string) (relName string, attr *Attribute, msg string, err error) {
if relation == "" {
foundAttrs := make([]*Attribute, 0)
for _, r := range c.joinedRelations {
for _, a := range r.Attributes {
if a.Name == column {
relName = r.Name
foundAttrs = append(foundAttrs, a)
}
}
}
switch len(foundAttrs) {
case 0:
return "", nil, column, ErrUnknownColumn
case 1:
return relName, foundAttrs[0], "", nil
default:
return "", nil, column, ErrAmbiguousColumn
}
}
// if referencing excluded, we should instead look at the target table,
// since the excluded data will always match the failed insert.
if relation == tableExcluded {
// excluded can only be used in an ON CONFLICT clause
if !c.inConflict {
return "", nil, relation, fmt.Errorf("%w: excluded table can only be used in an ON CONFLICT clause", ErrInvalidExcludedTable)
}
relation = c.targetTable
}
r, ok := c.getJoinedRelation(relation)
if !ok {
r, ok = c.getOuterRelation(relation)
if !ok {
return "", nil, relation, ErrUnknownTable
}
}
for _, a := range r.Attributes {
if a.Name == column {
return r.Name, a, "", nil
}
}
return "", nil, relation + "." + column, ErrUnknownColumn
}
// scope moves the current scope to outer scope,
// and sets the current scope to a new scope.
func (c *sqlContext) scope() {
c2 := &sqlContext{
joinedRelations: make([]*Relation, len(c.joinedRelations)),
outerRelations: make([]*Relation, len(c.outerRelations)),
joinedTables: make(map[string]*types.Table),
// we do not need to copy ctes since they are not ever modified.
targetTable: c.targetTable,
isInlineAction: c.isInlineAction,
inConflict: c.inConflict,
inSelect: c.inSelect,
hasAnonymousTable: c.hasAnonymousTable,
}
// copy all non-temp values
for i, r := range c.outerRelations {
c2.outerRelations[i] = r.Copy()
}
for i, r := range c.joinedRelations {
c2.joinedRelations[i] = r.Copy()
}
for k, t := range c.joinedTables {
c2.joinedTables[k] = t.Copy()
}
// move joined relations to the outside
c.outerRelations = append(c.outerRelations, c.joinedRelations...)
// zero everything else
c.joinedRelations = nil
c.joinedTables = make(map[string]*types.Table)
c.setTempValuesToZero()
// we do NOT change the inAction, inConflict, or targetTable values,
// since these apply in all nested scopes.
// we do not alter inSelect, but we do alter hasAnonymousTable.
c2.hasAnonymousTable = false
c2.outerScope = c.outerScope
c.outerScope = c2
}
// popScope moves the current scope to the outer scope.
func (c *sqlContext) popScope() {
*c = *c.outerScope
}
/*
this visitor breaks down nodes into 4 different types:
- Expressions: expressions simply return *Attribute. The name on all of these will be empty UNLESS it is a column reference.
- CommonTableExpressions: the only node that can directly add tables to outerRelations slice.
*/
// sqlAnalyzer visits SQL nodes and analyzes them.
type sqlAnalyzer struct {
UnimplementedSqlVisitor
blockContext
sqlCtx sqlContext
sqlResult sqlAnalyzeResult
}
// reset resets the sqlAnalyzer.
func (s *sqlAnalyzer) reset() {
// we don't need to touch the block context, since it does not change here.
s.sqlCtx = newSQLContext()
s.sqlResult = sqlAnalyzeResult{}
}
type sqlAnalyzeResult struct {
Mutative bool
}
// startSQLAnalyze initializes all fields of the sqlAnalyzer.
func (s *sqlAnalyzer) startSQLAnalyze() {
s.sqlCtx = sqlContext{
joinedTables: make(map[string]*types.Table),
}
}
// endSQLAnalyze is called at the end of the analysis.
func (s *sqlAnalyzer) endSQLAnalyze() *sqlAnalyzeResult {
res := s.sqlResult
s.sqlCtx = sqlContext{}
return &res
}
var _ Visitor = (*sqlAnalyzer)(nil)
// typeErr should be used when a type error is encountered.
// It returns an unknown attribute and adds an error to the error listener.
func (s *sqlAnalyzer) typeErr(node Node, t1, t2 *types.DataType) *types.DataType {
s.errs.AddErr(node, ErrType, "%s != %s", t1, t2)
return cast(node, types.UnknownType)
}
// expect is a helper function that expects a certain type, and adds an error if it is not found.
func (s *sqlAnalyzer) expect(node Node, t *types.DataType, expected *types.DataType) {
if !t.Equals(expected) {
s.errs.AddErr(node, ErrType, "expected %s, received %s", expected, t)
}
}
// expectedNumeric is a helper function that expects a numeric type, and adds an error if it is not found.
func (s *sqlAnalyzer) expectedNumeric(node Node, t *types.DataType) {
if !t.IsNumeric() {
s.errs.AddErr(node, ErrType, "expected numeric type, received %s", t)
}
}
// expressionTypeErr should be used if we expect an expression to return a *types.DataType,
// but it returns something else. It will attempt to read the actual type and create an error
// message that is helpful for the end user.
func (s *sqlAnalyzer) expressionTypeErr(e Expression) *types.DataType {
// prefixMsg is a function used to attempt to infer more information about
// the error. expressionTypeErr is typically triggered when someone uses a function/procedure
// with an incompatible return type. prefixMsg will attempt to get the name of the function/procedure
prefixMsg := func() string {
msg := "expression"
if call, ok := e.(ExpressionCall); ok {
msg = fmt.Sprintf(`function/procedure "%s"`, call.FunctionName())
}
return msg
}
switch v := e.Accept(s).(type) {
case *types.DataType:
// if it is a basic expression returning a scalar (e.g. "'hello'" or "abs(-1)"),
// or a procedure that returns exactly one scalar value.
// This should never happen, since expressionTypeErr is called when the expression
// does not return a *types.DataType.
panic("api misuse: expressionTypeErr should only be called when the expression does not return a *types.DataType")
case map[string]*types.DataType:
// if it is a loop receiver on a select statement (e.g. "for $row in select * from table")
s.errs.AddErr(e, ErrType, "invalid usage of compound type. you must reference a field using $compound.field notation")
case []*types.DataType:
// if it is a procedure than returns several scalar values
s.errs.AddErr(e, ErrType, "expected %s to return a single value, returns %d values", prefixMsg(), len(v))
case *returnsTable:
// if it is a procedure that returns a table
s.errs.AddErr(e, ErrType, "%s returns table, not scalar values", prefixMsg())
case nil:
// if it is a procedure that returns nothing
s.errs.AddErr(e, ErrType, "%s does not return any value", prefixMsg())
default:
// unknown
s.errs.AddErr(e, ErrType, "internal bug: could not infer expected type")
}
return cast(e, types.UnknownType)
}
// cast will return the type case if one exists. If not, it will simply
// return the passed type.
func cast(castable any, fallback *types.DataType) *types.DataType {
if castable == nil {
return fallback
}
c, ok := castable.(interface{ GetTypeCast() *types.DataType })
if !ok {
return fallback
}
if c.GetTypeCast() == nil {
return fallback
}
return c.GetTypeCast()
}
func (s *sqlAnalyzer) VisitExpressionLiteral(p0 *ExpressionLiteral) any {
// if type casted by the user, we should just use their value. If not,
// we should assert the type since Postgres might detect it incorrectly.
if p0.TypeCast == nil && !p0.Type.EqualsStrict(types.NullType) {
// cannot cast to null
p0.TypeCast = p0.Type
} else {
return cast(p0, p0.Type)
}
return p0.TypeCast
}
func (s *sqlAnalyzer) VisitExpressionFunctionCall(p0 *ExpressionFunctionCall) any {
// function call should either be against a known function, or a procedure.
fn, ok := Functions[p0.Name]
if !ok {
// if not found, it might be a schema procedure.
proc, found := s.schema.FindProcedure(p0.Name)
if !found {
s.errs.AddErr(p0, ErrUnknownFunctionOrProcedure, p0.Name)
return cast(p0, types.UnknownType)
}
if !proc.IsView() {
s.sqlResult.Mutative = true
}
// if it is a procedure, it cannot use distinct or *
if p0.Distinct {
s.errs.AddErr(p0, ErrFunctionSignature, "cannot use DISTINCT when calling procedure %s", p0.Name)
return cast(p0, types.UnknownType)
}
if p0.Star {
s.errs.AddErr(p0, ErrFunctionSignature, "cannot use * when calling procedure %s", p0.Name)
return cast(p0, types.UnknownType)
}
// verify the inputs
if len(p0.Args) != len(proc.Parameters) {
s.errs.AddErr(p0, ErrFunctionSignature, "expected %d arguments, received %d", len(proc.Parameters), len(p0.Args))
return cast(p0, types.UnknownType)
}
for i, arg := range p0.Args {
dt, ok := arg.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(arg)
}
if !dt.Equals(proc.Parameters[i].Type) {
return s.typeErr(arg, dt, proc.Parameters[i].Type)
}
}
return s.returnProcedureReturnExpr(p0, p0.Name, proc.Returns)
}
if s.sqlCtx._inOrdering && s.sqlCtx._inAggregate {
s.errs.AddErr(p0, ErrOrdering, "cannot use aggregate functions in ORDER BY clause")
return cast(p0, types.UnknownType)
}
// the function is a built in function. If using DISTINCT, it needs to be an aggregate
// if using *, it needs to support it.
if p0.Distinct && !fn.IsAggregate {
s.errs.AddErr(p0, ErrFunctionSignature, "DISTINCT can only be used with aggregate functions")
return cast(p0, types.UnknownType)
}
if fn.IsAggregate {
s.sqlCtx._inAggregate = true
s.sqlCtx._containsAggregate = true
defer func() { s.sqlCtx._inAggregate = false }()
}
// if the function is called with *, we need to ensure it supports it.
// If not, then we validate all args and return the type.
var returnType *types.DataType
if p0.Star {
if fn.StarArgReturn == nil {
s.errs.AddErr(p0, ErrFunctionSignature, "function does not support *")
return cast(p0, types.UnknownType)
}
// if calling with *, it must have no args
if len(p0.Args) != 0 {
s.errs.AddErr(p0, ErrFunctionSignature, "function does not accept arguments when using *")
return cast(p0, types.UnknownType)
}
returnType = fn.StarArgReturn
} else {
argTyps := make([]*types.DataType, len(p0.Args))
for i, arg := range p0.Args {
dt, ok := arg.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(arg)
}
argTyps[i] = dt
}
var err error
returnType, err = fn.ValidateArgs(argTyps)
if err != nil {
s.errs.AddErr(p0, ErrFunctionSignature, err.Error())
return cast(p0, types.UnknownType)
}
}
// callers of this visitor know that a nil return means a function does not
// return anything. We explicitly return nil instead of a nil *types.DataType
if returnType == nil {
return nil
}
return cast(p0, returnType)
}
func (s *sqlAnalyzer) VisitExpressionForeignCall(p0 *ExpressionForeignCall) any {
if s.sqlCtx.isInlineAction {
s.errs.AddErr(p0, ErrFunctionSignature, "foreign calls are not supported in in-line action statements")
}
// foreign call must be defined as a foreign procedure
proc, found := s.schema.FindForeignProcedure(p0.Name)
if !found {
s.errs.AddErr(p0, ErrUnknownFunctionOrProcedure, p0.Name)
return cast(p0, types.UnknownType)
}
if len(p0.ContextualArgs) != 2 {
s.errs.AddErr(p0, ErrFunctionSignature, "expected 2 contextual arguments, received %d", len(p0.ContextualArgs))
return cast(p0, types.UnknownType)
}
// contextual args have to be strings
for _, ctxArgs := range p0.ContextualArgs {
dt, ok := ctxArgs.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(ctxArgs)
}
s.expect(ctxArgs, dt, types.TextType)
}
// verify the inputs
if len(p0.Args) != len(proc.Parameters) {
s.errs.AddErr(p0, ErrFunctionSignature, "expected %d arguments, received %d", len(proc.Parameters), len(p0.Args))
return cast(p0, types.UnknownType)
}
for i, arg := range p0.Args {
dt, ok := arg.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(arg)
}
if !dt.Equals(proc.Parameters[i]) {
return s.typeErr(arg, dt, proc.Parameters[i])
}
}
return s.returnProcedureReturnExpr(p0, p0.Name, proc.Returns)
}
// returnProcedureReturnExpr handles a procedure return used as an expression return. It mandates
// that the procedure returns a single value, or a table.
func (s *sqlAnalyzer) returnProcedureReturnExpr(p0 ExpressionCall, procedureName string, ret *types.ProcedureReturn) any {
// if an expression calls a function, it should return exactly one value or a table.
if ret == nil {
if p0.GetTypeCast() != nil {
s.errs.AddErr(p0, ErrType, "cannot typecast procedure %s because does not return a value", procedureName)
}
return nil
}
// if it returns a table, we need to return it as a set of attributes.
if ret.IsTable {
attrs := make([]*Attribute, len(ret.Fields))
for i, f := range ret.Fields {
attrs[i] = &Attribute{
Name: f.Name,
Type: f.Type,
}
}
return &returnsTable{
attrs: attrs,
}
}
switch len(ret.Fields) {
case 0:
s.errs.AddErr(p0, ErrFunctionSignature, "procedure %s does not return a value", procedureName)
return cast(p0, types.UnknownType)
case 1:
return cast(p0, ret.Fields[0].Type)
default:
if p0.GetTypeCast() != nil {
s.errs.AddErr(p0, ErrType, "cannot type cast multiple return values")
}
retVals := make([]*types.DataType, len(ret.Fields))
for i, f := range ret.Fields {
retVals[i] = f.Type.Copy()
}
return retVals
}
}
// returnsTable is a special struct returned by returnProcedureReturnExpr when a procedure returns a table.
// It is used internally to detect when a procedure returns a table, so that we can properly throw type errors
// with helpful messages when a procedure returning a table is used in a position where a scalar value is expected.
type returnsTable struct {
attrs []*Attribute
}
func (s *sqlAnalyzer) VisitExpressionVariable(p0 *ExpressionVariable) any {
dt, ok := s.blockContext.variables[p0.String()]
if !ok {
// if not found, it could be an anonymous variable.
anonVar, ok := s.blockContext.anonymousVariables[p0.String()]
if ok {
// if it is anonymous, we cannot type cast, since it is a compound type.
if p0.GetTypeCast() != nil {
s.errs.AddErr(p0, ErrType, "cannot type cast compound variable")
}
return anonVar
}
// if not found, then var does not exist.
// for raw SQL queries, this is ok. For procedures and actions, this is an error.
if !s.sqlCtx.inLoneSQL {
s.errs.AddErr(p0, ErrUndeclaredVariable, p0.String())
}
return cast(p0, types.UnknownType)
}
return cast(p0, dt)
}
func (s *sqlAnalyzer) VisitExpressionArrayAccess(p0 *ExpressionArrayAccess) any {
if s.sqlCtx.isInlineAction {
s.errs.AddErr(p0, ErrAssignment, "array access is not supported in in-line action statements")
}
var isArray bool
if p0.Index != nil {
// if single index, result is not an array
idxAttr, ok := p0.Index.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Index)
}
if !idxAttr.Equals(types.IntType) {
return s.typeErr(p0.Index, idxAttr, types.IntType)
}
} else {
// if multiple indexes, result is an array
isArray = true
for _, idx := range p0.FromTo {
if idx == nil {
continue
}
idxAttr, ok := idx.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(idx)
}
if !idxAttr.Equals(types.IntType) {
return s.typeErr(idx, idxAttr, types.IntType)
}
}
}
arrAttr, ok := p0.Array.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Array)
}
if !arrAttr.IsArray {
s.errs.AddErr(p0.Array, ErrType, "expected array")
return cast(p0, types.UnknownType)
}
return cast(p0, &types.DataType{
Name: arrAttr.Name,
Metadata: arrAttr.Metadata,
IsArray: isArray,
})
}
func (s *sqlAnalyzer) VisitExpressionMakeArray(p0 *ExpressionMakeArray) any {
if s.sqlCtx.isInlineAction {
s.errs.AddErr(p0, ErrAssignment, "array instantiation is not supported in in-line action statements")
}
if len(p0.Values) == 0 {
s.errs.AddErr(p0, ErrAssignment, "array instantiation must have at least one element")
return cast(p0, types.UnknownType)
}
first, ok := p0.Values[0].Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Values[0])
}
for _, v := range p0.Values {
typ, ok := v.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(v)
}
if !typ.Equals(first) {
return s.typeErr(v, typ, first)
}
}
return cast(p0, &types.DataType{
Name: first.Name,
Metadata: first.Metadata,
IsArray: true,
})
}
func (s *sqlAnalyzer) VisitExpressionFieldAccess(p0 *ExpressionFieldAccess) any {
if s.sqlCtx.isInlineAction {
s.errs.AddErr(p0, ErrAssignment, "field access is not supported in in-line action statements")
}
// field access needs to be accessing a compound type.
// currently, compound types can only be anonymous variables declared
// as loop receivers.
anonType, ok := p0.Record.Accept(s).(map[string]*types.DataType)
if !ok {
s.errs.AddErr(p0.Record, ErrType, "cannot access field on non-compound type")
return cast(p0, types.UnknownType)
}
dt, ok := anonType[p0.Field]
if !ok {
s.errs.AddErr(p0, ErrType, "unknown field %s", p0.Field)
return cast(p0, types.UnknownType)
}
return cast(p0, dt)
}
func (s *sqlAnalyzer) VisitExpressionParenthesized(p0 *ExpressionParenthesized) any {
dt, ok := p0.Inner.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Inner)
}
return cast(p0, dt)
}
func (s *sqlAnalyzer) VisitExpressionComparison(p0 *ExpressionComparison) any {
left, ok := p0.Left.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Left)
}
right, ok := p0.Right.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Right)
}
if !left.Equals(right) {
return s.typeErr(p0.Right, right, left)
}
return cast(p0, types.BoolType)
}
func (s *sqlAnalyzer) VisitExpressionLogical(p0 *ExpressionLogical) any {
if s.sqlCtx.isInlineAction {
s.errs.AddErr(p0, ErrAssignment, "logical expressions are not supported in in-line action statements")
}
left, ok := p0.Left.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Left)
}
right, ok := p0.Right.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Right)
}
if !left.Equals(types.BoolType) {
return s.typeErr(p0.Left, left, types.BoolType)
}
if !right.Equals(types.BoolType) {
return s.typeErr(p0.Right, right, types.BoolType)
}
return cast(p0, types.BoolType)
}
func (s *sqlAnalyzer) VisitExpressionArithmetic(p0 *ExpressionArithmetic) any {
left, ok := p0.Left.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Left)
}
right, ok := p0.Right.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Right)
}
// both must be numeric UNLESS it is a concat
if p0.Operator == ArithmeticOperatorConcat {
if !left.Equals(types.TextType) || !right.Equals(types.TextType) {
// Postgres supports concatenation on non-text types, but we do not,
// so we give a more descriptive error here.
// see the note at the top of: https://www.postgresql.org/docs/16.1/functions-string.html
s.errs.AddErr(p0.Left, ErrType, "concatenation only allowed on text types. received %s and %s", left.String(), right.String())
return cast(p0, types.UnknownType)
}
} else {
s.expectedNumeric(p0.Left, left)
}
// we check this after to return a more helpful error message if
// the user is not concatenating strings.
if !left.Equals(right) {
return s.typeErr(p0.Right, right, left)
}
return cast(p0, left)
}
func (s *sqlAnalyzer) VisitExpressionUnary(p0 *ExpressionUnary) any {
e, ok := p0.Expression.Accept(s).(*types.DataType)
if !ok {
return s.expressionTypeErr(p0.Expression)
}
switch p0.Operator {
default:
panic("unknown unary operator")
case UnaryOperatorPos:
s.expectedNumeric(p0.Expression, e)
case UnaryOperatorNeg:
s.expectedNumeric(p0.Expression, e)
if e.Equals(types.Uint256Type) {
s.errs.AddErr(p0.Expression, ErrType, "cannot negate uint256")
return cast(p0, types.UnknownType)
}
case UnaryOperatorNot:
s.expect(p0.Expression, e, types.BoolType)
}
return cast(p0, e)
}
func (s *sqlAnalyzer) VisitExpressionColumn(p0 *ExpressionColumn) any {
if s.sqlCtx.isInlineAction {
s.errs.AddErr(p0, ErrAssignment, "column references are not supported in in-line action statements")
}
// there is a special case, where if we are within an ORDER BY clause,
// we can access columns in the result set. We should search that first
// before searching all joined tables, as result set columns with conflicting
// names are given precedence over joined tables.
if s.sqlCtx._inOrdering && p0.Table == "" {
attr := findAttribute(s.sqlCtx._result, p0.Column)
// short-circuit if we find the column, otherwise proceed to normal search
if attr != nil {
return cast(p0, attr.Type)
}
}
// if we are in an upsert and the column references a column name in the target table
// AND the table is not specified, we need to throw an ambiguity error. For conflict tables,
// the user HAS to specify whether the upsert value is from the existing table or excluded table.
if s.sqlCtx.inConflict && p0.Table == "" {
mainTbl, ok := s.sqlCtx.joinedTables[s.sqlCtx.targetTable]
// if not ok, then we are in a subquery or something else, and we can ignore this check.
if ok {
if _, ok = mainTbl.FindColumn(p0.Column); ok {
s.errs.AddErr(p0, ErrAmbiguousConflictTable, `upsert value is ambigous. specify whether the column is from "%s" or "%s"`, s.sqlCtx.targetTable, tableExcluded)
return cast(p0, types.UnknownType)
}
}
}