Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dialect refactoring #914

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions database/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,28 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/pressly/goose/v3/internal/dialect"

"github.com/pressly/goose/v3/internal/dialect/dialectquery"
)

// Dialect is the type of database dialect.
type Dialect string
type Dialect = dialect.Dialect

var ErrUnknownDialect = dialect.ErrUnknownDialect

const (
DialectClickHouse Dialect = "clickhouse"
DialectMSSQL Dialect = "mssql"
DialectMySQL Dialect = "mysql"
DialectPostgres Dialect = "postgres"
DialectRedshift Dialect = "redshift"
DialectSQLite3 Dialect = "sqlite3"
DialectTiDB Dialect = "tidb"
DialectTurso Dialect = "turso"
DialectVertica Dialect = "vertica"
DialectYdB Dialect = "ydb"
DialectStarrocks Dialect = "starrocks"
DialectClickHouse Dialect = dialect.Clickhouse
DialectMSSQL Dialect = dialect.Mssql
DialectMySQL Dialect = dialect.Mysql
DialectPostgres Dialect = dialect.Postgres
DialectRedshift Dialect = dialect.Redshift
DialectSQLite3 Dialect = dialect.Sqlite3
DialectTiDB Dialect = dialect.Tidb
DialectTurso Dialect = dialect.Turso
DialectVertica Dialect = dialect.Vertica
DialectYdB Dialect = dialect.Ydb
DialectStarrocks Dialect = dialect.Starrocks
)

