diff --git a/sqlconnect/internal/base/db.go b/sqlconnect/internal/base/db.go index 617efe4..9969bb4 100644 --- a/sqlconnect/internal/base/db.go +++ b/sqlconnect/internal/base/db.go @@ -32,7 +32,7 @@ func NewDB(db *sql.DB, tunnelCloser func() error, opts ...Option) *DB { return "SELECT schema_name FROM information_schema.schemata", "schema_name" }, SchemaExists: func(schema UnquotedIdentifier) string { - return fmt.Sprintf("SELECT schema_name FROM information_schema.schemata where schema_name = '%[1]s'", schema) + return fmt.Sprintf("SELECT schema_name FROM information_schema.schemata where schema_name = '%[1]s'", EscapeSqlString(schema)) }, DropSchema: func(schema QuotedIdentifier) string { return fmt.Sprintf("DROP SCHEMA %[1]s CASCADE", schema) }, CreateTestTable: func(table QuotedIdentifier) string { @@ -40,21 +40,21 @@ func NewDB(db *sql.DB, tunnelCloser func() error, opts ...Option) *DB { }, ListTables: func(schema UnquotedIdentifier) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%[1]s'", schema), B: "table_name"}, + {A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%[1]s'", EscapeSqlString(schema)), B: "table_name"}, } }, ListTablesWithPrefix: func(schema UnquotedIdentifier, prefix string) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' AND table_name LIKE '%[2]s'", schema, prefix+"%"), B: "table_name"}, + {A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' AND table_name LIKE '%[2]s'", EscapeSqlString(schema), prefix+"%"), B: "table_name"}, } }, TableExists: func(schema, table UnquotedIdentifier) string { - return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[2]s'", schema, table) + return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[2]s'", EscapeSqlString(schema), EscapeSqlString(table)) }, ListColumns: func(catalog, schema, table UnquotedIdentifier) (string, string, string) { - stmt := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table) + stmt := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", EscapeSqlString(schema), EscapeSqlString(table)) if catalog != "" { - stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog) + stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", EscapeSqlString(catalog)) } return stmt + " ORDER BY ordinal_position ASC", "column_name", "data_type" }, diff --git a/sqlconnect/internal/base/dialect.go b/sqlconnect/internal/base/dialect.go index fefbda6..cd7aa5e 100644 --- a/sqlconnect/internal/base/dialect.go +++ b/sqlconnect/internal/base/dialect.go @@ -19,7 +19,7 @@ func (d dialect) QuoteTable(table sqlconnect.RelationRef) string { // QuoteIdentifier quotes an identifier, e.g. a column name func (d dialect) QuoteIdentifier(name string) string { - return fmt.Sprintf(`"%s"`, name) + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(name, `"`, `""`)) } // FormatTableName formats a table name, typically by lower or upper casing it, depending on the database @@ -99,3 +99,8 @@ func doNormaliseIdentifier(identifier string, quote rune, normF func(string) str } return result.String() } + +// EscapeSqlString escapes a string for use in SQL, e.g. by doubling single quotes +func EscapeSqlString(value UnquotedIdentifier) string { + return strings.ReplaceAll(string(value), "'", "''") +} diff --git a/sqlconnect/internal/bigquery/db.go b/sqlconnect/internal/bigquery/db.go index 2231201..839c733 100644 --- a/sqlconnect/internal/bigquery/db.go +++ b/sqlconnect/internal/bigquery/db.go @@ -55,12 +55,12 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { } } cmds.TableExists = func(schema, table base.UnquotedIdentifier) string { - return fmt.Sprintf("SELECT table_name FROM `%[1]s`.INFORMATION_SCHEMA.TABLES WHERE table_name = '%[2]s'", schema, table) + return fmt.Sprintf("SELECT table_name FROM `%[1]s`.INFORMATION_SCHEMA.TABLES WHERE table_name = '%[2]s'", schema, base.EscapeSqlString(table)) } cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) { - stmt := fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table) + stmt := fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, base.EscapeSqlString(table)) if catalog != "" { - stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog) + stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", base.EscapeSqlString(catalog)) } return stmt, "column_name", "data_type" } diff --git a/sqlconnect/internal/bigquery/dialect.go b/sqlconnect/internal/bigquery/dialect.go index 0998ebb..84f1fe8 100644 --- a/sqlconnect/internal/bigquery/dialect.go +++ b/sqlconnect/internal/bigquery/dialect.go @@ -1,6 +1,7 @@ package bigquery import ( + "regexp" "strings" "github.com/rudderlabs/sqlconnect-go/sqlconnect" @@ -9,6 +10,11 @@ import ( type dialect struct{} +var ( + escape = regexp.MustCompile("('|\"|`)") + unescape = regexp.MustCompile("\\\\('|\")") +) + // QuoteTable quotes a table name func (d dialect) QuoteTable(table sqlconnect.RelationRef) string { if table.Schema != "" { @@ -19,7 +25,7 @@ func (d dialect) QuoteTable(table sqlconnect.RelationRef) string { // QuoteIdentifier quotes an identifier, e.g. a column name func (d dialect) QuoteIdentifier(name string) string { - return "`" + name + "`" + return "`" + escape.ReplaceAllString(name, "\\$1") + "`" } // FormatTableName formats a table name, typically by lower or upper casing it, depending on the database @@ -31,11 +37,25 @@ var identityFn = func(s string) string { return s } // NormaliseIdentifier normalises identifier parts that are unquoted, typically by lower or upper casing them, depending on the database func (d dialect) NormaliseIdentifier(identifier string) string { - return base.NormaliseIdentifier(identifier, '`', identityFn) + return escapeSpecial(base.NormaliseIdentifier(unescapeSpecial(identifier), '`', identityFn)) } // ParseRelationRef parses a string into a RelationRef after normalising the identifier and stripping out surrounding quotes. // The result is a RelationRef with case-sensitive fields, i.e. it can be safely quoted (see [QuoteTable] and, for instance, used for matching against the database's information schema. func (d dialect) ParseRelationRef(identifier string) (sqlconnect.RelationRef, error) { - return base.ParseRelationRef(identifier, '`', identityFn) + return base.ParseRelationRef(unescapeSpecial(identifier), '`', identityFn) +} + +// unescapeSpecial unescapes special characters in an identifier and replaces escaped backticks with a double backtick +func unescapeSpecial(identifier string) string { + identifier = strings.ReplaceAll(identifier, "\\`", "``") + return unescape.ReplaceAllString(identifier, "$1") +} + +// escapeSpecial escapes special characters in an identifier and replaces double backticks with an escaped backtick +func escapeSpecial(identifier string) string { + identifier = strings.ReplaceAll(identifier, "``", "\\`") + identifier = strings.ReplaceAll(identifier, "'", "\\'") + identifier = strings.ReplaceAll(identifier, "\"", "\\\"") + return identifier } diff --git a/sqlconnect/internal/bigquery/dialect_test.go b/sqlconnect/internal/bigquery/dialect_test.go index 533cbe8..110ffcd 100644 --- a/sqlconnect/internal/bigquery/dialect_test.go +++ b/sqlconnect/internal/bigquery/dialect_test.go @@ -18,6 +18,8 @@ func TestDialect(t *testing.T) { t.Run("quote identifier", func(t *testing.T) { quoted := d.QuoteIdentifier("column") require.Equal(t, "`column`", quoted, "column name should be quoted with backticks") + + require.Equal(t, "`col\\`umn`", d.QuoteIdentifier("col`umn"), "column name with backtick should be escaped") }) t.Run("quote table", func(t *testing.T) { @@ -41,8 +43,8 @@ func TestDialect(t *testing.T) { normalised = d.NormaliseIdentifier("TaBle.`ColUmn`") require.Equal(t, "TaBle.`ColUmn`", normalised, "non quoted parts should be normalised") - normalised = d.NormaliseIdentifier("`Sh``EmA`.TABLE.`ColUmn`") - require.Equal(t, "`Sh``EmA`.TABLE.`ColUmn`", normalised, "non quoted parts should be normalised") + normalised = d.NormaliseIdentifier("`Sh\\`EmA`.TABLE.`Co\\'lUmn`") + require.Equal(t, "`Sh\\`EmA`.TABLE.`Co\\'lUmn`", normalised, "non quoted parts should be normalised") }) t.Run("parse relation", func(t *testing.T) { @@ -62,8 +64,8 @@ func TestDialect(t *testing.T) { require.NoError(t, err) require.Equal(t, sqlconnect.RelationRef{Schema: "ScHeMA", Name: "TaBle"}, parsed) - parsed, err = d.ParseRelationRef("`CaTa``LoG`.ScHeMA.`TaBle`") + parsed, err = d.ParseRelationRef("`CaTa``LoG`.ScHeMA.`TaB\\`\\\"\\'le`") require.NoError(t, err) - require.Equal(t, sqlconnect.RelationRef{Catalog: "CaTa`LoG", Schema: "ScHeMA", Name: "TaBle"}, parsed) + require.Equal(t, sqlconnect.RelationRef{Catalog: "CaTa`LoG", Schema: "ScHeMA", Name: "TaB`\"'le"}, parsed) }) } diff --git a/sqlconnect/internal/bigquery/integration_test.go b/sqlconnect/internal/bigquery/integration_test.go index 71ff0b4..7225dc2 100644 --- a/sqlconnect/internal/bigquery/integration_test.go +++ b/sqlconnect/internal/bigquery/integration_test.go @@ -24,7 +24,8 @@ func TestBigqueryDB(t *testing.T) { []byte(configJSON), strings.ToLower, integrationtest.Options{ - LegacySupport: true, + LegacySupport: true, + SpecialCharactersInQuotedTable: "-", }, ) } diff --git a/sqlconnect/internal/databricks/db.go b/sqlconnect/internal/databricks/db.go index 8309e8e..31ae96f 100644 --- a/sqlconnect/internal/databricks/db.go +++ b/sqlconnect/internal/databricks/db.go @@ -77,14 +77,16 @@ func NewDB(configJson json.RawMessage) (*DB, error) { return "SELECT current_catalog()" } cmds.ListSchemas = func() (string, string) { return "SHOW SCHEMAS", "schema_name" } - cmds.SchemaExists = func(schema base.UnquotedIdentifier) string { return fmt.Sprintf(`SHOW SCHEMAS LIKE '%s'`, schema) } + cmds.SchemaExists = func(schema base.UnquotedIdentifier) string { + return fmt.Sprintf(`SHOW SCHEMAS LIKE '%s'`, base.EscapeSqlString(schema)) + } cmds.CreateTestTable = func(table base.QuotedIdentifier) string { return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %[1]s (c1 INT, c2 STRING)", table) } cmds.ListTables = func(schema base.UnquotedIdentifier) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: fmt.Sprintf("SHOW TABLES IN `%s`", schema), B: "tableName"}, + {A: fmt.Sprintf("SHOW TABLES IN `%s`", base.EscapeSqlString(schema)), B: "tableName"}, } } cmds.ListTablesWithPrefix = func(schema base.UnquotedIdentifier, prefix string) []lo.Tuple2[string, string] { @@ -93,7 +95,7 @@ func NewDB(configJson json.RawMessage) (*DB, error) { } } cmds.TableExists = func(schema, table base.UnquotedIdentifier) string { - return fmt.Sprintf("SHOW TABLES IN `%[1]s` LIKE '%[2]s'", schema, table) + return fmt.Sprintf("SHOW TABLES IN `%[1]s` LIKE '%[2]s'", schema, base.EscapeSqlString(table)) } cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) { if catalog == "" || !informationSchema { diff --git a/sqlconnect/internal/databricks/dialect.go b/sqlconnect/internal/databricks/dialect.go index 392f1d2..359cc2a 100644 --- a/sqlconnect/internal/databricks/dialect.go +++ b/sqlconnect/internal/databricks/dialect.go @@ -19,7 +19,7 @@ func (d dialect) QuoteTable(table sqlconnect.RelationRef) string { // QuoteIdentifier quotes an identifier, e.g. a column name func (d dialect) QuoteIdentifier(name string) string { - return "`" + name + "`" + return "`" + strings.ReplaceAll(name, "`", "``") + "`" } // FormatTableName formats a table name, typically by lower or upper casing it, depending on the database @@ -27,13 +27,17 @@ func (d dialect) FormatTableName(name string) string { return strings.ToLower(name) } -// NormaliseIdentifier normalises identifier parts that are unquoted, typically by lower or upper casing them, depending on the database +// NormaliseIdentifier normalises all identifier parts by lower casing them. func (d dialect) NormaliseIdentifier(identifier string) string { - return base.NormaliseIdentifier(identifier, '`', strings.ToLower) + // Identifiers are case-insensitive + // https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html#:~:text=Identifiers%20are%20case%2Dinsensitive + // Unity Catalog stores all object names as lowercase + // https://docs.databricks.com/en/sql/language-manual/sql-ref-names.html#:~:text=Unity%20Catalog%20stores%20all%20object%20names%20as%20lowercase + return strings.ToLower(identifier) } // ParseRelationRef parses a string into a RelationRef after normalising the identifier and stripping out surrounding quotes. // The result is a RelationRef with case-sensitive fields, i.e. it can be safely quoted (see [QuoteTable] and, for instance, used for matching against the database's information schema. func (d dialect) ParseRelationRef(identifier string) (sqlconnect.RelationRef, error) { - return base.ParseRelationRef(identifier, '`', strings.ToLower) + return base.ParseRelationRef(strings.ToLower(identifier), '`', strings.ToLower) } diff --git a/sqlconnect/internal/databricks/dialect_test.go b/sqlconnect/internal/databricks/dialect_test.go index 19a35e5..9e81772 100644 --- a/sqlconnect/internal/databricks/dialect_test.go +++ b/sqlconnect/internal/databricks/dialect_test.go @@ -36,13 +36,13 @@ func TestDialect(t *testing.T) { require.Equal(t, "column", normalised, "column name should be normalised to lowercase") normalised = d.NormaliseIdentifier("`ColUmn`") - require.Equal(t, "`ColUmn`", normalised, "quoted column name should not be normalised") + require.Equal(t, "`column`", normalised, "quoted column name should be normalised to lowercase") normalised = d.NormaliseIdentifier("TaBle.`ColUmn`") - require.Equal(t, "table.`ColUmn`", normalised, "non quoted parts should be normalised") + require.Equal(t, "table.`column`", normalised, "all parts should be normalised") normalised = d.NormaliseIdentifier("`Sh``EmA`.TABLE.`ColUmn`") - require.Equal(t, "`Sh``EmA`.table.`ColUmn`", normalised, "non quoted parts should be normalised") + require.Equal(t, "`sh``ema`.table.`column`", normalised, "all parts should be normalised to lowercase") }) t.Run("parse relation", func(t *testing.T) { @@ -56,14 +56,14 @@ func TestDialect(t *testing.T) { parsed, err = d.ParseRelationRef("`TaBle`") require.NoError(t, err) - require.Equal(t, sqlconnect.RelationRef{Name: "TaBle"}, parsed) + require.Equal(t, sqlconnect.RelationRef{Name: "table"}, parsed) parsed, err = d.ParseRelationRef("ScHeMA.`TaBle`") require.NoError(t, err) - require.Equal(t, sqlconnect.RelationRef{Schema: "schema", Name: "TaBle"}, parsed) + require.Equal(t, sqlconnect.RelationRef{Schema: "schema", Name: "table"}, parsed) parsed, err = d.ParseRelationRef("`CaTa``LoG`.ScHeMA.`TaBle`") require.NoError(t, err) - require.Equal(t, sqlconnect.RelationRef{Catalog: "CaTa`LoG", Schema: "schema", Name: "TaBle"}, parsed) + require.Equal(t, sqlconnect.RelationRef{Catalog: "cata`log", Schema: "schema", Name: "table"}, parsed) }) } diff --git a/sqlconnect/internal/databricks/integration_test.go b/sqlconnect/internal/databricks/integration_test.go index 4e5653a..71c4f99 100644 --- a/sqlconnect/internal/databricks/integration_test.go +++ b/sqlconnect/internal/databricks/integration_test.go @@ -43,7 +43,8 @@ func TestDatabricksDB(t *testing.T) { []byte(configJSON), strings.ToLower, integrationtest.Options{ - LegacySupport: true, + LegacySupport: true, + SpecialCharactersInQuotedTable: "`-", }, ) }) @@ -63,7 +64,8 @@ func TestDatabricksDB(t *testing.T) { []byte(configJSON), strings.ToLower, integrationtest.Options{ - LegacySupport: true, + LegacySupport: true, + SpecialCharactersInQuotedTable: "`-", }, ) diff --git a/sqlconnect/internal/integration_test/db_integration_test_scenario.go b/sqlconnect/internal/integration_test/db_integration_test_scenario.go index d67da97..2eba8f7 100644 --- a/sqlconnect/internal/integration_test/db_integration_test_scenario.go +++ b/sqlconnect/internal/integration_test/db_integration_test_scenario.go @@ -28,6 +28,8 @@ type Options struct { IncludesViewsInListTables bool + SpecialCharactersInQuotedTable string // special characters to test in quoted table identifiers (default: ,",',`") + ExtraTests func(t *testing.T, db sqlconnect.DB) } @@ -144,17 +146,58 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe }) t.Run("dialect", func(t *testing.T) { - // Create an unquoted table - unquotedTable := "UnQuoted_TablE" - identifier := db.QuoteIdentifier(schema.Name) + "." + unquotedTable - _, err := db.Exec("CREATE TABLE " + identifier + " (c1 int)") - require.NoError(t, err, "it should be able to create an unquoted table") - - table, err := db.ParseRelationRef(identifier) - require.NoError(t, err, "it should be able to parse an unquoted table") - exists, err := db.TableExists(ctx, table) - require.NoError(t, err, "it should be able to check if a table exists") - require.True(t, exists, "it should return true for a table that exists") + t.Run("with unquoted table", func(t *testing.T) { + identifier := db.QuoteIdentifier(schema.Name) + "." + "UnQuoted_TablE" + _, err := db.Exec("CREATE TABLE " + identifier + " (c1 int)") + require.NoError(t, err, "it should be able to create an unquoted table") + + table, err := db.ParseRelationRef(identifier) + require.NoError(t, err, "it should be able to parse an unquoted table") + + alltables, err := db.ListTables(ctx, schema) + require.NoError(t, err, "it should be able to list tables") + + exists, err := db.TableExists(ctx, table) + require.NoErrorf(t, err, "it should be able to check if a table exists: %s allTables: %+v", table, alltables) + require.Truef(t, exists, "it should return true for a table that exists: %s allTables: %+v", table, alltables) + }) + + t.Run("with quoted table", func(t *testing.T) { + identifier := db.QuoteIdentifier(schema.Name) + "." + db.QuoteIdentifier("Quoted_TablE") + _, err := db.Exec("CREATE TABLE " + identifier + " (c1 int)") + require.NoErrorf(t, err, "it should be able to create a quoted table: %s", identifier) + + table, err := db.ParseRelationRef(identifier) + require.NoError(t, err, "it should be able to parse a quoted table") + + alltables, err := db.ListTables(ctx, schema) + require.NoError(t, err, "it should be able to list tables") + + exists, err := db.TableExists(ctx, table) + require.NoErrorf(t, err, "it should be able to check if a table exists: %s allTables: %+v", table, alltables) + require.Truef(t, exists, "it should return true for a table that exists: %s allTables: %+v", table, alltables) + }) + + t.Run("with quoted table and special characters", func(t *testing.T) { + specialCharacters := " \"`'" + if len(opts.SpecialCharactersInQuotedTable) > 0 { + specialCharacters = opts.SpecialCharactersInQuotedTable + } + + identifier := db.QuoteIdentifier(schema.Name) + "." + db.QuoteIdentifier("Quoted_TablE"+specialCharacters) + _, err := db.Exec("CREATE TABLE " + identifier + " (c1 int)") + require.NoErrorf(t, err, "it should be able to create a quoted table: %s", identifier) + + table, err := db.ParseRelationRef(identifier) + require.NoError(t, err, "it should be able to parse a quoted table") + + alltables, err := db.ListTables(ctx, schema) + require.NoError(t, err, "it should be able to list tables") + + exists, err := db.TableExists(ctx, table) + require.NoErrorf(t, err, "it should be able to check if a table exists: %s allTables: %+v", table, alltables) + require.Truef(t, exists, "it should return true for a table that exists: %s allTables: %+v", table, alltables) + }) }) t.Run("table admin", func(t *testing.T) { diff --git a/sqlconnect/internal/mysql/dialect.go b/sqlconnect/internal/mysql/dialect.go index 7248a1e..be9a0ed 100644 --- a/sqlconnect/internal/mysql/dialect.go +++ b/sqlconnect/internal/mysql/dialect.go @@ -19,7 +19,7 @@ func (d dialect) QuoteTable(table sqlconnect.RelationRef) string { // QuoteIdentifier quotes an identifier, e.g. a column name func (d dialect) QuoteIdentifier(name string) string { - return "`" + name + "`" + return "`" + strings.ReplaceAll(name, "`", "``") + "`" } // FormatTableName formats a table name, typically by lower or upper casing it, depending on the database diff --git a/sqlconnect/internal/redshift/db.go b/sqlconnect/internal/redshift/db.go index 9e477b5..999a15b 100644 --- a/sqlconnect/internal/redshift/db.go +++ b/sqlconnect/internal/redshift/db.go @@ -1,9 +1,11 @@ package redshift import ( + "context" "database/sql" "encoding/json" "fmt" + "time" _ "github.com/lib/pq" // postgres driver "github.com/samber/lo" @@ -38,11 +40,19 @@ func NewDB(credentialsJSON json.RawMessage) (*DB, error) { if err != nil { return nil, err } + var caseSensitive string + + func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = db.QueryRowContext(ctx, "show enable_case_sensitive_identifier").Scan(&caseSensitive) + }() return &DB{ DB: base.NewDB( db, tunnelCloser, + base.WithDialect(dialect{caseSensitive: caseSensitive == "on"}), base.WithColumnTypeMappings(getColumnTypeMappings(useLegacyMappings)), base.WithJsonRowMapper(getJonRowMapper(useLegacyMappings)), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { @@ -53,25 +63,25 @@ func NewDB(credentialsJSON json.RawMessage) (*DB, error) { return "SELECT schema_name FROM svv_all_schemas", "schema_name" } cmds.SchemaExists = func(schema base.UnquotedIdentifier) string { - return fmt.Sprintf("SELECT schema_name FROM svv_all_schemas WHERE schema_name = '%[1]s'", schema) + return fmt.Sprintf("SELECT schema_name FROM svv_all_schemas WHERE schema_name = '%[1]s'", base.EscapeSqlString(schema)) } cmds.ListTables = func(schema base.UnquotedIdentifier) (sqlAndColumnNamePairs []lo.Tuple2[string, string]) { return []lo.Tuple2[string, string]{ - {A: fmt.Sprintf("SELECT table_name FROM svv_all_tables WHERE schema_name = '%[1]s'", schema), B: "table_name"}, + {A: fmt.Sprintf("SELECT table_name FROM svv_all_tables WHERE schema_name = '%[1]s'", base.EscapeSqlString(schema)), B: "table_name"}, } } cmds.ListTablesWithPrefix = func(schema base.UnquotedIdentifier, prefix string) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: fmt.Sprintf("SELECT table_name FROM svv_all_tables WHERE schema_name='%[1]s' AND table_name LIKE '%[2]s'", schema, prefix+"%"), B: "table_name"}, + {A: fmt.Sprintf("SELECT table_name FROM svv_all_tables WHERE schema_name='%[1]s' AND table_name LIKE '%[2]s'", base.EscapeSqlString(schema), prefix+"%"), B: "table_name"}, } } cmds.TableExists = func(schema, table base.UnquotedIdentifier) string { - return fmt.Sprintf("SELECT table_name FROM svv_all_tables WHERE schema_name='%[1]s' and table_name = '%[2]s'", schema, table) + return fmt.Sprintf("SELECT table_name FROM svv_all_tables WHERE schema_name='%[1]s' and table_name = '%[2]s'", base.EscapeSqlString(schema), base.EscapeSqlString(table)) } cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) { - stmt := fmt.Sprintf("SELECT column_name, data_type FROM SVV_ALL_COLUMNS WHERE schema_name = '%[1]s' AND table_name = '%[2]s'", schema, table) + stmt := fmt.Sprintf("SELECT column_name, data_type FROM SVV_ALL_COLUMNS WHERE schema_name = '%[1]s' AND table_name = '%[2]s'", base.EscapeSqlString(schema), base.EscapeSqlString(table)) if catalog != "" { - stmt += fmt.Sprintf(" AND database_name = '%[1]s'", catalog) + stmt += fmt.Sprintf(" AND database_name = '%[1]s'", base.EscapeSqlString(catalog)) } return stmt + " ORDER BY ordinal_position ASC", "column_name", "data_type" } diff --git a/sqlconnect/internal/redshift/dialect.go b/sqlconnect/internal/redshift/dialect.go new file mode 100644 index 0000000..56edd97 --- /dev/null +++ b/sqlconnect/internal/redshift/dialect.go @@ -0,0 +1,47 @@ +package redshift + +import ( + "fmt" + "strings" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base" +) + +type dialect struct { + caseSensitive bool +} + +// QuoteTable quotes a table name +func (d dialect) QuoteTable(table sqlconnect.RelationRef) string { + if table.Schema != "" { + return d.QuoteIdentifier(table.Schema) + "." + d.QuoteIdentifier(table.Name) + } + return d.QuoteIdentifier(table.Name) +} + +// QuoteIdentifier quotes an identifier, e.g. a column name +func (d dialect) QuoteIdentifier(name string) string { + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(name, `"`, `""`)) +} + +// FormatTableName formats a table name, typically by lower or upper casing it, depending on the database +func (d dialect) FormatTableName(name string) string { + return strings.ToLower(name) +} + +// NormaliseIdentifier normalises all identifier parts by lower casing them. +func (d dialect) NormaliseIdentifier(identifier string) string { + if d.caseSensitive { + return base.NormaliseIdentifier(identifier, '"', strings.ToLower) + } + // ASCII letters in standard and delimited identifiers are case-insensitive and are folded to lowercase in the database + // https://docs.aws.amazon.com/redshift/latest/dg/r_names.html#:~:text=ASCII%20letters%20in%20standard%20and%20delimited%20identifiers%20are%20case%2Dinsensitive%20and%20are%20folded%20to%20lowercase%20in%20the%20database + return strings.ToLower(identifier) +} + +// ParseRelationRef parses a string into a RelationRef after normalising the identifier and stripping out surrounding quotes. +// The result is a RelationRef with case-sensitive fields, i.e. it can be safely quoted (see [QuoteTable] and, for instance, used for matching against the database's information schema. +func (d dialect) ParseRelationRef(identifier string) (sqlconnect.RelationRef, error) { + return base.ParseRelationRef(strings.ToLower(identifier), '"', strings.ToLower) +} diff --git a/sqlconnect/internal/redshift/dialect_test.go b/sqlconnect/internal/redshift/dialect_test.go new file mode 100644 index 0000000..1ab1fed --- /dev/null +++ b/sqlconnect/internal/redshift/dialect_test.go @@ -0,0 +1,69 @@ +package redshift + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestDialect(t *testing.T) { + var d dialect + t.Run("format table", func(t *testing.T) { + formatted := d.FormatTableName("TaBle") + require.Equal(t, "table", formatted, "table name should be lowercased") + }) + + t.Run("quote identifier", func(t *testing.T) { + quoted := d.QuoteIdentifier("column") + require.Equal(t, `"column"`, quoted, "column name should be quoted with double quotes") + }) + + t.Run("quote table", func(t *testing.T) { + quoted := d.QuoteTable(sqlconnect.NewRelationRef("table")) + require.Equal(t, `"table"`, quoted, "table name should be quoted with double quotes") + + quoted = d.QuoteTable(sqlconnect.NewSchemaTableRef("schema", "table")) + require.Equal(t, `"schema"."table"`, quoted, "schema and table name should be quoted with double quotes") + }) + + t.Run("normalise identifier", func(t *testing.T) { + normalised := d.NormaliseIdentifier("column") + require.Equal(t, "column", normalised, "column name should be normalised to lowercase") + + normalised = d.NormaliseIdentifier("COLUMN") + require.Equal(t, "column", normalised, "column name should be normalised to lowercase") + + normalised = d.NormaliseIdentifier(`"ColUmn"`) + require.Equal(t, `"column"`, normalised, "quoted column name should be normalised to lowercase") + + normalised = d.NormaliseIdentifier(`TaBle."ColUmn"`) + require.Equal(t, `table."column"`, normalised, "all parts should be normalised") + + normalised = d.NormaliseIdentifier(`"Sh""EmA".TABLE."ColUmn"`) + require.Equal(t, `"sh""ema".table."column"`, normalised, "all parts should be normalised") + }) + + t.Run("parse relation", func(t *testing.T) { + parsed, err := d.ParseRelationRef("table") + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Name: "table"}, parsed) + + parsed, err = d.ParseRelationRef("TABLE") + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Name: "table"}, parsed) + + parsed, err = d.ParseRelationRef(`"TaBle"`) + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Name: `table`}, parsed) + + parsed, err = d.ParseRelationRef(`ScHeMA."TaBle"`) + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Schema: "schema", Name: "table"}, parsed) + + parsed, err = d.ParseRelationRef(`"CaTa""LoG".ScHeMA."TaBle"`) + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Catalog: "cata\"log", Schema: "schema", Name: "table"}, parsed) + }) +} diff --git a/sqlconnect/internal/snowflake/db.go b/sqlconnect/internal/snowflake/db.go index c02c32e..f73d821 100644 --- a/sqlconnect/internal/snowflake/db.go +++ b/sqlconnect/internal/snowflake/db.go @@ -47,7 +47,7 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { } cmds.ListSchemas = func() (string, string) { return "SHOW TERSE SCHEMAS", "name" } cmds.SchemaExists = func(schema base.UnquotedIdentifier) string { - return fmt.Sprintf("SHOW TERSE SCHEMAS LIKE '%[1]s'", schema) + return fmt.Sprintf("SHOW TERSE SCHEMAS LIKE '%[1]s'", base.EscapeSqlString(schema)) } cmds.ListTables = func(schema base.UnquotedIdentifier) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ @@ -61,7 +61,7 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { } } cmds.TableExists = func(schema, table base.UnquotedIdentifier) string { - return fmt.Sprintf("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '%[1]s' AND TABLE_NAME = '%[2]s'", schema, table) + return fmt.Sprintf("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '%[1]s' AND TABLE_NAME = '%[2]s'", base.EscapeSqlString(schema), base.EscapeSqlString(table)) } cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) { if catalog != "" { diff --git a/sqlconnect/internal/snowflake/dialect.go b/sqlconnect/internal/snowflake/dialect.go index 7781576..60f5bf3 100644 --- a/sqlconnect/internal/snowflake/dialect.go +++ b/sqlconnect/internal/snowflake/dialect.go @@ -19,7 +19,7 @@ func (d dialect) QuoteTable(table sqlconnect.RelationRef) string { // QuoteIdentifier quotes an identifier, e.g. a column name func (d dialect) QuoteIdentifier(name string) string { - return `"` + name + `"` + return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` } // FormatTableName formats a table name, typically by lower or upper casing it, depending on the database diff --git a/sqlconnect/internal/trino/db.go b/sqlconnect/internal/trino/db.go index 0054a75..c3cdd95 100644 --- a/sqlconnect/internal/trino/db.go +++ b/sqlconnect/internal/trino/db.go @@ -45,6 +45,7 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { DB: base.NewDB( db, tunnelCloser, + base.WithDialect(dialect{}), base.WithColumnTypeMapper(columnTypeMapper), base.WithJsonRowMapper(jsonRowMapper), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { @@ -59,7 +60,7 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { } } cmds.TableExists = func(schema, table base.UnquotedIdentifier) string { - return fmt.Sprintf(`SHOW TABLES FROM "%[1]s" LIKE '%[2]s'`, schema, table) + return fmt.Sprintf(`SHOW TABLES FROM "%[1]s" LIKE '%[2]s'`, schema, base.EscapeSqlString(table)) } cmds.TruncateTable = func(table base.QuotedIdentifier) string { return fmt.Sprintf(`DELETE FROM %[1]s`, table) diff --git a/sqlconnect/internal/trino/dialect.go b/sqlconnect/internal/trino/dialect.go new file mode 100644 index 0000000..0d870c9 --- /dev/null +++ b/sqlconnect/internal/trino/dialect.go @@ -0,0 +1,42 @@ +package trino + +import ( + "fmt" + "strings" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base" +) + +type dialect struct{} + +// QuoteTable quotes a table name +func (d dialect) QuoteTable(table sqlconnect.RelationRef) string { + if table.Schema != "" { + return d.QuoteIdentifier(table.Schema) + "." + d.QuoteIdentifier(table.Name) + } + return d.QuoteIdentifier(table.Name) +} + +// QuoteIdentifier quotes an identifier, e.g. a column name +func (d dialect) QuoteIdentifier(name string) string { + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(name, `"`, `""`)) +} + +// FormatTableName formats a table name, typically by lower or upper casing it, depending on the database +func (d dialect) FormatTableName(name string) string { + return strings.ToLower(name) +} + +// NormaliseIdentifier normalises all identifier parts by lower casing them. +func (d dialect) NormaliseIdentifier(identifier string) string { + // Identifiers are not treated as case sensitive. + // https://trino.io/docs/current/language/reserved.html#:~:text=Identifiers%20are%20not%20treated%20as%20case%20sensitive. + return strings.ToLower(identifier) +} + +// ParseRelationRef parses a string into a RelationRef after normalising the identifier and stripping out surrounding quotes. +// The result is a RelationRef with case-sensitive fields, i.e. it can be safely quoted (see [QuoteTable] and, for instance, used for matching against the database's information schema. +func (d dialect) ParseRelationRef(identifier string) (sqlconnect.RelationRef, error) { + return base.ParseRelationRef(strings.ToLower(identifier), '"', strings.ToLower) +} diff --git a/sqlconnect/internal/trino/dialect_test.go b/sqlconnect/internal/trino/dialect_test.go new file mode 100644 index 0000000..c2eb010 --- /dev/null +++ b/sqlconnect/internal/trino/dialect_test.go @@ -0,0 +1,69 @@ +package trino + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestDialect(t *testing.T) { + var d dialect + t.Run("format table", func(t *testing.T) { + formatted := d.FormatTableName("TaBle") + require.Equal(t, "table", formatted, "table name should be lowercased") + }) + + t.Run("quote identifier", func(t *testing.T) { + quoted := d.QuoteIdentifier("column") + require.Equal(t, `"column"`, quoted, "column name should be quoted with double quotes") + }) + + t.Run("quote table", func(t *testing.T) { + quoted := d.QuoteTable(sqlconnect.NewRelationRef("table")) + require.Equal(t, `"table"`, quoted, "table name should be quoted with double quotes") + + quoted = d.QuoteTable(sqlconnect.NewSchemaTableRef("schema", "table")) + require.Equal(t, `"schema"."table"`, quoted, "schema and table name should be quoted with double quotes") + }) + + t.Run("normalise identifier", func(t *testing.T) { + normalised := d.NormaliseIdentifier("column") + require.Equal(t, "column", normalised, "column name should be normalised to lowercase") + + normalised = d.NormaliseIdentifier("COLUMN") + require.Equal(t, "column", normalised, "column name should be normalised to lowercase") + + normalised = d.NormaliseIdentifier(`"ColUmn"`) + require.Equal(t, `"column"`, normalised, "quoted column name should be normalised to lowercase") + + normalised = d.NormaliseIdentifier(`TaBle."ColUmn"`) + require.Equal(t, `table."column"`, normalised, "all parts should be normalised") + + normalised = d.NormaliseIdentifier(`"Sh""EmA".TABLE."ColUmn"`) + require.Equal(t, `"sh""ema".table."column"`, normalised, "all parts should be normalised") + }) + + t.Run("parse relation", func(t *testing.T) { + parsed, err := d.ParseRelationRef("table") + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Name: "table"}, parsed) + + parsed, err = d.ParseRelationRef("TABLE") + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Name: "table"}, parsed) + + parsed, err = d.ParseRelationRef(`"TaBle"`) + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Name: `table`}, parsed) + + parsed, err = d.ParseRelationRef(`ScHeMA."TaBle"`) + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Schema: "schema", Name: "table"}, parsed) + + parsed, err = d.ParseRelationRef(`"CaTa""LoG".ScHeMA."TaBle"`) + require.NoError(t, err) + require.Equal(t, sqlconnect.RelationRef{Catalog: "cata\"log", Schema: "schema", Name: "table"}, parsed) + }) +} diff --git a/sqlconnect/internal/trino/integration_test.go b/sqlconnect/internal/trino/integration_test.go index fc8acc2..1de5ed5 100644 --- a/sqlconnect/internal/trino/integration_test.go +++ b/sqlconnect/internal/trino/integration_test.go @@ -24,7 +24,9 @@ func TestTrinoDB(t *testing.T) { trino.DatabaseType, []byte(configJSON), strings.ToLower, - integrationtest.Options{}, + integrationtest.Options{ + SpecialCharactersInQuotedTable: "_12", // No special characters allowed in table names :/ + }, ) integrationtest.TestSshTunnelScenarios(t, trino.DatabaseType, []byte(configJSON)) diff --git a/sqlconnect/querydef_test.go b/sqlconnect/querydef_test.go index ed9c8ac..024c699 100644 --- a/sqlconnect/querydef_test.go +++ b/sqlconnect/querydef_test.go @@ -2,6 +2,7 @@ package sqlconnect_test import ( "fmt" + "strings" "testing" "github.com/stretchr/testify/require" @@ -54,7 +55,7 @@ func (d testDialect) FormatTableName(name string) string { } func (d testDialect) QuoteIdentifier(name string) string { - return fmt.Sprintf(`"%s"`, name) + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(name, `"`, `""`)) } func (d testDialect) QuoteTable(relation sqlconnect.RelationRef) string {