diff --git a/pgx_test.go b/pgx_test.go index b9a85a8..68f6cf2 100644 --- a/pgx_test.go +++ b/pgx_test.go @@ -16,8 +16,7 @@ func init() { } func TestAdapterPgx_specs(t *testing.T) { - driverName = "pgx" - adapter := MustOpen(dsn()) + adapter := MustOpen(dsn(), WithDriver("pgx")) defer adapter.Close() repo := rel.New(adapter) @@ -25,24 +24,21 @@ func TestAdapterPgx_specs(t *testing.T) { } func TestAdapterPgx_Transaction_commitError(t *testing.T) { - driverName = "pgx" - adapter := MustOpen(dsn()) + adapter := MustOpen(dsn(), WithDriver("pgx")) defer adapter.Close() assert.NotNil(t, adapter.Commit(ctx)) } func TestAdapterPgx_Transaction_rollbackError(t *testing.T) { - driverName = "pgx" - adapter := MustOpen(dsn()) + adapter := MustOpen(dsn(), WithDriver("pgx")) defer adapter.Close() assert.NotNil(t, adapter.Rollback(ctx)) } func TestAdapterPgx_Exec_error(t *testing.T) { - driverName = "pgx" - adapter := MustOpen(dsn()) + adapter := MustOpen(dsn(), WithDriver("pgx")) defer adapter.Close() _, _, err := adapter.Exec(ctx, "error", nil) @@ -51,7 +47,6 @@ func TestAdapterPgx_Exec_error(t *testing.T) { func TestAdapterPgx_InvalidDriverPanic(t *testing.T) { assert.Panics(t, func() { - driverName = "pgx/v4" - MustOpen("postgres://test:test@localhost:1111/test?sslmode=disable&timezone=Asia/Jakarta") + MustOpen("postgres://test:test@localhost:1111/test?sslmode=disable&timezone=Asia/Jakarta", WithDriver("pgx/v4")) }) } diff --git a/postgres.go b/postgres.go index bbb28f9..f59565f 100644 --- a/postgres.go +++ b/postgres.go @@ -16,6 +16,7 @@ package postgres import ( "context" db "database/sql" + "slices" "time" "github.com/go-rel/rel" @@ -31,8 +32,6 @@ type Postgres struct { // Name of database type this adapter implements. const Name string = "postgres" -var driverName string = "postgres" - // New postgres adapter using existing connection. func New(database *db.DB) rel.Adapter { var ( @@ -65,15 +64,37 @@ func New(database *db.DB) rel.Adapter { } } +type OpenOpt struct { + driver string +} + +func WithDriver(driver string) OpenOpt { + return OpenOpt{driver: driver} +} + // Open postgres connection using dsn. -func Open(dsn string) (rel.Adapter, error) { +func Open(dsn string, opts ...OpenOpt) (rel.Adapter, error) { + // Default to postgres driver + driverName := "postgres" + + // Identify if pgx driver is available and default to that instead. + if slices.Contains(db.Drivers(), "pgx") { + driverName = "pgx" + } + + for _, opts := range opts { + if opts.driver != "" { + driverName = opts.driver + } + } + database, err := db.Open(driverName, dsn) return New(database), err } // MustOpen postgres connection using dsn. -func MustOpen(dsn string) rel.Adapter { - adapter, err := Open(dsn) +func MustOpen(dsn string, opts ...OpenOpt) rel.Adapter { + adapter, err := Open(dsn, opts...) if err != nil { panic(err) } @@ -195,13 +216,3 @@ func columnMapper(column *rel.Column) (string, int, int) { return typ, m, n } - -func init() { - // Identify if pgx driver is available and default to that instead. - for _, drv := range db.Drivers() { - if drv == "pgx" { - driverName = "pgx" - break - } - } -} diff --git a/postgres_test.go b/postgres_test.go index 371e250..d15768b 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -29,8 +29,7 @@ func dsn() string { } func TestAdapter_Name(t *testing.T) { - driverName = "postgres" - adapter := MustOpen(dsn()) + adapter := MustOpen(dsn(), WithDriver("postgres")) defer adapter.Close() assert.Equal(t, Name, adapter.Name()) @@ -115,8 +114,7 @@ func TestAdapter_specs(t *testing.T) { return } - driverName = "postgres" - adapter := MustOpen(dsn()) + adapter := MustOpen(dsn(), WithDriver("postgres")) defer adapter.Close() repo := rel.New(adapter) @@ -129,7 +127,6 @@ func TestAdapter_PrimaryReplica_specs(t *testing.T) { return } - driverName = "postgres" adapter := primaryreplica.New( MustOpen("postgres://rel:rel@localhost:25432/rel_test?sslmode=disable&timezone=Asia/Jakarta"), MustOpen("postgres://rel:rel@localhost:25433/rel_test?sslmode=disable&timezone=Asia/Jakarta"), @@ -142,24 +139,21 @@ func TestAdapter_PrimaryReplica_specs(t *testing.T) { } func TestAdapter_Transaction_commitError(t *testing.T) { - driverName = "postgres" - adapter := MustOpen(dsn()) + adapter := MustOpen(dsn(), WithDriver("postgres")) defer adapter.Close() assert.NotNil(t, adapter.Commit(ctx)) } func TestAdapter_Transaction_rollbackError(t *testing.T) { - driverName = "postgres" - adapter := MustOpen(dsn()) + adapter := MustOpen(dsn(), WithDriver("postgres")) defer adapter.Close() assert.NotNil(t, adapter.Rollback(ctx)) } func TestAdapter_Exec_error(t *testing.T) { - driverName = "postgres" - adapter, err := Open(dsn()) + adapter, err := Open(dsn(), WithDriver("postgres")) assert.Nil(t, err) defer adapter.Close() @@ -168,8 +162,7 @@ func TestAdapter_Exec_error(t *testing.T) { } func TestAdapter_TableBuilder(t *testing.T) { - driverName = "postgres" - adapter := MustOpen(dsn()) + adapter := MustOpen(dsn(), WithDriver("postgres")) defer adapter.Close() tests := []struct {