// NewStore returns a new [Store] implementation for the given dialect.
Expand All @@ -49,7 +52,7 @@ func NewStore(dialect Dialect, tablename string) (Store, error) {
}
querier, ok := lookup[dialect]
if !ok {
return nil, fmt.Errorf("unknown dialect: %q", dialect)
return nil, fmt.Errorf("%s: %w", dialect, ErrUnknownDialect)
}
return &store{
tablename: tablename,
Expand Down
78 changes: 34 additions & 44 deletions dialect.go
Original file line number Diff line number Diff line change
@@ -1,64 +1,54 @@
package goose

import (
"fmt"

"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/dialect"
)

// Dialect is the type of database dialect. It is an alias for [database.Dialect].
type Dialect = database.Dialect
// Dialect is the type of database dialect. It is an alias for [dialect.Dialect].
type Dialect = dialect.Dialect

const (
DialectClickHouse Dialect = database.DialectClickHouse
DialectMSSQL Dialect = database.DialectMSSQL
DialectMySQL Dialect = database.DialectMySQL
DialectPostgres Dialect = database.DialectPostgres
DialectRedshift Dialect = database.DialectRedshift
DialectSQLite3 Dialect = database.DialectSQLite3
DialectTiDB Dialect = database.DialectTiDB
DialectVertica Dialect = database.DialectVertica
DialectYdB Dialect = database.DialectYdB
DialectStarrocks Dialect = database.DialectStarrocks
DialectClickHouse Dialect = dialect.Clickhouse
DialectMSSQL Dialect = dialect.Mssql
DialectMySQL Dialect = dialect.Mysql
DialectPostgres Dialect = dialect.Postgres
DialectRedshift Dialect = dialect.Redshift
DialectSQLite3 Dialect = dialect.Sqlite3
DialectTiDB Dialect = dialect.Tidb
DialectVertica Dialect = dialect.Vertica
DialectYdB Dialect = dialect.Ydb
DialectTurso Dialect = dialect.Turso
DialectStarrocks Dialect = dialect.Starrocks
)

var ErrUnknownDialect = dialect.ErrUnknownDialect

// GetDialect gets the dialect used in the goose package.
var GetDialect = dialect.GetDialect

func init() {
store, _ = dialect.NewStore(dialect.Postgres)
}

var store dialect.Store

// SetDialect sets the dialect to use for the goose package.
func SetDialect(s string) error {
var d dialect.Dialect
switch s {
case "postgres", "pgx":
d = dialect.Postgres
case "mysql":
d = dialect.Mysql
case "sqlite3", "sqlite":
d = dialect.Sqlite3
case "mssql", "azuresql", "sqlserver":
d = dialect.Sqlserver
case "redshift":
d = dialect.Redshift
case "tidb":
d = dialect.Tidb
case "clickhouse":
d = dialect.Clickhouse
case "vertica":
d = dialect.Vertica
case "ydb":
d = dialect.Ydb
case "turso":
d = dialect.Turso
case "starrocks":
d = dialect.Starrocks
default:
return fmt.Errorf("%q: unknown dialect", s)
func SetDialect[D string | Dialect](d D) error {
var (
v Dialect
err error
)

switch t := any(d).(type) {
case string:
v, err = GetDialect(t)
if err != nil {
return err
}
case Dialect:
v = t
}
var err error
store, err = dialect.NewStore(d)

store, err = dialect.NewStore(v)
return err
}
41 changes: 41 additions & 0 deletions dialect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package goose_test

import (
"github.com/pressly/goose/v3"
"github.com/stretchr/testify/require"
"testing"
)

func TestGetDialect(t *testing.T) {
tests := []struct {
name string
want goose.Dialect
}{
{
name: "postgres",
want: goose.DialectPostgres,
},
{
name: "mysql",
want: goose.DialectMySQL,
},
{
name: "MySQL",
want: goose.DialectMySQL,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
dialect, err := goose.GetDialect(test.name)
require.NoError(t, err)
require.Equal(t, test.want, dialect)
})
}
}

func TestGetDialectFail(t *testing.T) {
dialect, err := goose.GetDialect("fail")
require.Empty(t, dialect)
require.ErrorIs(t, err, goose.ErrUnknownDialect)
require.EqualError(t, err, "fail: unknown dialect")
}
57 changes: 54 additions & 3 deletions internal/dialect/dialects.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
package dialect

import (
"errors"
"fmt"
"strings"
)

// Dialect is the type of database dialect.
type Dialect string

var ErrUnknownDialect = errors.New("unknown dialect")

const (
Postgres Dialect = "postgres"
Mysql Dialect = "mysql"
Sqlite3 Dialect = "sqlite3"
Postgres Dialect = "postgres"
Mysql Dialect = "mysql"
Sqlite3 Dialect = "sqlite3"
Mssql Dialect = "mssql"
// Deprecated: use [Mssql]
Sqlserver Dialect = "sqlserver"
Redshift Dialect = "redshift"
Tidb Dialect = "tidb"
Expand All @@ -16,3 +26,44 @@ const (
Turso Dialect = "turso"
Starrocks Dialect = "starrocks"
)

// GetDialect gets the dialect used in the goose package.
func GetDialect(s string) (Dialect, error) {
switch strings.ToLower(s) {
case "postgres", "pgx":
return Postgres, nil
case "mysql":
return Mysql, nil
case "sqlite3", "sqlite":
return Sqlite3, nil
case "mssql", "azuresql", "sqlserver":
return Mssql, nil
case "redshift":
return Redshift, nil
case "tidb":
return Tidb, nil
case "clickhouse":
return Clickhouse, nil
case "vertica":
return Vertica, nil
case "ydb":
return Ydb, nil
case "turso":
return Turso, nil
case "starrocks":
return Starrocks, nil
default:
return "", fmt.Errorf("%s: %w", s, ErrUnknownDialect)
}
}

func (d *Dialect) UnmarshalText(text []byte) error {
dialect, err := GetDialect(string(text))
if err != nil {
return err
}

*d = dialect

return nil
}
58 changes: 58 additions & 0 deletions internal/dialect/dialects_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package dialect_test

import (
"github.com/pressly/goose/v3/internal/dialect"
"github.com/stretchr/testify/require"
"testing"
)

var _testUnmarshalData = []struct {
name string
want dialect.Dialect
}{
{
name: "postgres",
want: dialect.Postgres,
},
{
name: "mysql",
want: dialect.Mysql,
},
{
name: "MySQL",
want: dialect.Mysql,
},
}

func TestDialect_GetDialect(t *testing.T) {
for _, test := range _testUnmarshalData {
t.Run(test.name, func(t *testing.T) {
d, err := dialect.GetDialect(test.name)
require.NoError(t, err)
require.Equal(t, test.want, d)
})
}
}

func TestDialect_GetDialectFail(t *testing.T) {
d, err := dialect.GetDialect("fail")
require.Empty(t, d)
require.ErrorIs(t, err, dialect.ErrUnknownDialect)
require.EqualError(t, err, "fail: unknown dialect")
}

func TestDialect_UnmarshalText(t *testing.T) {
for _, test := range _testUnmarshalData {
t.Run(test.name, func(t *testing.T) {
var d dialect.Dialect
require.NoError(t, d.UnmarshalText([]byte(test.name)))
})
}
}

func TestDialect_UnmarshalTextFail(t *testing.T) {
var d dialect.Dialect
var err = d.UnmarshalText([]byte("fail"))
require.ErrorIs(t, err, dialect.ErrUnknownDialect)
require.EqualError(t, err, "fail: unknown dialect")
}
2 changes: 1 addition & 1 deletion internal/dialect/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func NewStore(d Dialect) (Store, error) {
querier = &dialectquery.Mysql{}
case Sqlite3:
querier = &dialectquery.Sqlite3{}
case Sqlserver:
case Mssql, Sqlserver:
querier = &dialectquery.Sqlserver{}
case Redshift:
querier = &dialectquery.Redshift{}
Expand Down
Loading