diff --git a/cmd/pg-schema-diff/apply_cmd.go b/cmd/pg-schema-diff/apply_cmd.go index 3a49408..d5dc4e7 100644 --- a/cmd/pg-schema-diff/apply_cmd.go +++ b/cmd/pg-schema-diff/apply_cmd.go @@ -28,12 +28,12 @@ func buildApplyCmd() *cobra.Command { " (example: --allowed-hazards DELETES_DATA,INDEX_BUILD)") cmd.RunE = func(cmd *cobra.Command, args []string) error { logger := log.SimpleLogger() - connConfig, err := connFlags.parseConnConfig(logger) + connConfig, err := parseConnConfig(*connFlags, logger) if err != nil { return err } - planConfig, err := planFlags.parsePlanConfig() + planConfig, err := parsePlanConfig(*planFlags) if err != nil { return err } diff --git a/cmd/pg-schema-diff/flags.go b/cmd/pg-schema-diff/flags.go index b08ac34..128acfc 100644 --- a/cmd/pg-schema-diff/flags.go +++ b/cmd/pg-schema-diff/flags.go @@ -7,28 +7,23 @@ import ( ) type connFlags struct { - dsn *string + dsn string } -func createConnFlags(cmd *cobra.Command) connFlags { - dsn := cmd.Flags().String("dsn", "", "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)") +func createConnFlags(cmd *cobra.Command) *connFlags { + flags := &connFlags{} + + cmd.Flags().StringVar(&flags.dsn, "dsn", "", "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)") // Don't mark dsn as a required flag. // Allow users to use the "PGHOST" etc environment variables like `psql`. - return connFlags{ - dsn: dsn, - } + + return flags } -func (c connFlags) parseConnConfig(logger log.Logger) (*pgx.ConnConfig, error) { - if c.dsn == nil || *c.dsn == "" { +func parseConnConfig(c connFlags, logger log.Logger) (*pgx.ConnConfig, error) { + if c.dsn == "" { logger.Warnf("DSN flag not set. Using libpq environment variables and default values.") } - return pgx.ParseConfig(*c.dsn) -} - -func mustMarkFlagAsRequired(cmd *cobra.Command, flagName string) { - if err := cmd.MarkFlagRequired(flagName); err != nil { - panic(err) - } + return pgx.ParseConfig(c.dsn) } diff --git a/cmd/pg-schema-diff/plan_cmd.go b/cmd/pg-schema-diff/plan_cmd.go index 9ae9c0a..939c3fa 100644 --- a/cmd/pg-schema-diff/plan_cmd.go +++ b/cmd/pg-schema-diff/plan_cmd.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "io" "os" "path/filepath" "regexp" @@ -18,6 +19,10 @@ import ( "github.com/stripe/pg-schema-diff/pkg/tempdb" ) +const ( + defaultMaxConnections = 5 +) + var ( // Match arguments in the format "regex=duration" where duration is any duration valid in time.ParseDuration // We'll let time.ParseDuration handle the complexity of parsing invalid duration, so the regex we're extracting is @@ -47,12 +52,12 @@ func buildPlanCmd() *cobra.Command { planFlags := createPlanFlags(cmd) cmd.RunE = func(cmd *cobra.Command, args []string) error { logger := log.SimpleLogger() - connConfig, err := connFlags.parseConnConfig(logger) + connConfig, err := parseConnConfig(*connFlags, logger) if err != nil { return err } - planConfig, err := planFlags.parsePlanConfig() + planConfig, err := parsePlanConfig(*planFlags) if err != nil { return err } @@ -75,11 +80,24 @@ func buildPlanCmd() *cobra.Command { } type ( + schemaFlags struct { + includeSchemas []string + excludeSchemas []string + } + + schemaSourceFlags struct { + schemaDir string + targetDatabaseDSN string + } + planFlags struct { - schemaDir *string - statementTimeoutModifiers *[]string - lockTimeoutModifiers *[]string - insertStatements *[]string + dbSchemaSourceFlags schemaSourceFlags + + schemaFlags schemaFlags + + statementTimeoutModifiers []string + lockTimeoutModifiers []string + insertStatements []string } timeoutModifiers struct { @@ -93,43 +111,65 @@ type ( timeout time.Duration } + schemaSourceFactory func() (diff.SchemaSource, io.Closer, error) + planConfig struct { - schemaDir string + schemaSourceFactory schemaSourceFactory + opts []diff.PlanOpt + statementTimeoutModifiers []timeoutModifiers lockTimeoutModifiers []timeoutModifiers insertStatements []insertStatement } ) -func createPlanFlags(cmd *cobra.Command) planFlags { - schemaDir := cmd.Flags().String("schema-dir", "", "Directory containing schema files") - mustMarkFlagAsRequired(cmd, "schema-dir") +func createPlanFlags(cmd *cobra.Command) *planFlags { + flags := &planFlags{} + + schemaSourceFlagsVar(cmd, &flags.dbSchemaSourceFlags) - statementTimeoutModifiers := timeoutModifierFlag(cmd, "statement", "t") - lockTimeoutModifiers := timeoutModifierFlag(cmd, "lock", "l") - insertStatements := cmd.Flags().StringArrayP("insert-statement", "s", nil, + schemaFlagsVar(cmd, &flags.schemaFlags) + + timeoutModifierFlagVar(cmd, &flags.statementTimeoutModifiers, "statement", "t") + timeoutModifierFlagVar(cmd, &flags.lockTimeoutModifiers, "lock", "l") + cmd.Flags().StringArrayVarP(&flags.insertStatements, "insert-statement", "s", nil, "_: values. Will insert the statement at the index in the "+ "generated plan with the specified timeout. This follows normal insert semantics. Example: -s '0 5s:SELECT 1''") - return planFlags{ - schemaDir: schemaDir, - statementTimeoutModifiers: statementTimeoutModifiers, - lockTimeoutModifiers: lockTimeoutModifiers, - insertStatements: insertStatements, + return flags +} + +func schemaSourceFlagsVar(cmd *cobra.Command, p *schemaSourceFlags) { + cmd.Flags().StringVar(&p.schemaDir, "schema-dir", "", "Directory of .SQL files to use as the schema source. Use to generate a diff between the target database and the schema in this directory.") + if err := cmd.MarkFlagDirname("schema-dir"); err != nil { + panic(err) } + cmd.Flags().StringVar(&p.targetDatabaseDSN, "schema-source-dsn", "", "DSN for the database to use as the schema source. Use to generate a diff between the target database and the schema in this database.") + + cmd.MarkFlagsMutuallyExclusive("schema-dir", "schema-source-dsn") +} + +func schemaFlagsVar(cmd *cobra.Command, p *schemaFlags) { + cmd.Flags().StringArrayVar(&p.includeSchemas, "include-schema", nil, "Include the specified schema in the plan") + cmd.Flags().StringArrayVar(&p.excludeSchemas, "exclude-schema", nil, "Exclude the specified schema in the plan") } -func timeoutModifierFlag(cmd *cobra.Command, timeoutType string, shorthand string) *[]string { +func timeoutModifierFlagVar(cmd *cobra.Command, p *[]string, timeoutType string, shorthand string) { flagName := fmt.Sprintf("%s-timeout-modifier", timeoutType) desc := fmt.Sprintf("regex=timeout key-value pairs, where if a statement matches the regex, the statement "+ "will be modified to have the %s timeout. If multiple regexes match, the latest regex will take priority. "+ "Example: -t 'CREATE TABLE=5m' -t 'CONCURRENTLY=10s'", timeoutType) - return cmd.Flags().StringArrayP(flagName, shorthand, nil, desc) + cmd.Flags().StringArrayVarP(p, flagName, shorthand, nil, desc) } -func (p planFlags) parsePlanConfig() (planConfig, error) { +func parsePlanConfig(p planFlags) (planConfig, error) { + schemaSourceFactory, err := parseSchemaSource(p.dbSchemaSourceFlags) + if err != nil { + return planConfig{}, err + } + var statementTimeoutModifiers []timeoutModifiers - for _, s := range *p.statementTimeoutModifiers { + for _, s := range p.statementTimeoutModifiers { stm, err := parseTimeoutModifier(s) if err != nil { return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err) @@ -138,7 +178,7 @@ func (p planFlags) parsePlanConfig() (planConfig, error) { } var lockTimeoutModifiers []timeoutModifiers - for _, s := range *p.lockTimeoutModifiers { + for _, s := range p.lockTimeoutModifiers { ltm, err := parseTimeoutModifier(s) if err != nil { return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err) @@ -147,7 +187,7 @@ func (p planFlags) parsePlanConfig() (planConfig, error) { } var insertStatements []insertStatement - for _, i := range *p.insertStatements { + for _, i := range p.insertStatements { is, err := parseInsertStatementStr(i) if err != nil { return planConfig{}, fmt.Errorf("parsing insert statement from %q: %w", i, err) @@ -156,13 +196,49 @@ func (p planFlags) parsePlanConfig() (planConfig, error) { } return planConfig{ - schemaDir: *p.schemaDir, + schemaSourceFactory: schemaSourceFactory, + opts: parseSchemaConfig(p.schemaFlags), statementTimeoutModifiers: statementTimeoutModifiers, lockTimeoutModifiers: lockTimeoutModifiers, insertStatements: insertStatements, }, nil } +func parseSchemaSource(p schemaSourceFlags) (schemaSourceFactory, error) { + if p.schemaDir != "" { + ddl, err := getDDLFromPath(p.schemaDir) + if err != nil { + return nil, err + } + return func() (diff.SchemaSource, io.Closer, error) { + return diff.DDLSchemaSource(ddl), nil, nil + }, nil + } + + if p.targetDatabaseDSN != "" { + connConfig, err := pgx.ParseConfig(p.targetDatabaseDSN) + if err != nil { + return nil, fmt.Errorf("parsing DSN %q: %w", p.targetDatabaseDSN, err) + } + return func() (diff.SchemaSource, io.Closer, error) { + connPool, err := openDbWithPgxConfig(connConfig) + if err != nil { + return nil, nil, fmt.Errorf("opening db with pgx config: %w", err) + } + return diff.DBSchemaSource(connPool), connPool, nil + }, nil + } + + return nil, fmt.Errorf("either --schema-dir or --schema-source-dsn must be set") +} + +func parseSchemaConfig(p schemaFlags) []diff.PlanOpt { + return []diff.PlanOpt{ + diff.WithIncludeSchemas(p.includeSchemas...), + diff.WithExcludeSchemas(p.excludeSchemas...), + } +} + func parseTimeoutModifier(val string) (timeoutModifiers, error) { submatches := statementTimeoutModifierRegex.FindStringSubmatch(val) if len(submatches) <= regexSTMRegexIndex || len(submatches) <= durationSTMRegexIndex { @@ -216,11 +292,6 @@ func parseInsertStatementStr(val string) (insertStatement, error) { } func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnConfig, planConfig planConfig) (diff.Plan, error) { - ddl, err := getDDLFromPath(planConfig.schemaDir) - if err != nil { - return diff.Plan{}, nil - } - tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) { copiedConfig := connConfig.Copy() copiedConfig.Database = dbName @@ -241,11 +312,22 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo return diff.Plan{}, err } defer connPool.Close() + connPool.SetMaxOpenConns(defaultMaxConnections) - connPool.SetMaxOpenConns(5) + schemaSource, schemaSourceCloser, err := planConfig.schemaSourceFactory() + if err != nil { + return diff.Plan{}, fmt.Errorf("creating schema source: %w", err) + } + if schemaSourceCloser != nil { + defer schemaSourceCloser.Close() + } - plan, err := diff.GeneratePlan(ctx, connPool, tempDbFactory, ddl, - diff.WithDataPackNewTables(), + plan, err := diff.Generate(ctx, connPool, schemaSource, + append( + planConfig.opts, + diff.WithTempDbFactory(tempDbFactory), + diff.WithDataPackNewTables(), + )..., ) if err != nil { return diff.Plan{}, fmt.Errorf("generating plan: %w", err) diff --git a/pkg/diff/plan_generator.go b/pkg/diff/plan_generator.go index 5b8b0ba..510b536 100644 --- a/pkg/diff/plan_generator.go +++ b/pkg/diff/plan_generator.go @@ -75,7 +75,7 @@ func WithLogger(logger log.Logger) PlanOpt { } } -func WithSchemas(schemas ...string) PlanOpt { +func WithIncludeSchemas(schemas ...string) PlanOpt { return func(opts *planOptions) { opts.getSchemaOpts = append(opts.getSchemaOpts, schema.WithIncludeSchemas(schemas...)) } @@ -96,7 +96,7 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt { // deprecated: GeneratePlan generates a migration plan to migrate the database to the target schema. This function only // diffs the public schemas. // -// Use Generate instead with the DDLSchemaSource(newDDL) and WithSchemas("public") and WithTempDbFactory options. +// Use Generate instead with the DDLSchemaSource(newDDL) and WithIncludeSchemas("public") and WithTempDbFactory options. // // Parameters: // queryable: The target database to generate the diff for. It is recommended to pass in *sql.DB of the db you @@ -106,7 +106,7 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt { // newDDL: DDL encoding the new schema // opts: Additional options to configure the plan generation func GeneratePlan(ctx context.Context, queryable sqldb.Queryable, tempdbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) { - return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithSchemas("public"))...) + return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...) } // Generate generates a migration plan to migrate the database to the target schema diff --git a/pkg/diff/plan_generator_test.go b/pkg/diff/plan_generator_test.go index 7936a2b..a615bcf 100644 --- a/pkg/diff/plan_generator_test.go +++ b/pkg/diff/plan_generator_test.go @@ -175,7 +175,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWit pool := suite.mustGetTestDBPool() defer pool.Close() _, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}), - WithSchemas("public"), + WithIncludeSchemas("public"), WithDoNotValidatePlan(), ) suite.ErrorContains(err, "tempDbFactory is required") @@ -185,7 +185,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotValidateWithoutTempDbFac pool := suite.mustGetTestDBPool() defer pool.Close() _, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}), - WithSchemas("public"), + WithIncludeSchemas("public"), WithDoNotValidatePlan(), ) suite.ErrorContains(err, "tempDbFactory is required")