Skip to content

Commit

Permalink
add partial ALTER TABLE support (adding column and DEFAULT) goccy#113 (
Browse files Browse the repository at this point in the history
  • Loading branch information
myhau authored Jun 13, 2024
1 parent b6dd6a5 commit 9e293a3
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 9 deletions.
193 changes: 184 additions & 9 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,196 @@ import (
zetasqlite "github.com/goccy/go-zetasqlite"
)

func TestDriverAlter(t *testing.T) {
db, err := sql.Open("zetasqlite", ":memory:")
if err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS Artists (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX)
)
`); err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`INSERT Artists (SingerId, FirstName, LastName) VALUES (1, 'John', 'Titor')`); err != nil {
t.Fatal(err)
}
row := db.QueryRow(`SELECT SingerId, FirstName, LastName FROM Artists WHERE SingerId = @id`, 1)
if row.Err() != nil {
t.Fatal(row.Err())
}
var (
singerID int64
firstName string
lastName string
)
if err := row.Scan(&singerID, &firstName, &lastName); err != nil {
t.Fatal(err)
}
if singerID != 1 || firstName != "John" || lastName != "Titor" {
t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName)
}

if _, err := db.Exec(`
CREATE VIEW IF NOT EXISTS
SingerNames AS SELECT FirstName || ' ' || LastName AS Name
FROM Artists
`); err != nil {
t.Fatal(err)
}

viewRow := db.QueryRow(`SELECT Name FROM SingerNames LIMIT 1`)
if viewRow.Err() != nil {
t.Fatal(viewRow.Err())
}

var name string

if err := viewRow.Scan(&name); err != nil {
t.Fatal(err)
}
if name != "John Titor" {
t.Fatalf("failed to find view row")
}

// Test ALTER TABLE SET OPTIONS
if _, err := db.Exec(`ALTER TABLE Artists SET OPTIONS (description="Famous Artists")`); err != nil {
t.Fatal(err)
}

// Test ALTER TABLE ADD COLUMN
if _, err := db.Exec(`ALTER TABLE Artists ADD COLUMN Age INT64, ADD COLUMN IsSingle BOOL`); err != nil {
t.Fatal(err)
}

// Verify the changes
row = db.QueryRow(`
SELECT SingerId, FirstName, LastName, Age, IsSingle
FROM Artists
WHERE SingerId = @id`,
1,
)
if row.Err() != nil {
t.Fatal(row.Err())
}

var age sql.NullInt64
var isSingle sql.NullBool
if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle); err != nil {
t.Fatal(err)
}
if singerID != 1 || firstName != "John" || lastName != "Titor" || age.Valid || isSingle.Valid {
t.Fatalf("failed to find row after ALTER TABLE statements")
}

if _, err := db.Exec(`
INSERT Artists (SingerId, FirstName, LastName, Age, IsSingle)
VALUES (2, 'Mike', 'Bit', 11, TRUE)
`); err != nil {
t.Fatal(err)
}
row = db.QueryRow(`
SELECT SingerId, FirstName, LastName, Age, isSingle
FROM Artists
WHERE SingerId = @id AND isSingle IS NOT NULL`,
2,
)
if row.Err() != nil {
t.Fatal(row.Err())
}
if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle); err != nil {
t.Fatal(err)
}
if singerID != 2 || firstName != "Mike" || lastName != "Bit" || age.Int64 != 11 || isSingle.Bool != true {
t.Fatalf("Failed to find row %v %v %v %v %v", singerID, firstName, lastName, age, isSingle)
}

if _, err := db.Exec(`
ALTER TABLE Artists
ADD COLUMN Nationality STRING
`); err != nil {
t.Fatal(err)
}

if _, err := db.Exec(`
ALTER TABLE Artists
ALTER COLUMN Nationality SET DEFAULT 'Unknown'
`); err != nil {
t.Fatal(err)
}

// Verify the changes
row = db.QueryRow(`
SELECT SingerID, FirstName, LastName, Age, IsSingle, Nationality
FROM Artists
WHERE SingerId = @id`,
2,
)
if row.Err() != nil {
t.Fatal(row.Err())
}

var nationality sql.NullString
if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle, &nationality); err != nil {
t.Fatal(err)
}

if singerID != 2 || firstName != "Mike" || lastName != "Bit" || age.Int64 != 11 || isSingle.Bool != true || nationality.Valid {
t.Fatalf("failed to find row after multi-action ALTER TABLE statement")
}

if _, err := db.Exec(`
INSERT Artists (SingerId, FirstName, LastName, Age, IsSingle)
VALUES (3, 'Mark', 'Byte', 12, FALSE)
`); err != nil {
t.Fatal(err)
}

// Verify the changes
row = db.QueryRow(`
SELECT SingerID, FirstName, LastName, Age, IsSingle, Nationality
FROM Artists
WHERE SingerId = @id`,
3,
)
if row.Err() != nil {
t.Fatal(row.Err())
}

if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle, &nationality); err != nil {
t.Fatal(err)
}
if singerID != 3 || firstName != "Mark" || lastName != "Byte" || age.Int64 != 12 || isSingle.Bool != false || nationality.String != "Unknown" {
t.Fatalf("failed to find row after multi-action ALTER TABLE statement")
}
}

func TestDriver(t *testing.T) {
db, err := sql.Open("zetasqlite", ":memory:")
if err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS Singers (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX)
)`); err != nil {
CREATE TABLE IF NOT EXISTS Singers (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX)
)
`); err != nil {
t.Fatal(err)
}
if _, err := db.Exec(`INSERT Singers (SingerId, FirstName, LastName) VALUES (1, 'John', 'Titor')`); err != nil {
if _, err := db.Exec(`
INSERT Singers (SingerId, FirstName, LastName)
VALUES (1, 'John', 'Titor')
`); err != nil {
t.Fatal(err)
}
row := db.QueryRow("SELECT SingerID, FirstName, LastName FROM Singers WHERE SingerId = @id", 1)
row := db.QueryRow(`SELECT SingerID, FirstName, LastName FROM Singers WHERE SingerId = @id`, 1)
if row.Err() != nil {
t.Fatal(row.Err())
}
Expand All @@ -43,7 +215,10 @@ CREATE TABLE IF NOT EXISTS Singers (
t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName)
}
if _, err := db.Exec(`
CREATE VIEW IF NOT EXISTS SingerNames AS SELECT FirstName || ' ' || LastName AS Name FROM Singers`); err != nil {
CREATE VIEW IF NOT EXISTS
SingerNames AS SELECT FirstName || ' ' || LastName AS Name
FROM Singers
`); err != nil {
t.Fatal(err)
}

