Skip to content

Defer driver search to Open() #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions pgx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,29 @@ func init() {
}

func TestAdapterPgx_specs(t *testing.T) {
driverName = "pgx"
adapter := MustOpen(dsn())
adapter := MustOpen(dsn(), WithDriver("pgx"))
defer adapter.Close()

repo := rel.New(adapter)
AdapterSpecs(t, repo)
}

func TestAdapterPgx_Transaction_commitError(t *testing.T) {
driverName = "pgx"
adapter := MustOpen(dsn())
adapter := MustOpen(dsn(), WithDriver("pgx"))
defer adapter.Close()

assert.NotNil(t, adapter.Commit(ctx))
}

func TestAdapterPgx_Transaction_rollbackError(t *testing.T) {
driverName = "pgx"
adapter := MustOpen(dsn())
adapter := MustOpen(dsn(), WithDriver("pgx"))
defer adapter.Close()

assert.NotNil(t, adapter.Rollback(ctx))
}

func TestAdapterPgx_Exec_error(t *testing.T) {
driverName = "pgx"
adapter := MustOpen(dsn())
adapter := MustOpen(dsn(), WithDriver("pgx"))
defer adapter.Close()

_, _, err := adapter.Exec(ctx, "error", nil)
Expand All @@ -51,7 +47,6 @@ func TestAdapterPgx_Exec_error(t *testing.T) {

func TestAdapterPgx_InvalidDriverPanic(t *testing.T) {
assert.Panics(t, func() {
driverName = "pgx/v4"
MustOpen("postgres://test:test@localhost:1111/test?sslmode=disable&timezone=Asia/Jakarta")
MustOpen("postgres://test:test@localhost:1111/test?sslmode=disable&timezone=Asia/Jakarta", WithDriver("pgx/v4"))
})
}
41 changes: 26 additions & 15 deletions postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package postgres
import (
"context"
db "database/sql"
"slices"
"time"

"github.com/go-rel/rel"
Expand All @@ -31,8 +32,6 @@ type Postgres struct {
// Name of database type this adapter implements.
const Name string = "postgres"

var driverName string = "postgres"

// New postgres adapter using existing connection.
func New(database *db.DB) rel.Adapter {
var (
Expand Down Expand Up @@ -65,15 +64,37 @@ func New(database *db.DB) rel.Adapter {
}
}

type OpenOpt struct {
driver string
}

func WithDriver(driver string) OpenOpt {
return OpenOpt{driver: driver}
}

// Open postgres connection using dsn.
func Open(dsn string) (rel.Adapter, error) {
func Open(dsn string, opts ...OpenOpt) (rel.Adapter, error) {
// Default to postgres driver
driverName := "postgres"

// Identify if pgx driver is available and default to that instead.
if slices.Contains(db.Drivers(), "pgx") {
driverName = "pgx"
}

for _, opts := range opts {
if opts.driver != "" {
driverName = opts.driver
}
}

database, err := db.Open(driverName, dsn)
return New(database), err
}

// MustOpen postgres connection using dsn.
func MustOpen(dsn string) rel.Adapter {
adapter, err := Open(dsn)
func MustOpen(dsn string, opts ...OpenOpt) rel.Adapter {
adapter, err := Open(dsn, opts...)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -195,13 +216,3 @@ func columnMapper(column *rel.Column) (string, int, int) {

return typ, m, n
}

func init() {
// Identify if pgx driver is available and default to that instead.
for _, drv := range db.Drivers() {
if drv == "pgx" {
driverName = "pgx"
break
}
}
}
19 changes: 6 additions & 13 deletions postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ func dsn() string {
}

func TestAdapter_Name(t *testing.T) {
driverName = "postgres"
adapter := MustOpen(dsn())
adapter := MustOpen(dsn(), WithDriver("postgres"))
defer adapter.Close()

assert.Equal(t, Name, adapter.Name())
Expand Down Expand Up @@ -115,8 +114,7 @@ func TestAdapter_specs(t *testing.T) {
return
}

driverName = "postgres"
adapter := MustOpen(dsn())
adapter := MustOpen(dsn(), WithDriver("postgres"))
defer adapter.Close()

repo := rel.New(adapter)
Expand All @@ -129,7 +127,6 @@ func TestAdapter_PrimaryReplica_specs(t *testing.T) {
return
}

driverName = "postgres"
adapter := primaryreplica.New(
MustOpen("postgres://rel:rel@localhost:25432/rel_test?sslmode=disable&timezone=Asia/Jakarta"),
MustOpen("postgres://rel:rel@localhost:25433/rel_test?sslmode=disable&timezone=Asia/Jakarta"),
Expand All @@ -142,24 +139,21 @@ func TestAdapter_PrimaryReplica_specs(t *testing.T) {
}

func TestAdapter_Transaction_commitError(t *testing.T) {
driverName = "postgres"
adapter := MustOpen(dsn())
adapter := MustOpen(dsn(), WithDriver("postgres"))
defer adapter.Close()

assert.NotNil(t, adapter.Commit(ctx))
}

func TestAdapter_Transaction_rollbackError(t *testing.T) {
driverName = "postgres"
adapter := MustOpen(dsn())
adapter := MustOpen(dsn(), WithDriver("postgres"))
defer adapter.Close()

assert.NotNil(t, adapter.Rollback(ctx))
}

func TestAdapter_Exec_error(t *testing.T) {
driverName = "postgres"
adapter, err := Open(dsn())
adapter, err := Open(dsn(), WithDriver("postgres"))
assert.Nil(t, err)
defer adapter.Close()

Expand All @@ -168,8 +162,7 @@ func TestAdapter_Exec_error(t *testing.T) {
}

func TestAdapter_TableBuilder(t *testing.T) {
driverName = "postgres"
adapter := MustOpen(dsn())
adapter := MustOpen(dsn(), WithDriver("postgres"))
defer adapter.Close()

tests := []struct {
Expand Down
Loading