Skip to content

Commit

Permalink
feat: support dialect-aware query conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
atzoum committed Sep 10, 2024
1 parent 8cab1fd commit 7a1a31d
Show file tree
Hide file tree
Showing 44 changed files with 1,124 additions and 30 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
github.com/google/uuid v1.6.0
github.com/lib/pq v1.10.9
github.com/ory/dockertest/v3 v3.11.0
github.com/rudderlabs/goqu/v10 v10.3.0
github.com/rudderlabs/rudder-go-kit v0.40.0
github.com/rudderlabs/sql-tunnels v0.1.7
github.com/samber/lo v1.47.0
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg6
github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY=
github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU=
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
Expand Down Expand Up @@ -299,6 +301,8 @@ github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/f
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
github.com/rudderlabs/goqu/v10 v10.3.0 h1:FaZioS8fRYJVYoLO5lieuyiVvEHhmQ6jP5sKwPIcKSs=
github.com/rudderlabs/goqu/v10 v10.3.0/go.mod h1:LH2vI5gGHBxEQuESqFyk5ZA2anGINc8o25hbidDWOYw=
github.com/rudderlabs/rudder-go-kit v0.40.0 h1:Vk/NZm2DUuOiMmTSKUWYQVbXkl4If9KdGQOjNpXCPC4=
github.com/rudderlabs/rudder-go-kit v0.40.0/go.mod h1:GtOYIFfVvNcXabgGytoGdsjdpKTH6PipFIom0bY94WQ=
github.com/rudderlabs/sql-tunnels v0.1.7 h1:wDCRl6zY4M5gfWazf7XkSTGQS3yjBzUiUgEMBIfHNDA=
Expand Down
68 changes: 52 additions & 16 deletions sqlconnect/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"time"

"github.com/rudderlabs/goqu/v10"
)

var (
Expand Down Expand Up @@ -98,23 +101,56 @@ type JsonRowMapper interface {
JSONRowMapper() RowMapper[map[string]any]
}

type Dialect interface {
// QuoteTable quotes a table name
QuoteTable(table RelationRef) string
type (
Dialect interface {
// QuoteTable quotes a table name
QuoteTable(table RelationRef) string

// QuoteIdentifier quotes an identifier, e.g. a column name
QuoteIdentifier(name string) string
// QuoteIdentifier quotes an identifier, e.g. a column name
QuoteIdentifier(name string) string

// FormatTableName formats a table name, typically by lower or upper casing it, depending on the database
//
// Deprecated: to be removed in future versions, since its behaviour is not consistent across databases, e.g. using lowercase for BigQuery while it shouldn't.
// If you want to have a consistent behaviour across databases, use [NormaliseIdentifier] and [ParseRelationRef] instead.
FormatTableName(name string) string
// FormatTableName formats a table name, typically by lower or upper casing it, depending on the database
//
// Deprecated: to be removed in future versions, since its behaviour is not consistent across databases, e.g. using lowercase for BigQuery while it shouldn't.
// If you want to have a consistent behaviour across databases, use [NormaliseIdentifier] and [ParseRelationRef] instead.
FormatTableName(name string) string

// NormaliseIdentifier normalises the identifier's parts that are unquoted, typically by lower or upper casing them, depending on the database
NormaliseIdentifier(identifier string) string
// NormaliseIdentifier normalises the identifier's parts that are unquoted, typically by lower or upper casing them, depending on the database
NormaliseIdentifier(identifier string) string

// ParseRelationRef parses a string into a RelationRef after normalising the identifier and stripping out surrounding quotes.
// The result is a RelationRef with case-sensitive fields, i.e. it can be safely quoted (see [QuoteTable] and, for instance, used for matching against the database's information schema.
ParseRelationRef(identifier string) (RelationRef, error)
}
// ParseRelationRef parses a string into a RelationRef after normalising the identifier and stripping out surrounding quotes.
// The result is a RelationRef with case-sensitive fields, i.e. it can be safely quoted (see [QuoteTable] and, for instance, used for matching against the database's information schema.
ParseRelationRef(identifier string) (RelationRef, error)

// QueryCondition returns a dialect-specific query condition sql string for the provided identifier, operator and value(s).
//
// E.g. QueryCondition("age", "gt", 18) returns "age > 18"
//
// Each operator has a different number of arguments, e.g. [eq] requires one argument, [in] requires at least one argument, etc.
// See [op] package for the list of supported operators
QueryCondition(identifier, operator string, args ...any) (sql string, err error)

// GoquExpressionToSQL converts an Expression to a SQL string
GoquExpressionToSQL(expression GoquExpression) (sql string, err error)

// Expressions returns the dialect-specific expressions
Expressions() Expressions
}

// GoquExpression represents a goqu expression
GoquExpression = goqu.Expression

// Expressions provides dialect-specific expressions
Expressions interface {
// TimestampAdd returns an expression that adds the interval to the timestamp value.
// The value can either be a string literal (column, timestamp, function etc.) or a [time.Time] value.
TimestampAdd(timeValue any, interval int, unit string) (Expression, error)
}

// Expression represents a dialect-specific expression.
// One can get the expression's SQL string by calling [String()] on it.
Expression interface {
GoquExpression() GoquExpression
fmt.Stringer
}
)
7 changes: 7 additions & 0 deletions sqlconnect/internal/base/dbopts.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ func WithDialect(dialect sqlconnect.Dialect) Option {
}
}

