Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(databricks): normalize column names #241

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sqlconnect/internal/databricks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ type Config struct {
SessionParams map[string]string `json:"sessionParams"`

UseLegacyMappings bool `json:"useLegacyMappings"`
// SkipColumnNormalization skips normalizing column names during ListColumns and ListColumnsForSqlQuery.
// Databricks is returning column names case sensitive from information schema, even though it is case insensitive.
// So, by default table names are returned normalized by databricks, whereas column names are not.
// To avoid this inconsistency, we are normalizing column names by default.
// If you want to skip this normalization, set this flag to true.
SkipColumnNormalization bool `json:"skipColumnNormalisation"`
}

func (c *Config) Parse(input json.RawMessage) error {
Expand Down
6 changes: 4 additions & 2 deletions sqlconnect/internal/databricks/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ func NewDB(configJson json.RawMessage) (*DB, error) {
return cmds
}),
),
informationSchema: informationSchema,
informationSchema: informationSchema,
skipColumnNormalization: config.SkipColumnNormalization,
}, nil
}

Expand All @@ -132,7 +133,8 @@ func init() {

type DB struct {
*base.DB
informationSchema bool
informationSchema bool
skipColumnNormalization bool
}

func getColumnTypeMapper(config Config) func(base.ColumnType) string {
Expand Down
22 changes: 21 additions & 1 deletion sqlconnect/internal/databricks/tableadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"fmt"
"strings"

"github.com/samber/lo"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

Expand All @@ -19,7 +21,25 @@
return nil, fmt.Errorf("catalog %s not found", relation.Catalog)
}
}
return db.DB.ListColumns(ctx, relation)
cols, err := db.DB.ListColumns(ctx, relation)
if db.skipColumnNormalization {
return cols, err
}

Check warning on line 27 in sqlconnect/internal/databricks/tableadmin.go

View check run for this annotation

Codecov / codecov/patch

sqlconnect/internal/databricks/tableadmin.go#L26-L27

Added lines #L26 - L27 were not covered by tests
return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
col.Name = db.NormaliseIdentifier(col.Name)
return col
}), err
}

func (db *DB) ListColumnsForSqlQuery(ctx context.Context, sql string) ([]sqlconnect.ColumnRef, error) {
cols, err := db.DB.ListColumnsForSqlQuery(ctx, sql)
if db.skipColumnNormalization {
return cols, err
}

Check warning on line 38 in sqlconnect/internal/databricks/tableadmin.go

View check run for this annotation

Codecov / codecov/patch

sqlconnect/internal/databricks/tableadmin.go#L37-L38

Added lines #L37 - L38 were not covered by tests
return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
col.Name = db.NormaliseIdentifier(col.Name)
return col
}), err
}

// RenameTable in databricks falls back to MoveTable if rename is not supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package integrationtest

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -609,6 +610,75 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe
{Name: formatfn("c2"), Type: "string"},
}, "it should return the correct columns")
})

t.Run("list columns with mixed case", func(t *testing.T) {
unquotedColumn := "cOluMnA"
normalizedUnquotedColumn := db.NormaliseIdentifier(unquotedColumn)
quotedColumn := db.QuoteIdentifier("QuOted_CoLuMnB")
normalizedQuotedColumn := db.NormaliseIdentifier(quotedColumn)
parsedRel, err := db.ParseRelationRef(quotedColumn)
require.NoError(t, err, "it should be able to parse a quoted column")
normalizedQuotedColumnWithoutQuotes := parsedRel.Name

tableIdentifier := db.QuoteIdentifier(schema.Name) + "." + db.QuoteIdentifier("table_mixed_case")
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %[1]s (%[2]s int, %[3]s int)", tableIdentifier, unquotedColumn, quotedColumn))
require.NoErrorf(t, err, "it should be able to create a quoted table: %s", tableIdentifier)

table, err := db.ParseRelationRef(tableIdentifier)
require.NoError(t, err, "it should be able to parse a quoted table")

t.Run("without catalog", func(t *testing.T) {
columns, err := db.ListColumns(ctx, table)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: normalizedUnquotedColumn, Type: "int"},
{Name: normalizedQuotedColumnWithoutQuotes, Type: "int"},
}, "it should return the correct columns")

var c1, c2 int
err = db.QueryRow(fmt.Sprintf("SELECT %[1]s, %[2]s FROM %[3]s", normalizedUnquotedColumn, normalizedQuotedColumn, tableIdentifier)).Scan(&c1, &c2)
require.ErrorIs(t, err, sql.ErrNoRows, "it should get a no rows error (supports normalised column names)")
})

t.Run("with catalog", func(t *testing.T) {
tableWithCatalog := table
tableWithCatalog.Catalog = currentCatalog
columns, err := db.ListColumns(ctx, tableWithCatalog)
require.NoErrorf(t, err, "it should be able to list columns for %s", tableWithCatalog)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})

require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: normalizedUnquotedColumn, Type: "int"},
{Name: normalizedQuotedColumnWithoutQuotes, Type: "int"},
}, "it should return the correct columns")
})

t.Run("for sql query", func(t *testing.T) {
columns, err := db.ListColumnsForSqlQuery(ctx, "SELECT * FROM "+tableIdentifier)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: normalizedUnquotedColumn, Type: "int"},
{Name: normalizedQuotedColumnWithoutQuotes, Type: "int"},
}, "it should return the correct columns")
})
})
})

t.Run("list columns for sql query", func(t *testing.T) {
Expand Down
Loading