diff --git a/provider/postgres/postgres.go b/provider/postgres/postgres.go index 2be4f46..9a40a94 100644 --- a/provider/postgres/postgres.go +++ b/provider/postgres/postgres.go @@ -4,8 +4,9 @@ import ( "context" "database/sql" "fmt" - "github.com/iamsalnikov/mymigrate/provider" "time" + + "github.com/iamsalnikov/mymigrate/provider" ) // Provider - migration provider for postgres db @@ -58,14 +59,14 @@ func (p *Provider) GetApplied(ctx context.Context) ([]string, error) { // MarkApplied - function for mark migration applied func (p *Provider) MarkApplied(ctx context.Context, name string, t time.Time) error { - query := fmt.Sprintf("INSERT INTO %s (name, time) VALUES (?, ?)", provider.DefaultTableName) + query := fmt.Sprintf("INSERT INTO %s (name, time) VALUES ($1, $2)", provider.DefaultTableName) _, err := p.db.ExecContext(ctx, query, name, t) return err } // DeleteApplied - function for delete migration from applied list func (p *Provider) DeleteApplied(ctx context.Context, name string) error { - query := fmt.Sprintf("DELETE FROM %s WHERE name=?", provider.DefaultTableName) + query := fmt.Sprintf("DELETE FROM %s WHERE name=$1", provider.DefaultTableName) _, err := p.db.ExecContext(ctx, query, name) return err } diff --git a/provider/postgres/postgres_test.go b/provider/postgres/postgres_test.go index 7a33ce9..7ddd86b 100644 --- a/provider/postgres/postgres_test.go +++ b/provider/postgres/postgres_test.go @@ -4,12 +4,13 @@ import ( "context" "errors" "fmt" + "testing" + "time" + "github.com/DATA-DOG/go-sqlmock" "github.com/iamsalnikov/mymigrate/provider" "github.com/iamsalnikov/mymigrate/provider/postgres" "github.com/stretchr/testify/assert" - "testing" - "time" ) func TestPsqlProvider_CreateMigrationsTable(t *testing.T) { @@ -126,7 +127,7 @@ func TestPsqlProvider_MarkApplied(t *testing.T) { name: "migration_1", time: now, execError: nil, - expectQuery: fmt.Sprintf("INSERT INTO %s (name, time) VALUES (?, ?)", provider.DefaultTableName), + expectQuery: fmt.Sprintf("INSERT INTO %s (name, time) VALUES ($1, $2)", provider.DefaultTableName), expectArgs: []interface{}{"migration_1", now}, expectErr: nil, }, @@ -134,7 +135,7 @@ func TestPsqlProvider_MarkApplied(t *testing.T) { name: "migration_2", time: now, execError: errors.New("some db error"), - expectQuery: fmt.Sprintf("INSERT INTO %s (name, time) VALUES (?, ?)", provider.DefaultTableName), + expectQuery: fmt.Sprintf("INSERT INTO %s (name, time) VALUES ($1, $2)", provider.DefaultTableName), expectArgs: []interface{}{"migration_2", now}, expectErr: errors.New("some db error"), }, @@ -173,14 +174,14 @@ func TestPsqlProvider_DeleteApplied(t *testing.T) { "all is ok": { name: "migration_1", execError: nil, - expectQuery: fmt.Sprintf("DELETE FROM %s WHERE name=?", provider.DefaultTableName), + expectQuery: fmt.Sprintf("DELETE FROM %s WHERE name=$1", provider.DefaultTableName), expectArgs: []interface{}{"migration_1"}, expectErr: nil, }, "db error": { name: "migration_2", execError: errors.New("some db error"), - expectQuery: fmt.Sprintf("DELETE FROM %s WHERE name=?", provider.DefaultTableName), + expectQuery: fmt.Sprintf("DELETE FROM %s WHERE name=$1", provider.DefaultTableName), expectArgs: []interface{}{"migration_2"}, expectErr: errors.New("some db error"), },