Expand Down
24 changes: 24 additions & 0 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) {
zetasql.FeatureV11WithOnSubquery,
zetasql.FeatureV13Pivot,
zetasql.FeatureV13Unpivot,
zetasql.FeatureV13ColumnDefaultValue,
})
langOpt.SetSupportedStatementKinds([]ast.Kind{
ast.BeginStmt,
Expand All @@ -87,6 +88,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) {
ast.DropStmt,
ast.TruncateStmt,
ast.CreateTableStmt,
ast.AlterTableStmt,
ast.CreateTableAsSelectStmt,
ast.CreateProcedureStmt,
ast.CreateFunctionStmt,
Expand Down Expand Up @@ -290,10 +292,32 @@ func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []drive
return a.newBeginStmtAction(ctx, query, args, node)
case ast.CommitStmt:
return a.newCommitStmtAction(ctx, query, args, node)
case ast.AlterTableStmt:
return a.alterTableStmtAction(ctx, query, args, node.(*ast.AlterTableStmtNode))
}
return nil, fmt.Errorf("unsupported stmt %s", node.DebugString())
}

func (a *Analyzer) alterTableStmtAction(ctx context.Context, query string, args []driver.NamedValue, node *ast.AlterTableStmtNode) (*AlterTableStmtAction, error) {
spec, err := newAlterSpec(ctx, a.namePath, node)
if err != nil {
return nil, err
}
params := getParamsFromNode(node)
queryArgs, err := getArgsFromParams(args, params)
if err != nil {
return nil, err
}
return &AlterTableStmtAction{
query: query,
spec: spec,
node: node,
args: queryArgs,
rawArgs: args,
catalog: a.catalog,
}, nil
}

