diff --git a/sqlconnect/internal/databricks/config.go b/sqlconnect/internal/databricks/config.go index 4b10ac7..09c872e 100644 --- a/sqlconnect/internal/databricks/config.go +++ b/sqlconnect/internal/databricks/config.go @@ -27,6 +27,12 @@ type Config struct { SessionParams map[string]string `json:"sessionParams"` UseLegacyMappings bool `json:"useLegacyMappings"` + // SkipColumnNormalization skips normalizing column names during ListColumns and ListColumnsForSqlQuery. + // Databricks is returning column names case sensitive from information schema, even though it is case insensitive. + // So, by default table names are returned normalized by databricks, whereas column names are not. + // To avoid this inconsistency, we are normalizing column names by default. + // If you want to skip this normalization, set this flag to true. + SkipColumnNormalization bool `json:"skipColumnNormalisation"` } func (c *Config) Parse(input json.RawMessage) error { diff --git a/sqlconnect/internal/databricks/db.go b/sqlconnect/internal/databricks/db.go index cd781f6..d1349c5 100644 --- a/sqlconnect/internal/databricks/db.go +++ b/sqlconnect/internal/databricks/db.go @@ -120,7 +120,8 @@ func NewDB(configJson json.RawMessage) (*DB, error) { return cmds }), ), - informationSchema: informationSchema, + informationSchema: informationSchema, + skipColumnNormalization: config.SkipColumnNormalization, }, nil } @@ -132,7 +133,8 @@ func init() { type DB struct { *base.DB - informationSchema bool + informationSchema bool + skipColumnNormalization bool } func getColumnTypeMapper(config Config) func(base.ColumnType) string { diff --git a/sqlconnect/internal/databricks/tableadmin.go b/sqlconnect/internal/databricks/tableadmin.go index d43c050..1fabf6d 100644 --- a/sqlconnect/internal/databricks/tableadmin.go +++ b/sqlconnect/internal/databricks/tableadmin.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" + "github.com/samber/lo" + "github.com/rudderlabs/sqlconnect-go/sqlconnect" ) @@ -19,7 +21,25 @@ func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) return nil, fmt.Errorf("catalog %s not found", relation.Catalog) } } - return db.DB.ListColumns(ctx, relation) + cols, err := db.DB.ListColumns(ctx, relation) + if db.skipColumnNormalization { + return cols, err + } + return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + col.Name = db.NormaliseIdentifier(col.Name) + return col + }), err +} + +func (db *DB) ListColumnsForSqlQuery(ctx context.Context, sql string) ([]sqlconnect.ColumnRef, error) { + cols, err := db.DB.ListColumnsForSqlQuery(ctx, sql) + if db.skipColumnNormalization { + return cols, err + } + return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + col.Name = db.NormaliseIdentifier(col.Name) + return col + }), err } // RenameTable in databricks falls back to MoveTable if rename is not supported diff --git a/sqlconnect/internal/integration_test/db_integration_test_scenario.go b/sqlconnect/internal/integration_test/db_integration_test_scenario.go index 036d847..1cc5c52 100644 --- a/sqlconnect/internal/integration_test/db_integration_test_scenario.go +++ b/sqlconnect/internal/integration_test/db_integration_test_scenario.go @@ -2,6 +2,7 @@ package integrationtest import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -609,6 +610,75 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe {Name: formatfn("c2"), Type: "string"}, }, "it should return the correct columns") }) + + t.Run("list columns with mixed case", func(t *testing.T) { + unquotedColumn := "cOluMnA" + normalizedUnquotedColumn := db.NormaliseIdentifier(unquotedColumn) + quotedColumn := db.QuoteIdentifier("QuOted_CoLuMnB") + normalizedQuotedColumn := db.NormaliseIdentifier(quotedColumn) + parsedRel, err := db.ParseRelationRef(quotedColumn) + require.NoError(t, err, "it should be able to parse a quoted column") + normalizedQuotedColumnWithoutQuotes := parsedRel.Name + + tableIdentifier := db.QuoteIdentifier(schema.Name) + "." + db.QuoteIdentifier("table_mixed_case") + _, err = db.Exec(fmt.Sprintf("CREATE TABLE %[1]s (%[2]s int, %[3]s int)", tableIdentifier, unquotedColumn, quotedColumn)) + require.NoErrorf(t, err, "it should be able to create a quoted table: %s", tableIdentifier) + + table, err := db.ParseRelationRef(tableIdentifier) + require.NoError(t, err, "it should be able to parse a quoted table") + + t.Run("without catalog", func(t *testing.T) { + columns, err := db.ListColumns(ctx, table) + columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name) + col.RawType = "" + return col + }) + require.NoError(t, err, "it should be able to list columns") + require.Len(t, columns, 2, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: normalizedUnquotedColumn, Type: "int"}, + {Name: normalizedQuotedColumnWithoutQuotes, Type: "int"}, + }, "it should return the correct columns") + + var c1, c2 int + err = db.QueryRow(fmt.Sprintf("SELECT %[1]s, %[2]s FROM %[3]s", normalizedUnquotedColumn, normalizedQuotedColumn, tableIdentifier)).Scan(&c1, &c2) + require.ErrorIs(t, err, sql.ErrNoRows, "it should get a no rows error (supports normalised column names)") + }) + + t.Run("with catalog", func(t *testing.T) { + tableWithCatalog := table + tableWithCatalog.Catalog = currentCatalog + columns, err := db.ListColumns(ctx, tableWithCatalog) + require.NoErrorf(t, err, "it should be able to list columns for %s", tableWithCatalog) + columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name) + col.RawType = "" + return col + }) + + require.Len(t, columns, 2, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: normalizedUnquotedColumn, Type: "int"}, + {Name: normalizedQuotedColumnWithoutQuotes, Type: "int"}, + }, "it should return the correct columns") + }) + + t.Run("for sql query", func(t *testing.T) { + columns, err := db.ListColumnsForSqlQuery(ctx, "SELECT * FROM "+tableIdentifier) + columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name) + col.RawType = "" + return col + }) + require.NoError(t, err, "it should be able to list columns") + require.Len(t, columns, 2, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: normalizedUnquotedColumn, Type: "int"}, + {Name: normalizedQuotedColumnWithoutQuotes, Type: "int"}, + }, "it should return the correct columns") + }) + }) }) t.Run("list columns for sql query", func(t *testing.T) {