Skip to content

Commit

Permalink
Update CLI with multi-schema support (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
bplunkett-stripe authored Feb 8, 2024
1 parent 97fffc8 commit 7fc1d07
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 55 deletions.
4 changes: 2 additions & 2 deletions cmd/pg-schema-diff/apply_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ func buildApplyCmd() *cobra.Command {
" (example: --allowed-hazards DELETES_DATA,INDEX_BUILD)")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
logger := log.SimpleLogger()
connConfig, err := connFlags.parseConnConfig(logger)
connConfig, err := parseConnConfig(*connFlags, logger)
if err != nil {
return err
}

planConfig, err := planFlags.parsePlanConfig()
planConfig, err := parsePlanConfig(*planFlags)
if err != nil {
return err
}
Expand Down
25 changes: 10 additions & 15 deletions cmd/pg-schema-diff/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,23 @@ import (
)

type connFlags struct {
dsn *string
dsn string
}

func createConnFlags(cmd *cobra.Command) connFlags {
dsn := cmd.Flags().String("dsn", "", "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)")
func createConnFlags(cmd *cobra.Command) *connFlags {
flags := &connFlags{}

cmd.Flags().StringVar(&flags.dsn, "dsn", "", "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)")
// Don't mark dsn as a required flag.
// Allow users to use the "PGHOST" etc environment variables like `psql`.
return connFlags{
dsn: dsn,
}

return flags
}

func (c connFlags) parseConnConfig(logger log.Logger) (*pgx.ConnConfig, error) {
if c.dsn == nil || *c.dsn == "" {
func parseConnConfig(c connFlags, logger log.Logger) (*pgx.ConnConfig, error) {
if c.dsn == "" {
logger.Warnf("DSN flag not set. Using libpq environment variables and default values.")
}

return pgx.ParseConfig(*c.dsn)
}

func mustMarkFlagAsRequired(cmd *cobra.Command, flagName string) {
if err := cmd.MarkFlagRequired(flagName); err != nil {
panic(err)
}
return pgx.ParseConfig(c.dsn)
}
148 changes: 115 additions & 33 deletions cmd/pg-schema-diff/plan_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
Expand All @@ -18,6 +19,10 @@ import (
"github.com/stripe/pg-schema-diff/pkg/tempdb"
)

const (
defaultMaxConnections = 5
)

var (
// Match arguments in the format "regex=duration" where duration is any duration valid in time.ParseDuration
// We'll let time.ParseDuration handle the complexity of parsing invalid duration, so the regex we're extracting is
Expand Down Expand Up @@ -47,12 +52,12 @@ func buildPlanCmd() *cobra.Command {
planFlags := createPlanFlags(cmd)
cmd.RunE = func(cmd *cobra.Command, args []string) error {
logger := log.SimpleLogger()
connConfig, err := connFlags.parseConnConfig(logger)
connConfig, err := parseConnConfig(*connFlags, logger)
if err != nil {
return err
}

planConfig, err := planFlags.parsePlanConfig()
planConfig, err := parsePlanConfig(*planFlags)
if err != nil {
return err
}
Expand All @@ -75,11 +80,24 @@ func buildPlanCmd() *cobra.Command {
}

type (
schemaFlags struct {
includeSchemas []string
excludeSchemas []string
}

schemaSourceFlags struct {
schemaDir string
targetDatabaseDSN string
}

planFlags struct {
schemaDir *string
statementTimeoutModifiers *[]string
lockTimeoutModifiers *[]string
insertStatements *[]string
dbSchemaSourceFlags schemaSourceFlags

schemaFlags schemaFlags

statementTimeoutModifiers []string
lockTimeoutModifiers []string
insertStatements []string
}

timeoutModifiers struct {
Expand All @@ -93,43 +111,65 @@ type (
timeout time.Duration
}

schemaSourceFactory func() (diff.SchemaSource, io.Closer, error)

planConfig struct {
schemaDir string
schemaSourceFactory schemaSourceFactory
opts []diff.PlanOpt

statementTimeoutModifiers []timeoutModifiers
lockTimeoutModifiers []timeoutModifiers
insertStatements []insertStatement
}
)

func createPlanFlags(cmd *cobra.Command) planFlags {
schemaDir := cmd.Flags().String("schema-dir", "", "Directory containing schema files")
mustMarkFlagAsRequired(cmd, "schema-dir")
func createPlanFlags(cmd *cobra.Command) *planFlags {
flags := &planFlags{}

schemaSourceFlagsVar(cmd, &flags.dbSchemaSourceFlags)

statementTimeoutModifiers := timeoutModifierFlag(cmd, "statement", "t")
lockTimeoutModifiers := timeoutModifierFlag(cmd, "lock", "l")
insertStatements := cmd.Flags().StringArrayP("insert-statement", "s", nil,
schemaFlagsVar(cmd, &flags.schemaFlags)

timeoutModifierFlagVar(cmd, &flags.statementTimeoutModifiers, "statement", "t")
timeoutModifierFlagVar(cmd, &flags.lockTimeoutModifiers, "lock", "l")
cmd.Flags().StringArrayVarP(&flags.insertStatements, "insert-statement", "s", nil,
"<index>_<timeout>:<statement> values. Will insert the statement at the index in the "+
"generated plan with the specified timeout. This follows normal insert semantics. Example: -s '0 5s:SELECT 1''")

return planFlags{
schemaDir: schemaDir,
statementTimeoutModifiers: statementTimeoutModifiers,
lockTimeoutModifiers: lockTimeoutModifiers,
insertStatements: insertStatements,
return flags
}

func schemaSourceFlagsVar(cmd *cobra.Command, p *schemaSourceFlags) {
cmd.Flags().StringVar(&p.schemaDir, "schema-dir", "", "Directory of .SQL files to use as the schema source. Use to generate a diff between the target database and the schema in this directory.")
if err := cmd.MarkFlagDirname("schema-dir"); err != nil {
panic(err)
}
cmd.Flags().StringVar(&p.targetDatabaseDSN, "schema-source-dsn", "", "DSN for the database to use as the schema source. Use to generate a diff between the target database and the schema in this database.")

cmd.MarkFlagsMutuallyExclusive("schema-dir", "schema-source-dsn")
}

func schemaFlagsVar(cmd *cobra.Command, p *schemaFlags) {
cmd.Flags().StringArrayVar(&p.includeSchemas, "include-schema", nil, "Include the specified schema in the plan")
cmd.Flags().StringArrayVar(&p.excludeSchemas, "exclude-schema", nil, "Exclude the specified schema in the plan")
}

func timeoutModifierFlag(cmd *cobra.Command, timeoutType string, shorthand string) *[]string {
func timeoutModifierFlagVar(cmd *cobra.Command, p *[]string, timeoutType string, shorthand string) {
flagName := fmt.Sprintf("%s-timeout-modifier", timeoutType)
desc := fmt.Sprintf("regex=timeout key-value pairs, where if a statement matches the regex, the statement "+
"will be modified to have the %s timeout. If multiple regexes match, the latest regex will take priority. "+
"Example: -t 'CREATE TABLE=5m' -t 'CONCURRENTLY=10s'", timeoutType)
return cmd.Flags().StringArrayP(flagName, shorthand, nil, desc)
cmd.Flags().StringArrayVarP(p, flagName, shorthand, nil, desc)
}

func (p planFlags) parsePlanConfig() (planConfig, error) {
func parsePlanConfig(p planFlags) (planConfig, error) {
schemaSourceFactory, err := parseSchemaSource(p.dbSchemaSourceFlags)
if err != nil {
return planConfig{}, err
}

var statementTimeoutModifiers []timeoutModifiers
for _, s := range *p.statementTimeoutModifiers {
for _, s := range p.statementTimeoutModifiers {
stm, err := parseTimeoutModifier(s)
if err != nil {
return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err)
Expand All @@ -138,7 +178,7 @@ func (p planFlags) parsePlanConfig() (planConfig, error) {
}

var lockTimeoutModifiers []timeoutModifiers
for _, s := range *p.lockTimeoutModifiers {
for _, s := range p.lockTimeoutModifiers {
ltm, err := parseTimeoutModifier(s)
if err != nil {
return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err)
Expand All @@ -147,7 +187,7 @@ func (p planFlags) parsePlanConfig() (planConfig, error) {
}

var insertStatements []insertStatement
for _, i := range *p.insertStatements {
for _, i := range p.insertStatements {
is, err := parseInsertStatementStr(i)
if err != nil {
return planConfig{}, fmt.Errorf("parsing insert statement from %q: %w", i, err)
Expand All @@ -156,13 +196,49 @@ func (p planFlags) parsePlanConfig() (planConfig, error) {
}

return planConfig{
schemaDir: *p.schemaDir,
schemaSourceFactory: schemaSourceFactory,
opts: parseSchemaConfig(p.schemaFlags),
statementTimeoutModifiers: statementTimeoutModifiers,
lockTimeoutModifiers: lockTimeoutModifiers,
insertStatements: insertStatements,
}, nil
}

func parseSchemaSource(p schemaSourceFlags) (schemaSourceFactory, error) {
if p.schemaDir != "" {
ddl, err := getDDLFromPath(p.schemaDir)
if err != nil {
return nil, err
}
return func() (diff.SchemaSource, io.Closer, error) {
return diff.DDLSchemaSource(ddl), nil, nil
}, nil
}

if p.targetDatabaseDSN != "" {
connConfig, err := pgx.ParseConfig(p.targetDatabaseDSN)
if err != nil {
return nil, fmt.Errorf("parsing DSN %q: %w", p.targetDatabaseDSN, err)
}
return func() (diff.SchemaSource, io.Closer, error) {
connPool, err := openDbWithPgxConfig(connConfig)
if err != nil {
return nil, nil, fmt.Errorf("opening db with pgx config: %w", err)
}
return diff.DBSchemaSource(connPool), connPool, nil
}, nil
}

return nil, fmt.Errorf("either --schema-dir or --schema-source-dsn must be set")
}

func parseSchemaConfig(p schemaFlags) []diff.PlanOpt {
return []diff.PlanOpt{
diff.WithIncludeSchemas(p.includeSchemas...),
diff.WithExcludeSchemas(p.excludeSchemas...),
}
}

func parseTimeoutModifier(val string) (timeoutModifiers, error) {
submatches := statementTimeoutModifierRegex.FindStringSubmatch(val)
if len(submatches) <= regexSTMRegexIndex || len(submatches) <= durationSTMRegexIndex {
Expand Down Expand Up @@ -216,11 +292,6 @@ func parseInsertStatementStr(val string) (insertStatement, error) {
}

func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnConfig, planConfig planConfig) (diff.Plan, error) {
ddl, err := getDDLFromPath(planConfig.schemaDir)
if err != nil {
return diff.Plan{}, nil
}

tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) {
copiedConfig := connConfig.Copy()
copiedConfig.Database = dbName
Expand All @@ -241,11 +312,22 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo
return diff.Plan{}, err
}
defer connPool.Close()
connPool.SetMaxOpenConns(defaultMaxConnections)

connPool.SetMaxOpenConns(5)
schemaSource, schemaSourceCloser, err := planConfig.schemaSourceFactory()
if err != nil {
return diff.Plan{}, fmt.Errorf("creating schema source: %w", err)
}
if schemaSourceCloser != nil {
defer schemaSourceCloser.Close()
}

plan, err := diff.GeneratePlan(ctx, connPool, tempDbFactory, ddl,
diff.WithDataPackNewTables(),
plan, err := diff.Generate(ctx, connPool, schemaSource,
append(
planConfig.opts,
diff.WithTempDbFactory(tempDbFactory),
diff.WithDataPackNewTables(),
)...,
)
if err != nil {
return diff.Plan{}, fmt.Errorf("generating plan: %w", err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/diff/plan_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func WithLogger(logger log.Logger) PlanOpt {
}
}

func WithSchemas(schemas ...string) PlanOpt {
func WithIncludeSchemas(schemas ...string) PlanOpt {
return func(opts *planOptions) {
opts.getSchemaOpts = append(opts.getSchemaOpts, schema.WithIncludeSchemas(schemas...))
}
Expand All @@ -96,7 +96,7 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt {
// deprecated: GeneratePlan generates a migration plan to migrate the database to the target schema. This function only
// diffs the public schemas.
//
// Use Generate instead with the DDLSchemaSource(newDDL) and WithSchemas("public") and WithTempDbFactory options.
// Use Generate instead with the DDLSchemaSource(newDDL) and WithIncludeSchemas("public") and WithTempDbFactory options.
//
// Parameters:
// queryable: The target database to generate the diff for. It is recommended to pass in *sql.DB of the db you
Expand All @@ -106,7 +106,7 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt {
// newDDL: DDL encoding the new schema
// opts: Additional options to configure the plan generation
func GeneratePlan(ctx context.Context, queryable sqldb.Queryable, tempdbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) {
return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithSchemas("public"))...)
return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)
}

// Generate generates a migration plan to migrate the database to the target schema
Expand Down
4 changes: 2 additions & 2 deletions pkg/diff/plan_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWit
pool := suite.mustGetTestDBPool()
defer pool.Close()
_, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}),
WithSchemas("public"),
WithIncludeSchemas("public"),
WithDoNotValidatePlan(),
)
suite.ErrorContains(err, "tempDbFactory is required")
Expand All @@ -185,7 +185,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotValidateWithoutTempDbFac
pool := suite.mustGetTestDBPool()
defer pool.Close()
_, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}),
WithSchemas("public"),
WithIncludeSchemas("public"),
WithDoNotValidatePlan(),
)
suite.ErrorContains(err, "tempDbFactory is required")
Expand Down

0 comments on commit 7fc1d07

Please sign in to comment.