func (a *Analyzer) newCreateTableStmtAction(_ context.Context, query string, args []driver.NamedValue, node *ast.CreateTableStmtNode) (*CreateTableStmtAction, error) {
spec := newTableSpec(a.namePath, node)
params := getParamsFromNode(node)
Expand Down
36 changes: 36 additions & 0 deletions internal/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,42 @@ func (c *Catalog) addTableSpec(spec *TableSpec) error {
return nil
}

func (c *Catalog) modifyTableSpec(spec *AlterTableSpec) error {
tableName := spec.TableName()
foundSpecToUpdate, exists := c.tableMap[tableName]

if !exists {
return fmt.Errorf("table %s does not exist", tableName)
}

formattedPath := formatPath(spec.NamePath)

err := c.deleteTableSpecByName(formattedPath)
if err != nil {
return err
}

for _, column := range spec.ColumnsWithNewDefaultValue {
if foundSpecToUpdate.Column(column.ColumnName) == nil {
return fmt.Errorf("cannot update column %s to have a default value, table %s does not have this column", tableName, column.ColumnName)
}
}

addedColumns := make([]*ColumnSpec, len(foundSpecToUpdate.Columns))
copy(addedColumns, foundSpecToUpdate.Columns)
addedColumns = append(addedColumns, spec.AddedColumns...)

foundSpecToUpdate.Columns = addedColumns
foundSpecToUpdate.UpdatedAt = spec.UpdatedAt

err = c.addTableSpec(foundSpecToUpdate)
if err != nil {
return err
}

return nil
}

func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpec) error {
if len(spec.NamePath) > 1 {
subCatalogName := spec.NamePath[0]
Expand Down
64 changes: 64 additions & 0 deletions internal/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ type TableSpec struct {
CreatedAt time.Time `json:"createdAt"`
}

type ColumnWithDefaultSpec struct {
ColumnName string
DefaultValue string
}

type AlterTableSpec struct {
NamePath []string `json:"namePath"`
AddedColumns []*ColumnSpec `json:"addedColumns"`
ColumnsWithNewDefaultValue []*ColumnWithDefaultSpec `json:"columnsWithNewDefaultValue"`
UpdatedAt time.Time `json:"updatedAt"`
}

func (s *TableSpec) Column(name string) *ColumnSpec {
for _, col := range s.Columns {
if col.Name == name {
Expand All @@ -123,6 +135,10 @@ func (s *TableSpec) Column(name string) *ColumnSpec {
return nil
}

func (s *AlterTableSpec) TableName() string {
return formatPath(s.NamePath)
}

func (s *TableSpec) TableName() string {
return formatPath(s.NamePath)
}
Expand Down Expand Up @@ -513,6 +529,54 @@ func newPrimaryKey(key *ast.PrimaryKeyNode) []string {
return key.ColumnNameList()
}

func newAlterSpec(ctx context.Context, namePath *NamePath, stmt *ast.AlterTableStmtNode) (*AlterTableSpec, error) {
list := stmt.AlterActionList()
var columns []*ast.ColumnDefinitionNode
var columnsAddDefault []*ColumnWithDefaultSpec

var err error

for i := range list {
action := list[i]
if err != nil {
return nil, err
}
switch action.Kind() {
case ast.AddColumnAction | ast.AlterColumnSetDefaultAction:
err = fmt.Errorf("adding field with default value to an existing table schema is not supported")
case ast.AddColumnAction:
addColumnAction := action.(*ast.AddColumnActionNode)
columns = append(columns, addColumnAction.ColumnDefinition())
case ast.AlterColumnSetDefaultAction:
setDefaultAction := action.(*ast.AlterColumnSetDefaultActionNode)
columnName := setDefaultAction.Column()
defaultValueExpr := setDefaultAction.DefaultValue().Expression()
var defaultValue string
if defaultValueExpr != nil {
// TODO: figure out the timestamp thing here?
defaultValue, err = newNode(defaultValueExpr).FormatSQL(ctx) // assuming newNode has a method to format SQL
if err != nil {
return nil, fmt.Errorf("failed to format default value: %w", err)
}
}
columnsAddDefault = append(columnsAddDefault, &ColumnWithDefaultSpec{
ColumnName: columnName,
DefaultValue: defaultValue,
})
default:
err = fmt.Errorf("unknown alter action kind: %v", action.Kind())
}
}

now := time.Now()
return &AlterTableSpec{
NamePath: namePath.mergePath(stmt.NamePath()),
AddedColumns: newColumnsFromDef(columns),
ColumnsWithNewDefaultValue: columnsAddDefault,
UpdatedAt: now,
}, nil
}

func newTableSpec(namePath *NamePath, stmt *ast.CreateTableStmtNode) *TableSpec {
now := time.Now()
return &TableSpec{
Expand Down
Loading

0 comments on commit 9e293a3

Please sign in to comment.