diff --git a/internal/migration_acceptance_tests/acceptance_test.go b/internal/migration_acceptance_tests/acceptance_test.go index 0814d13..18ecdfc 100644 --- a/internal/migration_acceptance_tests/acceptance_test.go +++ b/internal/migration_acceptance_tests/acceptance_test.go @@ -39,6 +39,8 @@ type ( empty bool } + planFactory func(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (diff.Plan, error) + acceptanceTestCase struct { name string oldSchemaDDL []string @@ -54,11 +56,13 @@ type ( // vanillaExpectations refers to the expectations of the migration if no additional opts are used vanillaExpectations expectations - // dataPackingExpectations refers to the expectations of the migration if table packing is used + // dataPackingExpectations refers to the expectations of the migration if table packing is used. We should + // aim to deprecate this and just split out a separate set of individual tests for data packing. dataPackingExpectations expectations - // use old generate plan func - useOldGeneratePlan bool + // planFactory is used to generate the actual plan. This is useful for testing different plan generation paths + // outside of the normal path. If not specified, a plan will be generated using a default. + planFactory planFactory } acceptanceTestSuite struct { @@ -124,8 +128,8 @@ func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expe suite.Require().NoError(tempDbFactory.Close()) }(tempDbFactory) - generatePlanFn := diff.GeneratePlan - if !tc.useOldGeneratePlan { + generatePlanFn := tc.planFactory + if generatePlanFn == nil { generatePlanFn = func(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (diff.Plan, error) { return diff.Generate(ctx, connPool, diff.DDLSchemaSource(newSchemaDDL), append(planOpts, diff --git a/internal/migration_acceptance_tests/backwards_compat_test.go b/internal/migration_acceptance_tests/backwards_compat_cases_test.go similarity index 99% rename from internal/migration_acceptance_tests/backwards_compat_test.go rename to internal/migration_acceptance_tests/backwards_compat_cases_test.go index 2741a52..b5d8511 100644 --- a/internal/migration_acceptance_tests/backwards_compat_test.go +++ b/internal/migration_acceptance_tests/backwards_compat_cases_test.go @@ -173,7 +173,7 @@ var backCompatAcceptanceTestCases = []acceptanceTestCase{ }, // Ensure that we're maintaining backwards compatibility with the old generate plan func - useOldGeneratePlan: true, + planFactory: diff.GeneratePlan, }, } diff --git a/internal/migration_acceptance_tests/database_schema_source_cases_test.go b/internal/migration_acceptance_tests/database_schema_source_cases_test.go new file mode 100644 index 0000000..7346d5c --- /dev/null +++ b/internal/migration_acceptance_tests/database_schema_source_cases_test.go @@ -0,0 +1,130 @@ +package migration_acceptance_tests + +import ( + "context" + "fmt" + + "github.com/stripe/pg-schema-diff/pkg/diff" + "github.com/stripe/pg-schema-diff/pkg/sqldb" + "github.com/stripe/pg-schema-diff/pkg/tempdb" +) + +func databaseSchemaSourcePlanFactory(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (_ diff.Plan, retErr error) { + newSchemaDb, err := tempDbFactory.Create(ctx) + if err != nil { + return diff.Plan{}, fmt.Errorf("creating temp database: %w", err) + } + + defer func() { + tempDbErr := newSchemaDb.Close(ctx) + if retErr == nil { + retErr = tempDbErr + } + }() + + for _, stmt := range newSchemaDDL { + if _, err := newSchemaDb.ConnPool.ExecContext(ctx, stmt); err != nil { + return diff.Plan{}, fmt.Errorf("running DDL: %w", err) + } + } + + // Clone the opts so we don't modify the original. + opts = append([]diff.PlanOpt(nil), opts...) + opts = append(opts, diff.WithTempDbFactory(tempDbFactory)) + for _, o := range newSchemaDb.ExcludeMetadatOptions { + opts = append(opts, diff.WithGetSchemaOpts(o)) + } + + return diff.Generate(ctx, connPool, diff.DBSchemaSource(newSchemaDb.ConnPool), opts...) +} + +var databaseSchemaSourceTestCases = []acceptanceTestCase{ + { + name: "Drop partitioned table, Add partitioned table with local keys", + oldSchemaDDL: []string{ + ` + CREATE TABLE fizz(); + + CREATE TABLE foobar( + id INT, + bar SERIAL NOT NULL, + foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0), + fizz TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (foo, id), + UNIQUE (foo, bar) + ) PARTITION BY LIST(foo); + + CREATE TABLE foobar_1 PARTITION of foobar( + fizz NOT NULL + ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2'); + + -- partitioned indexes + CREATE INDEX foobar_normal_idx ON foobar(foo, bar); + CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, fizz); + -- local indexes + CREATE INDEX foobar_1_local_idx ON foobar_1(foo, bar); + + CREATE table bar( + id VARCHAR(255) PRIMARY KEY, + foo VARCHAR(255), + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL CHECK (buzz IS NOT NULL), + FOREIGN KEY (foo, fizz) REFERENCES foobar (foo, fizz) + ); + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar, fizz); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE fizz(); + + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar( + bar TIMESTAMPTZ NOT NULL, + fizz TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + id INT, + foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0), + UNIQUE (foo, bar) + ) PARTITION BY LIST(foo); + + CREATE TABLE schema_1.foobar_1 PARTITION of schema_1.foobar( + fizz NOT NULL, + PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2'); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON schema_1.foobar_1(foo, bar); + -- partitioned indexes + CREATE INDEX foobar_normal_idx ON schema_1.foobar(foo, bar); + CREATE UNIQUE INDEX foobar_unique_idx ON schema_1.foobar(foo, fizz); + + CREATE table bar( + id VARCHAR(255) PRIMARY KEY, + foo VARCHAR(255), + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL CHECK (buzz IS NOT NULL), + FOREIGN KEY (foo, fizz) REFERENCES schema_1.foobar (foo, fizz) + ); + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar, fizz); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresShareRowExclusiveLock, + diff.MigrationHazardTypeDeletesData, + }, + + planFactory: databaseSchemaSourcePlanFactory, + }, +} + +func (suite *acceptanceTestSuite) TestDatabaseSchemaSourceTestCases() { + suite.runTestCases(databaseSchemaSourceTestCases) +} diff --git a/pkg/diff/plan_generator.go b/pkg/diff/plan_generator.go index e3232eb..5b8b0ba 100644 --- a/pkg/diff/plan_generator.go +++ b/pkg/diff/plan_generator.go @@ -10,6 +10,7 @@ import ( _ "github.com/jackc/pgx/v4/stdlib" "github.com/kr/pretty" "github.com/stripe/pg-schema-diff/internal/schema" + externalschema "github.com/stripe/pg-schema-diff/pkg/schema" "github.com/stripe/pg-schema-diff/pkg/log" "github.com/stripe/pg-schema-diff/pkg/sqldb" @@ -80,6 +81,18 @@ func WithSchemas(schemas ...string) PlanOpt { } } +func WithExcludeSchemas(schemas ...string) PlanOpt { + return func(opts *planOptions) { + opts.getSchemaOpts = append(opts.getSchemaOpts, schema.WithExcludeSchemas(schemas...)) + } +} + +func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt { + return func(opts *planOptions) { + opts.getSchemaOpts = append(opts.getSchemaOpts, getSchemaOpts...) + } +} + // deprecated: GeneratePlan generates a migration plan to migrate the database to the target schema. This function only // diffs the public schemas. // diff --git a/pkg/diff/plan_generator_test.go b/pkg/diff/plan_generator_test.go index 058bbd0..7936a2b 100644 --- a/pkg/diff/plan_generator_test.go +++ b/pkg/diff/plan_generator_test.go @@ -1,40 +1,52 @@ -package diff_test +package diff import ( "context" "database/sql" - "io" + "fmt" "testing" _ "github.com/jackc/pgx/v4/stdlib" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "github.com/stripe/pg-schema-diff/internal/pgengine" - "github.com/stripe/pg-schema-diff/pkg/diff" + "github.com/stripe/pg-schema-diff/internal/schema" + "github.com/stripe/pg-schema-diff/pkg/log" + externalschema "github.com/stripe/pg-schema-diff/pkg/schema" "github.com/stripe/pg-schema-diff/pkg/tempdb" ) -type simpleMigratorTestSuite struct { +type fakeSchemaSource struct { + t *testing.T + + expectedDeps schemaSourcePlanDeps + schema schema.Schema + err error +} + +func (f fakeSchemaSource) GetSchema(_ context.Context, deps schemaSourcePlanDeps) (schema.Schema, error) { + assert.Equal(f.t, f.expectedDeps.logger, deps.logger) + assert.Equal(f.t, f.expectedDeps.tempDBFactory, deps.tempDBFactory) + // We can't easily compare the function pointers, so we'll just assert the length of the slices. + assert.Len(f.t, f.expectedDeps.getSchemaOpts, len(deps.getSchemaOpts)) + return f.schema, f.err +} + +type planGeneratorTestSuite struct { suite.Suite pgEngine *pgengine.Engine db *pgengine.DB } -func (suite *simpleMigratorTestSuite) mustGetTestDBPool() *sql.DB { +func (suite *planGeneratorTestSuite) mustGetTestDBPool() *sql.DB { pool, err := sql.Open("pgx", suite.db.GetDSN()) suite.NoError(err) return pool } -func (suite *simpleMigratorTestSuite) mustGetTestDBConn() (conn *sql.Conn, poolCloser io.Closer) { - pool := suite.mustGetTestDBPool() - conn, err := pool.Conn(context.Background()) - suite.Require().NoError(err) - return conn, pool -} - -func (suite *simpleMigratorTestSuite) mustBuildTempDbFactory(ctx context.Context) tempdb.Factory { +func (suite *planGeneratorTestSuite) mustBuildTempDbFactory(ctx context.Context) tempdb.Factory { tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) { return sql.Open("pgx", suite.pgEngine.GetPostgresDatabaseConnOpts().With("dbname", dbName).ToDSN()) }) @@ -42,7 +54,7 @@ func (suite *simpleMigratorTestSuite) mustBuildTempDbFactory(ctx context.Context return tempDbFactory } -func (suite *simpleMigratorTestSuite) mustApplyDDLToTestDb(ddl []string) { +func (suite *planGeneratorTestSuite) mustApplyDDLToTestDb(ddl []string) { conn := suite.mustGetTestDBPool() defer conn.Close() @@ -52,27 +64,27 @@ func (suite *simpleMigratorTestSuite) mustApplyDDLToTestDb(ddl []string) { } } -func (suite *simpleMigratorTestSuite) SetupSuite() { +func (suite *planGeneratorTestSuite) SetupSuite() { engine, err := pgengine.StartEngine() suite.Require().NoError(err) suite.pgEngine = engine } -func (suite *simpleMigratorTestSuite) TearDownSuite() { +func (suite *planGeneratorTestSuite) TearDownSuite() { suite.pgEngine.Close() } -func (suite *simpleMigratorTestSuite) SetupTest() { +func (suite *planGeneratorTestSuite) SetupTest() { db, err := suite.pgEngine.CreateDatabase() suite.NoError(err) suite.db = db } -func (suite *simpleMigratorTestSuite) TearDownTest() { +func (suite *planGeneratorTestSuite) TearDownTest() { suite.db.DropDB() } -func (suite *simpleMigratorTestSuite) TestGeneratePlan_GenerateAndApply() { +func (suite *planGeneratorTestSuite) TestGenerate() { initialDDL := ` CREATE TABLE foobar( id CHAR(16) PRIMARY KEY @@ -92,14 +104,10 @@ func (suite *simpleMigratorTestSuite) TestGeneratePlan_GenerateAndApply() { tempDbFactory := suite.mustBuildTempDbFactory(context.Background()) defer tempDbFactory.Close() - plan, err := diff.GeneratePlan(context.Background(), connPool, tempDbFactory, []string{newSchemaDDL}) + plan, err := Generate(context.Background(), connPool, DDLSchemaSource([]string{newSchemaDDL}), WithTempDbFactory(tempDbFactory)) suite.NoError(err) - // Run the migration - for _, stmt := range plan.Statements { - _, err = connPool.ExecContext(context.Background(), stmt.ToSQL()) - suite.Require().NoError(err) - } + suite.mustApplyMigrationPlan(connPool, plan) // Ensure that some sort of migration ran. we're really not testing the correctness of the // migration in this test suite _, err = connPool.ExecContext(context.Background(), @@ -107,41 +115,82 @@ func (suite *simpleMigratorTestSuite) TestGeneratePlan_GenerateAndApply() { suite.NoError(err) } -func (suite *simpleMigratorTestSuite) TestGeneratePlan_CannotPackNewTablesWithoutIgnoringChangesToColumnOrder() { +func (suite *planGeneratorTestSuite) TestGeneratePlan_SchemaSourceErr() { tempDbFactory := suite.mustBuildTempDbFactory(context.Background()) defer tempDbFactory.Close() - conn, poolCloser := suite.mustGetTestDBConn() - defer poolCloser.Close() - defer conn.Close() + logger := log.SimpleLogger() + + getSchemaOpts := []externalschema.GetSchemaOpt{ + externalschema.WithIncludeSchemas("schema_1"), + externalschema.WithIncludeSchemas("schema_2"), + } + + expectedErr := fmt.Errorf("some error") + fakeSchemaSource := fakeSchemaSource{ + t: suite.T(), + expectedDeps: schemaSourcePlanDeps{ + tempDBFactory: tempDbFactory, + logger: logger, + getSchemaOpts: getSchemaOpts, + }, + err: expectedErr, + } + + connPool := suite.mustGetTestDBPool() + defer connPool.Close() + + _, err := Generate(context.Background(), connPool, fakeSchemaSource, + WithTempDbFactory(tempDbFactory), + WithGetSchemaOpts(getSchemaOpts...), + WithLogger(logger), + ) + suite.ErrorIs(err, expectedErr) +} + +func (suite *planGeneratorTestSuite) mustApplyMigrationPlan(db *sql.DB, plan Plan) { + // Run the migration + for _, stmt := range plan.Statements { + _, err := db.ExecContext(context.Background(), stmt.ToSQL()) + suite.Require().NoError(err) + } +} + +func (suite *planGeneratorTestSuite) TestGenerate_CannotPackNewTablesWithoutIgnoringChangesToColumnOrder() { + tempDbFactory := suite.mustBuildTempDbFactory(context.Background()) + defer tempDbFactory.Close() + + connPool := suite.mustGetTestDBPool() + defer connPool.Close() - _, err := diff.GeneratePlan(context.Background(), conn, tempDbFactory, []string{``}, - diff.WithDataPackNewTables(), - diff.WithRespectColumnOrder(), + _, err := Generate(context.Background(), connPool, DDLSchemaSource([]string{``}), + WithTempDbFactory(tempDbFactory), + WithDataPackNewTables(), + WithRespectColumnOrder(), ) suite.ErrorContains(err, "cannot data pack new tables without also ignoring changes to column order") } -func (suite *simpleMigratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWithoutTempDbFactory() { +func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWithoutTempDbFactory() { pool := suite.mustGetTestDBPool() defer pool.Close() - _, err := diff.Generate(context.Background(), pool, diff.DDLSchemaSource([]string{``}), - diff.WithSchemas("public"), - diff.WithDoNotValidatePlan(), + _, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}), + WithSchemas("public"), + WithDoNotValidatePlan(), ) suite.ErrorContains(err, "tempDbFactory is required") } -func (suite *simpleMigratorTestSuite) TestGenerate_CannotValidateWithoutTempDbFactory() { +func (suite *planGeneratorTestSuite) TestGenerate_CannotValidateWithoutTempDbFactory() { pool := suite.mustGetTestDBPool() defer pool.Close() - _, err := diff.Generate(context.Background(), pool, diff.DDLSchemaSource([]string{``}), - diff.WithSchemas("public"), - diff.WithDoNotValidatePlan(), + _, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}), + WithSchemas("public"), + WithDoNotValidatePlan(), ) suite.ErrorContains(err, "tempDbFactory is required") } func TestSimpleMigratorTestSuite(t *testing.T) { - suite.Run(t, new(simpleMigratorTestSuite)) + suite.Run(t, new(planGeneratorTestSuite)) } diff --git a/pkg/diff/schema_source.go b/pkg/diff/schema_source.go index 0c6943d..364e0df 100644 --- a/pkg/diff/schema_source.go +++ b/pkg/diff/schema_source.go @@ -6,6 +6,7 @@ import ( "github.com/stripe/pg-schema-diff/internal/schema" "github.com/stripe/pg-schema-diff/pkg/log" + "github.com/stripe/pg-schema-diff/pkg/sqldb" "github.com/stripe/pg-schema-diff/pkg/tempdb" ) @@ -52,3 +53,17 @@ func (s *ddlSchemaSource) GetSchema(ctx context.Context, deps schemaSourcePlanDe return schema.GetSchema(ctx, tempDb.ConnPool, append(deps.getSchemaOpts, tempDb.ExcludeMetadatOptions...)...) } + +type dbSchemaSource struct { + queryable sqldb.Queryable +} + +// DBSchemaSource returns a SchemaSource that returns a schema based on the provided queryable. It is recommended +// that the sqldb.Queryable is a *sql.DB with a max # of connections set. +func DBSchemaSource(queryable sqldb.Queryable) SchemaSource { + return &dbSchemaSource{queryable: queryable} +} + +func (s *dbSchemaSource) GetSchema(ctx context.Context, deps schemaSourcePlanDeps) (schema.Schema, error) { + return schema.GetSchema(ctx, s.queryable, deps.getSchemaOpts...) +} diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 8b462ae..896141e 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -8,7 +8,7 @@ import ( "github.com/stripe/pg-schema-diff/pkg/sqldb" ) -type GetSchemaOptions = internalschema.GetSchemaOpt +type GetSchemaOpt = internalschema.GetSchemaOpt var ( WithIncludeSchemas = internalschema.WithIncludeSchemas diff --git a/pkg/tempdb/factory.go b/pkg/tempdb/factory.go index 8d71dba..7740a47 100644 --- a/pkg/tempdb/factory.go +++ b/pkg/tempdb/factory.go @@ -37,7 +37,7 @@ type ( // ConnPool is the connection pool to the temporary database ConnPool *sql.DB // ExcludeMetadataOptions are the options used to exclude any internal metadata from plan generation - ExcludeMetadatOptions []schema.GetSchemaOptions + ExcludeMetadatOptions []schema.GetSchemaOpt // ContextualCloser should be called to clean up the temporary database ContextualCloser } @@ -207,7 +207,7 @@ func (o *onInstanceFactory) Create(ctx context.Context) (_ *Database, retErr err return &Database{ ConnPool: tempDbConn, - ExcludeMetadatOptions: []schema.GetSchemaOptions{ + ExcludeMetadatOptions: []schema.GetSchemaOpt{ schema.WithExcludeSchemas(o.options.metadataSchema), }, ContextualCloser: fnContextualCloser(func(ctx context.Context) error {