diff --git a/driver.go b/driver.go index 6a9dd33..db5b759 100644 --- a/driver.go +++ b/driver.go @@ -5,6 +5,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "google.golang.org/api/bigquery/v2" "sync" "github.com/mattn/go-sqlite3" @@ -133,6 +134,11 @@ func (c *ZetaSQLiteConn) AddNamePath(path string) error { return c.analyzer.AddNamePath(path) } +func (c *ZetaSQLiteConn) SetQueryParameters(parameters []*bigquery.QueryParameter) { + c.analyzer.SetQueryParameters(parameters) + +} + func (s *ZetaSQLiteConn) CheckNamedValue(value *driver.NamedValue) error { return nil } diff --git a/driver_test.go b/driver_test.go index 30d8030..192ccec 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3,6 +3,8 @@ package zetasqlite_test import ( "context" "database/sql" + "fmt" + "google.golang.org/api/bigquery/v2" "testing" "github.com/google/go-cmp/cmp" @@ -62,6 +64,224 @@ CREATE VIEW IF NOT EXISTS SingerNames AS SELECT FirstName || ' ' || LastName AS } } +func configureParameters(conn *sql.Conn, parameters []*bigquery.QueryParameter) error { + if err := conn.Raw(func(c interface{}) error { + zetasqliteConn, ok := c.(*zetasqlite.ZetaSQLiteConn) + if !ok { + return fmt.Errorf("failed to get ZetaSQLiteConn from %T", c) + } + zetasqliteConn.SetQueryParameters(parameters) + return nil + }); err != nil { + return fmt.Errorf("failed to setup query parameters: %s", err) + } + return nil +} + +func TestNamedParameters(t *testing.T) { + ctx := context.Background() + db, err := sql.Open("zetasqlite", ":memory:") + if err != nil { + t.Fatal(err) + } + if _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) +)`); err != nil { + t.Fatal(err) + } + conn, err := db.Conn(ctx) + if _, err := conn.ExecContext(ctx, `INSERT Singers (SingerId, FirstName, LastName) VALUES (1, 'John', 'Titor')`); err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + t.Run("test multiple statements named params", func(t *testing.T) { + err = configureParameters(conn, []*bigquery.QueryParameter{ + { + Name: "id", + ParameterType: &bigquery.QueryParameterType{ + Type: "INT64", + }, + ParameterValue: &bigquery.QueryParameterValue{ + Value: "1", + }, + }, + { + Name: "name", + ParameterType: &bigquery.QueryParameterType{ + Type: "STRING", + }, + ParameterValue: &bigquery.QueryParameterValue{ + Value: "John", + }, + }, + }) + row := conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE SingerId = @id OR (@name is null OR FirstName = @name)", 1, "John") + if row.Err() != nil { + t.Fatal(row.Err()) + } + var ( + singerID int64 + firstName string + lastName string + ) + if err := row.Scan(&singerID, &firstName, &lastName); err != nil { + t.Fatal(err) + } + if singerID != 1 || firstName != "John" || lastName != "Titor" { + t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName) + } + }) + + t.Run("test array type", func(t *testing.T) { + err = configureParameters(conn, []*bigquery.QueryParameter{ + { + Name: "names", + ParameterType: &bigquery.QueryParameterType{ + Type: "ARRAY", + ArrayType: &bigquery.QueryParameterType{ + Type: "STRING", + }, + }, + ParameterValue: &bigquery.QueryParameterValue{ + ArrayValues: []*bigquery.QueryParameterValue{ + {Value: "John"}, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + row := conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE FirstName IN UNNEST(@names)", []string{ + "John", + }) + if row.Err() != nil { + t.Fatal(row.Err()) + } + var ( + singerID int64 + firstName string + lastName string + ) + if err := row.Scan(&singerID, &firstName, &lastName); err != nil { + t.Fatal(err) + } + if singerID != 1 || firstName != "John" || lastName != "Titor" { + t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName) + } + }) + + t.Run("test struct type", func(t *testing.T) { + err = configureParameters(conn, []*bigquery.QueryParameter{ + { + Name: "names", + ParameterType: &bigquery.QueryParameterType{ + Type: "STRUCT", + StructTypes: []*bigquery.QueryParameterTypeStructTypes{ + {Name: "first", Type: &bigquery.QueryParameterType{Type: "STRING"}}, + }, + }, + ParameterValue: &bigquery.QueryParameterValue{ + StructValues: map[string]bigquery.QueryParameterValue{ + "first": {Value: "John"}, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + row := conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE FirstName = @names.first", map[string]string{ + "first": "John", + }) + if row.Err() != nil { + t.Fatal(row.Err()) + } + var ( + singerID int64 + firstName string + lastName string + ) + if err := row.Scan(&singerID, &firstName, &lastName); err != nil { + t.Fatal(err) + } + if singerID != 1 || firstName != "John" || lastName != "Titor" { + t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName) + } + }) + + t.Run("test parameter pollution type", func(t *testing.T) { + param := "test_param" + // re-using the same parameter name should with different types works across queries + err = configureParameters(conn, []*bigquery.QueryParameter{ + { + Name: param, + ParameterType: &bigquery.QueryParameterType{ + Type: "STRUCT", + StructTypes: []*bigquery.QueryParameterTypeStructTypes{ + {Name: "first", Type: &bigquery.QueryParameterType{Type: "STRING"}}, + }, + }, + ParameterValue: &bigquery.QueryParameterValue{ + StructValues: map[string]bigquery.QueryParameterValue{ + "first": {Value: "John"}, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + row := conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE FirstName = @test_param.first", map[string]string{ + "first": "John", + }) + if row.Err() != nil { + t.Fatal(row.Err()) + } + var ( + singerID int64 + firstName string + lastName string + ) + if err := row.Scan(&singerID, &firstName, &lastName); err != nil { + t.Fatal(err) + } + if singerID != 1 || firstName != "John" || lastName != "Titor" { + t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName) + } + err = configureParameters(conn, []*bigquery.QueryParameter{ + { + Name: param, + ParameterType: &bigquery.QueryParameterType{ + Type: "STRING", + }, + ParameterValue: &bigquery.QueryParameterValue{ + Value: "John", + }, + }, + }) + if err != nil { + t.Fatal(err) + } + row = conn.QueryRowContext(ctx, "SELECT SingerID, FirstName, LastName FROM Singers WHERE FirstName = @test_param", "John") + if row.Err() != nil { + t.Fatal(row.Err()) + } + if err := row.Scan(&singerID, &firstName, &lastName); err != nil { + t.Fatal(err) + } + if singerID != 1 || firstName != "John" || lastName != "Titor" { + t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName) + } + }) +} + func TestRegisterCustomDriver(t *testing.T) { sql.Register("zetasqlite-custom", &zetasqlite.ZetaSQLiteDriver{ ConnectHook: func(conn *zetasqlite.ZetaSQLiteConn) error { diff --git a/internal/analyzer.go b/internal/analyzer.go index fa0aac6..dfd7318 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -1,6 +1,8 @@ package internal import ( + "google.golang.org/api/bigquery/v2" + "context" "database/sql/driver" "fmt" @@ -18,6 +20,7 @@ type Analyzer struct { isExplainMode bool catalog *Catalog opt *zetasql.AnalyzerOptions + queryParameters []*bigquery.QueryParameter } func NewAnalyzer(catalog *Catalog) (*Analyzer, error) { @@ -118,6 +121,16 @@ func (a *Analyzer) NamePath() []string { return a.namePath.path } +func (a *Analyzer) SetQueryParameters(parameters []*bigquery.QueryParameter) { + a.queryParameters = parameters +} + +func (a *Analyzer) PopQueryParameters() []*bigquery.QueryParameter { + parameters := a.queryParameters + a.SetQueryParameters(nil) + return parameters +} + func (a *Analyzer) SetNamePath(path []string) error { return a.namePath.setPath(path) } @@ -183,6 +196,29 @@ func (a *Analyzer) getParameterMode(stmt parsed_ast.StatementNode) (zetasql.Para type StmtActionFunc func() (StmtAction, error) +func (a *Analyzer) configureQueryParameters(options *zetasql.AnalyzerOptions) error { + parameters := a.PopQueryParameters() + for _, parameter := range parameters { + parameterType, err := ZetaSQLTypeFromBigQueryType(parameter.ParameterType) + if err != nil { + return err + } + + if parameter.Name == "" { + err = options.AddPositionalQueryParameter(parameterType) + if err != nil { + return err + } + } else { + err = options.AddQueryParameter(parameter.Name, parameterType) + if err != nil { + return err + } + } + } + return nil +} + func (a *Analyzer) Analyze(ctx context.Context, conn *Conn, query string, args []driver.NamedValue) ([]StmtActionFunc, error) { if err := a.catalog.Sync(ctx, conn); err != nil { return nil, fmt.Errorf("failed to sync catalog: %w", err) @@ -195,6 +231,14 @@ func (a *Analyzer) Analyze(ctx context.Context, conn *Conn, query string, args [ for _, spec := range a.catalog.getFunctions(a.namePath) { funcMap[spec.FuncName()] = spec } + options, err := newAnalyzerOptions() + if err != nil { + return nil, fmt.Errorf("failed to initialize analyzer options") + } + err = a.configureQueryParameters(options) + if err != nil { + return nil, fmt.Errorf("failed to configure query parameter types: %s", err) + } actionFuncs := make([]StmtActionFunc, 0, len(stmts)) for _, stmt := range stmts { stmt := stmt @@ -203,12 +247,12 @@ func (a *Analyzer) Analyze(ctx context.Context, conn *Conn, query string, args [ if err != nil { return nil, err } - a.opt.SetParameterMode(mode) + options.SetParameterMode(mode) out, err := zetasql.AnalyzeStatementFromParserAST( query, stmt, a.catalog, - a.opt, + options, ) if err != nil { return nil, fmt.Errorf("failed to analyze: %w", err) @@ -228,6 +272,74 @@ func (a *Analyzer) Analyze(ctx context.Context, conn *Conn, query string, args [ return actionFuncs, nil } +func ZetaSQLTypeFromBigQueryType(t *bigquery.QueryParameterType) (types.Type, error) { + // Generates ZetaSQL annotated types from a list of bigquery query parameters + if t.Type == "ARRAY" { + element, err := ZetaSQLTypeFromBigQueryType(t.ArrayType) + if err != nil { + return nil, err + } + return types.NewArrayType(element) + } + + if t.Type == "STRUCT" { + fields := []*types.StructField{} + for _, field := range t.StructTypes { + element, err := ZetaSQLTypeFromBigQueryType(field.Type) + if err != nil { + return nil, err + } + + fields = append(fields, types.NewStructField(field.Name, element)) + } + return types.NewStructType(fields) + } + + var zetasqlType types.Type + switch t.Type { + case "INT32": + zetasqlType = types.Int32Type() + case "INT64": + zetasqlType = types.Int64Type() + case "UINT32": + zetasqlType = types.Uint32Type() + case "UINT64": + zetasqlType = types.Uint64Type() + case "BOOL": + zetasqlType = types.BoolType() + case "FLOAT", "FLOAT32": + zetasqlType = types.FloatType() + case "FLOAT64", "DOUBLE": + zetasqlType = types.DoubleType() + case "STRING": + zetasqlType = types.StringType() + case "BYTES": + zetasqlType = types.BytesType() + case "DATE": + zetasqlType = types.DateType() + case "TIMESTAMP": + zetasqlType = types.TimestampType() + case "TIME": + zetasqlType = types.TimeType() + case "DATETIME": + zetasqlType = types.DatetimeType() + case "GEOGRAPHY": + zetasqlType = types.GeographyType() + case "NUMERIC", "DECIMAL": + zetasqlType = types.NumericType() + case "BIGDECIMAL", "BIGNUMERIC": + zetasqlType = types.BigNumericType() + case "JSON": + zetasqlType = types.JsonType() + case "INTERVAL": + zetasqlType = types.IntervalType() + default: + return nil, fmt.Errorf("unsupported query parameter type: %s", t.Type) + } + return zetasqlType, nil + +} + func (a *Analyzer) context( ctx context.Context, funcMap map[string]*FunctionSpec,