// WithGoquDialect sets the goqu dialect for the client
func WithGoquDialect(gqd *GoquDialect) Option {
return func(db *DB) {
db.Dialect = &dialect{gqd}
}
}

// WithSQLCommandsOverride allows for overriding some of the sql commands that the client uses
func WithSQLCommandsOverride(override func(defaultCommands SQLCommands) SQLCommands) Option {
return func(db *DB) {
Expand Down
4 changes: 3 additions & 1 deletion sqlconnect/internal/base/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

type dialect struct{}
type dialect struct {
*GoquDialect
}

// QuoteTable quotes a table name
func (d dialect) QuoteTable(table sqlconnect.RelationRef) string {
Expand Down
159 changes: 159 additions & 0 deletions sqlconnect/internal/base/goqu_dialect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package base

import (
"fmt"

"github.com/samber/lo"

"github.com/rudderlabs/goqu/v10"
"github.com/rudderlabs/goqu/v10/exp"
"github.com/rudderlabs/goqu/v10/sqlgen"
"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

func NewGoquDialect(dialect string, o *sqlgen.SQLDialectOptions, expressions *Expressions) *GoquDialect {
return &GoquDialect{
esg: sqlgen.NewExpressionSQLGenerator(dialect, o),
expressions: expressions,
}
}

type GoquDialect struct {
esg sqlgen.ExpressionSQLGenerator
expressions *Expressions
}

type Expressions struct {
TimestampAdd func(time any, interval int, unit string) goqu.Expression
}

func (gq *GoquDialect) QueryCondition(identifier, operator string, args ...any) (sql string, err error) {
args = lo.Map(args, func(a any, _ int) any {
if s, ok := a.(sqlconnect.Expression); ok {
return s.GoquExpression()
}
return a
})
var expr goqu.Expression
switch operator {
case "eq":
if len(args) != 1 {
return "", fmt.Errorf("eq operator requires exactly one argument, got %d", len(args))
}
expr = goqu.C(identifier).Eq(args[0])
case "neq":
if len(args) != 1 {
return "", fmt.Errorf("neq operator requires exactly one argument, got %d", len(args))
}
expr = goqu.C(identifier).Neq(args[0])
case "in":
if len(args) == 0 {
return "", fmt.Errorf("in operator requires at least one argument")
}
expr = goqu.C(identifier).In(args...)
case "notin":
if len(args) == 0 {
return "", fmt.Errorf("notin operator requires at least one argument")
}
expr = goqu.C(identifier).NotIn(args...)
case "like":
if len(args) != 1 {
return "", fmt.Errorf("like operator requires exactly one argument, got %d", len(args))
}
expr = goqu.C(identifier).Like(args[0])
case "notlike":
if len(args) != 1 {
return "", fmt.Errorf("notlike operator requires exactly one argument, got %d", len(args))
}
expr = goqu.C(identifier).NotLike(args[0])
case "isset":
if len(args) != 0 {
return "", fmt.Errorf("isset operator requires no arguments, got %d", len(args))
}
expr = goqu.C(identifier).IsNotNull()
case "notset":
if len(args) != 0 {
return "", fmt.Errorf("isnotset operator requires no arguments, got %d", len(args))
}
expr = goqu.C(identifier).IsNull()
case "gt":
if len(args) != 1 {
return "", fmt.Errorf("gt operator requires exactly one argument, got %d", len(args))
}
expr = goqu.C(identifier).Gt(args[0])
case "gte":
if len(args) != 1 {
return "", fmt.Errorf("gte operator requires exactly one argument, got %d", len(args))
}
expr = goqu.C(identifier).Gte(args[0])
case "lt":
if len(args) != 1 {
return "", fmt.Errorf("lt operator requires exactly one argument, got %d", len(args))
}
expr = goqu.C(identifier).Lt(args[0])
case "lte":
if len(args) != 1 {
return "", fmt.Errorf("lte operator requires exactly one argument, got %d", len(args))
}
expr = goqu.C(identifier).Lte(args[0])
case "between":
if len(args) != 2 {
return "", fmt.Errorf("between operator requires exactly two arguments, got %d", len(args))
}
expr = goqu.C(identifier).Between(exp.NewRangeVal(args[0], args[1]))
case "notbetween":
if len(args) != 2 {
return "", fmt.Errorf("notbetween operator requires exactly two arguments, got %d", len(args))
}
expr = goqu.C(identifier).NotBetween(exp.NewRangeVal(args[0], args[1]))
default:
return "", fmt.Errorf("unsupported operator: %s", operator)
}

return gq.GoquExpressionToSQL(expr)
}

func (gq *GoquDialect) GoquExpressionToSQL(expression sqlconnect.GoquExpression) (sql string, err error) {
sql, _, err = sqlgen.GenerateExpressionSQL(gq.esg, false, expression)
return
}

func (gq *GoquDialect) Expressions() sqlconnect.Expressions {
return gq
}

func (gq *GoquDialect) TimestampAdd(timeValue any, interval int, unit string) (sqlconnect.Expression, error) {
switch unit {
case "second", "minute", "hour", "day", "month", "year":
case "week":
unit = "day"
interval *= 7
default:
return nil, fmt.Errorf("unsupported unit: %s", unit)

Check warning on line 132 in sqlconnect/internal/base/goqu_dialect.go

View check run for this annotation

Codecov / codecov/patch

sqlconnect/internal/base/goqu_dialect.go#L128-L132

Added lines #L128 - L132 were not covered by tests
}

var v any
switch timeValue := timeValue.(type) {
case string:
v = goqu.L(timeValue)
default:
v = timeValue
}

goquExpression := gq.expressions.TimestampAdd(v, interval, unit)
sql, _, err := sqlgen.GenerateExpressionSQL(gq.esg, false, goquExpression)
return &expression{Expression: goquExpression, sql: sql}, err
}

type expression struct {
goqu.Expression
sql string
}

func (e *expression) GoquExpression() goqu.Expression {
return e.Expression
}

func (e *expression) String() string {
return e.sql

Check warning on line 158 in sqlconnect/internal/base/goqu_dialect.go

View check run for this annotation

Codecov / codecov/patch

sqlconnect/internal/base/goqu_dialect.go#L157-L158

Added lines #L157 - L158 were not covered by tests
}
2 changes: 1 addition & 1 deletion sqlconnect/internal/bigquery/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewDB(configJSON json.RawMessage) (*DB, error) {
DB: base.NewDB(
db,
sshtunnel.NoTunnelCloser,
base.WithDialect(dialect{}),
base.WithDialect(dialect{base.NewGoquDialect(DatabaseType, GoquDialectOptions(), GoquExpressions())}),
base.WithColumnTypeMapper(getColumnTypeMapper(config)),
base.WithJsonRowMapper(getJonRowMapper(config)),
base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands {
Expand Down
4 changes: 3 additions & 1 deletion sqlconnect/internal/bigquery/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base"
)

type dialect struct{}
type dialect struct {
*base.GoquDialect
}

var (
escape = regexp.MustCompile("('|\"|`)")
Expand Down
25 changes: 25 additions & 0 deletions sqlconnect/internal/bigquery/goqu_dialect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package bigquery

import (
"fmt"
"strings"

"github.com/rudderlabs/goqu/v10"
"github.com/rudderlabs/goqu/v10/sqlgen"
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base"
)

func GoquDialectOptions() *sqlgen.SQLDialectOptions {
o := sqlgen.DefaultDialectOptions()
o.QuoteIdentifiers = false
o.QuoteRune = '`'
return o
}

func GoquExpressions() *base.Expressions {
return &base.Expressions{
TimestampAdd: func(timeValue any, interval int, unit string) goqu.Expression {
return goqu.L(fmt.Sprintf("TIMESTAMP_ADD(?, INTERVAL %d %s)", interval, strings.ToUpper(unit)), timeValue)
},
}
}
35 changes: 35 additions & 0 deletions sqlconnect/internal/bigquery/goqu_dialect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package bigquery

import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/rudderlabs/goqu/v10"
"github.com/rudderlabs/goqu/v10/sqlgen"
)

func TestExpressions(t *testing.T) {
expressions := GoquExpressions()

t.Run("TimestampAdd", func(t *testing.T) {
t.Run("literal", func(t *testing.T) {
require.Equal(t, "TIMESTAMP_ADD(CURRENT_DATE, INTERVAL 1 HOUR)", toSQL(t, expressions.TimestampAdd(goqu.L("CURRENT_DATE"), 1, "hour")))
require.Equal(t, "TIMESTAMP_ADD('2020-01-01', INTERVAL 1 DAY)", toSQL(t, expressions.TimestampAdd(goqu.L("'2020-01-01'"), 1, "day")))
})

t.Run("time", func(t *testing.T) {
now, err := time.Parse(time.RFC3339, "2020-01-01T00:00:00Z")
require.NoError(t, err)
require.Equal(t, "TIMESTAMP_ADD('2020-01-01T00:00:00Z', INTERVAL 1 DAY)", toSQL(t, expressions.TimestampAdd(now, 1, "day")))
})
})
}

func toSQL(t *testing.T, expression interface{}) string {
esg := sqlgen.NewExpressionSQLGenerator(DatabaseType, GoquDialectOptions())
sql, _, err := sqlgen.GenerateExpressionSQL(esg, false, expression)
require.NoError(t, err)
return sql
}
12 changes: 12 additions & 0 deletions sqlconnect/internal/bigquery/testdata/goqu-test-seed.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
CREATE TABLE `{{.schema}}`.`goqu_test` (
_string STRING(10),
_int INT,
_float BIGNUMERIC(2,1),
_boolean BOOLEAN,
_timestamp TIMESTAMP
);

INSERT INTO `{{.schema}}`.`goqu_test`
(_string, _int, _float, _boolean, _timestamp)
VALUES
('string', 1, 1.1, TRUE, '2021-01-01T00:00:00Z');
2 changes: 1 addition & 1 deletion sqlconnect/internal/databricks/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func NewDB(configJson json.RawMessage) (*DB, error) {
DB: base.NewDB(
db,
tunnelCloser,
base.WithDialect(dialect{}),
base.WithDialect(dialect{base.NewGoquDialect(DatabaseType, GoquDialectOptions(), GoquExpressions())}),
base.WithColumnTypeMapper(getColumnTypeMapper(config)),
base.WithJsonRowMapper(getJonRowMapper(config)),
base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands {
Expand Down
Loading

0 comments on commit 7a1a31d

Please sign in to comment.