Skip to content

Commit

Permalink
fix(reset): ensure _supabase connections disconnect before reset (#2904)
Browse files Browse the repository at this point in the history
* fix(reset): ensure _supabase connections disconnect before reset

Closes #2903

* fix: switch tests

* fix: refactor reset for multiples db

* chore: refactor terminate backend query

---------

Co-authored-by: Qiao Han <[email protected]>
  • Loading branch information
avallete and sweatybridge authored Nov 25, 2024
1 parent e3f8f34 commit b9e89fc
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 50 deletions.
38 changes: 26 additions & 12 deletions internal/db/branch/switch_/switch__test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package switch_

import (
"context"
"fmt"
"net/http"
"os"
"path/filepath"
Expand All @@ -14,6 +13,7 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/supabase/cli/internal/db/reset"
"github.com/supabase/cli/internal/testing/apitest"
"github.com/supabase/cli/internal/utils"
"github.com/supabase/cli/pkg/pgtest"
Expand Down Expand Up @@ -42,10 +42,14 @@ func TestSwitchCommand(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;").
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query(fmt.Sprintf(utils.TerminateDbSqlFmt, "postgres")).
Reply("DO").
Query("ALTER DATABASE _supabase ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query(reset.TERMINATE_BACKENDS).
Reply("SELECT 1").
Query(reset.COUNT_REPLICATION_SLOTS).
Reply("SELECT 1", []interface{}{0}).
Query("ALTER DATABASE postgres RENAME TO main;").
Reply("ALTER DATABASE").
Query("ALTER DATABASE " + branch + " RENAME TO postgres;").
Expand Down Expand Up @@ -218,8 +222,10 @@ func TestSwitchDatabase(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;").
ReplyError(pgerrcode.InvalidParameterValue, `cannot disallow connections for current database`)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false").
ReplyError(pgerrcode.InvalidParameterValue, `cannot disallow connections for current database`).
Query("ALTER DATABASE _supabase ALLOW_CONNECTIONS false").
Query(reset.TERMINATE_BACKENDS)
// Run test
err := switchDatabase(context.Background(), "main", "target", conn.Intercept)
// Check error
Expand All @@ -234,10 +240,14 @@ func TestSwitchDatabase(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;").
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query("ALTER DATABASE _supabase ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query(fmt.Sprintf(utils.TerminateDbSqlFmt, "postgres")).
Reply("DO").
Query(reset.TERMINATE_BACKENDS).
Reply("SELECT 1").
Query(reset.COUNT_REPLICATION_SLOTS).
Reply("SELECT 1", []interface{}{0}).
Query("ALTER DATABASE postgres RENAME TO main;").
ReplyError(pgerrcode.DuplicateDatabase, `database "main" already exists`)
// Setup mock docker
Expand All @@ -260,10 +270,14 @@ func TestSwitchDatabase(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;").
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query("ALTER DATABASE _supabase ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query(fmt.Sprintf(utils.TerminateDbSqlFmt, "postgres")).
Reply("DO").
Query(reset.TERMINATE_BACKENDS).
Reply("SELECT 1").
Query(reset.COUNT_REPLICATION_SLOTS).
Reply("SELECT 1", []interface{}{0}).
Query("ALTER DATABASE postgres RENAME TO main;").
Reply("ALTER DATABASE").
Query("ALTER DATABASE target RENAME TO postgres;").
Expand Down
4 changes: 2 additions & 2 deletions internal/db/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ func CreateShadowDatabase(ctx context.Context, port uint16) (string, error) {

func ConnectShadowDatabase(ctx context.Context, timeout time.Duration, options ...func(*pgx.ConnConfig)) (conn *pgx.Conn, err error) {
// Retry until connected, cancelled, or timeout
policy := backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Second), uint64(timeout.Seconds()))
policy := start.NewBackoffPolicy(ctx, timeout)
config := pgconn.Config{Port: utils.Config.Db.ShadowPort}
connect := func() (*pgx.Conn, error) {
return utils.ConnectLocalPostgres(ctx, config, options...)
}
return backoff.RetryWithData(connect, backoff.WithContext(policy, ctx))
return backoff.RetryWithData(connect, policy)
}

func MigrateShadowDatabase(ctx context.Context, container string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
Expand Down
35 changes: 28 additions & 7 deletions internal/db/reset/reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
Expand Down Expand Up @@ -164,20 +165,40 @@ func recreateDatabase(ctx context.Context, options ...func(*pgx.ConnConfig)) err
return sql.ExecBatch(ctx, conn)
}

const (
TERMINATE_BACKENDS = "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname IN ('postgres', '_supabase')"
COUNT_REPLICATION_SLOTS = "SELECT COUNT(*) FROM pg_replication_slots WHERE database IN ('postgres', '_supabase')"
)

func DisconnectClients(ctx context.Context, conn *pgx.Conn) error {
// Must be executed separately because running in transaction is unsupported
disconn := "ALTER DATABASE postgres ALLOW_CONNECTIONS false;"
if _, err := conn.Exec(ctx, disconn); err != nil {
// Must be executed separately because looping in transaction is unsupported
// https://dba.stackexchange.com/a/11895
disconn := migration.MigrationFile{
Statements: []string{
"ALTER DATABASE postgres ALLOW_CONNECTIONS false",
"ALTER DATABASE _supabase ALLOW_CONNECTIONS false",
TERMINATE_BACKENDS,
},
}
if err := disconn.ExecBatch(ctx, conn); err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code != pgerrcode.InvalidCatalogName {
return errors.Errorf("failed to disconnect clients: %w", err)
}
}
term := fmt.Sprintf(utils.TerminateDbSqlFmt, "postgres")
if _, err := conn.Exec(ctx, term); err != nil {
return errors.Errorf("failed to terminate backend: %w", err)
// Wait for WAL senders to drop their replication slots
policy := start.NewBackoffPolicy(ctx, 10*time.Second)
waitForDrop := func() error {
var count int
if err := conn.QueryRow(ctx, COUNT_REPLICATION_SLOTS).Scan(&count); err != nil {
err = errors.Errorf("failed to count replication slots: %w", err)
return &backoff.PermanentError{Err: err}
} else if count > 0 {
return errors.Errorf("replication slots still active: %d", count)
}
return nil
}
return nil
return backoff.Retry(waitForDrop, policy)
}

func RestartDatabase(ctx context.Context, w io.Writer) error {
Expand Down
40 changes: 26 additions & 14 deletions internal/db/reset/reset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package reset
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"path/filepath"
Expand Down Expand Up @@ -202,10 +201,14 @@ func TestRecreateDatabase(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;").
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query(fmt.Sprintf(utils.TerminateDbSqlFmt, "postgres")).
Reply("DO").
Query("ALTER DATABASE _supabase ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query(TERMINATE_BACKENDS).
Reply("SELECT 1").
Query(COUNT_REPLICATION_SLOTS).
Reply("SELECT 1", []interface{}{0}).
Query("DROP DATABASE IF EXISTS postgres WITH (FORCE)").
Reply("DROP DATABASE").
Query("CREATE DATABASE postgres WITH OWNER postgres").
Expand All @@ -228,23 +231,28 @@ func TestRecreateDatabase(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;").
ReplyError(pgerrcode.InvalidCatalogName, `database "postgres" does not exist`).
Query(fmt.Sprintf(utils.TerminateDbSqlFmt, "postgres")).
ReplyError(pgerrcode.UndefinedTable, `relation "pg_stat_activity" does not exist`)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query("ALTER DATABASE _supabase ALLOW_CONNECTIONS false").
ReplyError(pgerrcode.InvalidCatalogName, `database "_supabase" does not exist`).
Query(TERMINATE_BACKENDS).
Query(COUNT_REPLICATION_SLOTS).
ReplyError(pgerrcode.UndefinedTable, `relation "pg_replication_slots" does not exist`)
// Run test
err := recreateDatabase(context.Background(), conn.Intercept)
// Check error
assert.ErrorContains(t, err, `ERROR: relation "pg_stat_activity" does not exist (SQLSTATE 42P01)`)
assert.ErrorContains(t, err, `ERROR: relation "pg_replication_slots" does not exist (SQLSTATE 42P01)`)
})

t.Run("throws error on failure to disconnect", func(t *testing.T) {
utils.Config.Db.Port = 54322
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;").
ReplyError(pgerrcode.InvalidParameterValue, `cannot disallow connections for current database`)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false").
ReplyError(pgerrcode.InvalidParameterValue, `cannot disallow connections for current database`).
Query("ALTER DATABASE _supabase ALLOW_CONNECTIONS false").
Query(TERMINATE_BACKENDS)
// Run test
err := recreateDatabase(context.Background(), conn.Intercept)
// Check error
Expand All @@ -256,10 +264,14 @@ func TestRecreateDatabase(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;").
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query("ALTER DATABASE _supabase ALLOW_CONNECTIONS false").
Reply("ALTER DATABASE").
Query(fmt.Sprintf(utils.TerminateDbSqlFmt, "postgres")).
Reply("DO").
Query(TERMINATE_BACKENDS).
Reply("SELECT 1").
Query(COUNT_REPLICATION_SLOTS).
Reply("SELECT 1", []interface{}{0}).
Query("DROP DATABASE IF EXISTS postgres WITH (FORCE)").
ReplyError(pgerrcode.ObjectInUse, `database "postgres" is used by an active logical replication slot`).
Query("CREATE DATABASE postgres WITH OWNER postgres").
Expand Down
13 changes: 9 additions & 4 deletions internal/db/start/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ EOF`}
return initCurrentBranch(fsys)
}

func NewBackoffPolicy(ctx context.Context, timeout time.Duration) backoff.BackOff {
policy := backoff.WithMaxRetries(
backoff.NewConstantBackOff(time.Second),
uint64(timeout.Seconds()),
)
return backoff.WithContext(policy, ctx)
}

func WaitForHealthyService(ctx context.Context, timeout time.Duration, started ...string) error {
probe := func() error {
var errHealth []error
Expand All @@ -173,10 +181,7 @@ func WaitForHealthyService(ctx context.Context, timeout time.Duration, started .
started = unhealthy
return errors.Join(errHealth...)
}
policy := backoff.WithContext(backoff.WithMaxRetries(
backoff.NewConstantBackOff(time.Second),
uint64(timeout.Seconds()),
), ctx)
policy := NewBackoffPolicy(ctx, timeout)
err := backoff.Retry(probe, policy)
if err != nil && !errors.Is(err, context.Canceled) {
// Print container logs for easier debugging
Expand Down
12 changes: 1 addition & 11 deletions internal/utils/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,7 @@ func ShortContainerImageName(imageName string) string {
return matches[1]
}

const (
// https://dba.stackexchange.com/a/11895
// Args: dbname
TerminateDbSqlFmt = `
SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '%[1]s';
-- Wait for WAL sender to drop replication slot.
DO 'BEGIN WHILE (
SELECT COUNT(*) FROM pg_replication_slots WHERE database = ''%[1]s''
) > 0 LOOP END LOOP; END';`
SuggestDebugFlag = "Try rerunning the command with --debug to troubleshoot the error."
)
const SuggestDebugFlag = "Try rerunning the command with --debug to troubleshoot the error."

var (
CmdSuggestion string
Expand Down

0 comments on commit b9e89fc

Please sign in to comment.