From d231ee09209479099b22eeccdadb59d7a6275f76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?No=C3=A9mi=20V=C3=A1nyi?= Date: Mon, 18 Nov 2024 23:24:10 +0100 Subject: [PATCH] Add support for creating `CHECK` constraints with `create_constraint` (#464) This PR introduces a new constraint `type` to `create_constraint` operation called `check`. Now it is possible to create check constraints on multiple columns. ### Example ```json { "name": "45_add_table_check_constraint", "operations": [ { "create_constraint": { "type": "check", "table": "tickets", "name": "check_zip_name", "columns": [ "sellers_name", "sellers_zip" ], "check": "sellers_name ~ 'Alice' AND sellers_zip IS NOT NULL", "up": { "sellers_name": "Alice", "sellers_zip": "(SELECT CASE WHEN sellers_zip IS NOT NULL THEN sellers_zip ELSE '00000' END)" }, "down": { "sellers_name": "sellers_name", "sellers_zip": "sellers_zip" } } } ] } ``` --------- Co-authored-by: Andrew Farries --- docs/README.md | 3 +- examples/.ledger | 1 + examples/45_add_table_check_constraint.json | 25 +++ pkg/migrations/op_create_constraint.go | 33 ++- pkg/migrations/op_create_constraint_test.go | 220 ++++++++++++++++++++ pkg/migrations/types.go | 4 + schema.json | 6 +- 7 files changed, 287 insertions(+), 5 deletions(-) create mode 100644 examples/45_add_table_check_constraint.json diff --git a/docs/README.md b/docs/README.md index 508ff681..86551ef2 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1101,7 +1101,7 @@ Example **create table** migrations: A create constraint operation adds a new constraint to an existing table. -Only `UNIQUE` constraints are supported. +Only `UNIQUE` and `CHECK` constraints are supported. Required fields: `name`, `table`, `type`, `up`, `down`. @@ -1129,6 +1129,7 @@ Required fields: `name`, `table`, `type`, `up`, `down`. Example **create constraint** migrations: * [44_add_table_unique_constraint.json](../examples/44_add_table_unique_constraint.json) +* [45_add_table_check_constraint.json](../examples/45_add_table_check_constraint.json) ### Drop column diff --git a/examples/.ledger b/examples/.ledger index a8067c2a..28503df2 100644 --- a/examples/.ledger +++ b/examples/.ledger @@ -42,3 +42,4 @@ 42_create_unique_index.json 43_create_tickets_table.json 44_add_table_unique_constraint.json +45_add_table_check_constraint.json diff --git a/examples/45_add_table_check_constraint.json b/examples/45_add_table_check_constraint.json new file mode 100644 index 00000000..294220dd --- /dev/null +++ b/examples/45_add_table_check_constraint.json @@ -0,0 +1,25 @@ +{ + "name": "45_add_table_check_constraint", + "operations": [ + { + "create_constraint": { + "type": "check", + "table": "tickets", + "name": "check_zip_name", + "columns": [ + "sellers_name", + "sellers_zip" + ], + "check": "sellers_name = 'alice' OR sellers_zip > 0", + "up": { + "sellers_name": "sellers_name", + "sellers_zip": "(SELECT CASE WHEN sellers_name != 'alice' AND sellers_zip <= 0 THEN 123 WHEN sellers_name != 'alice' THEN sellers_zip ELSE sellers_zip END)" + }, + "down": { + "sellers_name": "sellers_name", + "sellers_zip": "sellers_zip" + } + } + } + ] +} diff --git a/pkg/migrations/op_create_constraint.go b/pkg/migrations/op_create_constraint.go index 148eaf47..846215e0 100644 --- a/pkg/migrations/op_create_constraint.go +++ b/pkg/migrations/op_create_constraint.go @@ -65,16 +65,18 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema } } - switch o.Type { //nolint:gocritic // more cases will be added + switch o.Type { case OpCreateConstraintTypeUnique: return table, o.addUniqueIndex(ctx, conn) + case OpCreateConstraintTypeCheck: + return table, o.addCheckConstraint(ctx, conn) } return table, nil } func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { - switch o.Type { //nolint:gocritic // more cases will be added + switch o.Type { case OpCreateConstraintTypeUnique: uniqueOp := &OpSetUnique{ Table: o.Table, @@ -84,6 +86,17 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra if err != nil { return err } + case OpCreateConstraintTypeCheck: + checkOp := &OpSetCheckConstraint{ + Table: o.Table, + Check: CheckConstraint{ + Name: o.Name, + }, + } + err := checkOp.Complete(ctx, conn, tr, s) + if err != nil { + return err + } } // remove old columns @@ -176,11 +189,15 @@ func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) err } } - switch o.Type { //nolint:gocritic // more cases will be added + switch o.Type { case OpCreateConstraintTypeUnique: if len(o.Columns) == 0 { return FieldRequiredError{Name: "columns"} } + case OpCreateConstraintTypeCheck: + if o.Check == nil || *o.Check == "" { + return FieldRequiredError{Name: "check"} + } } return nil @@ -196,6 +213,16 @@ func (o *OpCreateConstraint) addUniqueIndex(ctx context.Context, conn db.DB) err return err } +func (o *OpCreateConstraint) addCheckConstraint(ctx context.Context, conn db.DB) error { + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(o.Name), + rewriteCheckExpression(*o.Check, o.Columns...), + )) + + return err +} + func quotedTemporaryNames(columns []string) []string { names := make([]string, len(columns)) for i, col := range columns { diff --git a/pkg/migrations/op_create_constraint_test.go b/pkg/migrations/op_create_constraint_test.go index 8b2ca8e6..9a372a5d 100644 --- a/pkg/migrations/op_create_constraint_test.go +++ b/pkg/migrations/op_create_constraint_test.go @@ -7,6 +7,8 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/xataio/pgroll/internal/testutils" "github.com/xataio/pgroll/pkg/migrations" ) @@ -97,6 +99,80 @@ func TestCreateConstraint(t *testing.T) { }, testutils.UniqueViolationErrorCode) }, }, + { + name: "create check constraint on single column", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_create_constraint", + Operations: migrations.Operations{ + &migrations.OpCreateConstraint{ + Name: "name_letters", + Table: "users", + Type: "check", + Check: ptr("name ~ '^[a-zA-Z]+$'"), + Columns: []string{"name"}, + Up: migrations.OpCreateConstraintUp(map[string]string{ + "name": "regexp_replace(name, '\\d+', '', 'g')", + }), + Down: migrations.OpCreateConstraintDown(map[string]string{ + "name": "name", + }), + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // The new (temporary) column should exist on the underlying table. + ColumnMustExist(t, db, schema, "users", migrations.TemporaryName("name")) + // The check constraint exists on the new table. + CheckConstraintMustExist(t, db, schema, "users", "name_letters") + // Inserting values into the old schema that violate the check constraint must succeed. + MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "alice11", + }) + + // Inserting values into the new schema that violate the check constraint should fail. + MustInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "bob", + }) + MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "bob2", + }, testutils.CheckViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // Functions, triggers and temporary columns are dropped. + tableCleanedUp(t, db, schema, "users", "name") + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // Functions, triggers and temporary columns are dropped. + tableCleanedUp(t, db, schema, "users", "name") + + // Inserting values into the new schema that violate the check constraint should fail. + MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "carol0", + }, testutils.CheckViolationErrorCode) + }, + }, { name: "create unique constraint on multiple columns", migrations: []migrations.Migration{ @@ -181,6 +257,104 @@ func TestCreateConstraint(t *testing.T) { // Complete is a no-op. }, }, + { + name: "create check constraint on multiple columns", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + { + Name: "email", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_create_constraint", + Operations: migrations.Operations{ + &migrations.OpCreateConstraint{ + Name: "check_name_email", + Table: "users", + Type: "check", + Check: ptr("name != email"), + Columns: []string{"name", "email"}, + Up: migrations.OpCreateConstraintUp(map[string]string{ + "name": "name", + "email": "(SELECT CASE WHEN email ~ '@' THEN email ELSE email || '@example.com' END)", + }), + Down: migrations.OpCreateConstraintDown(map[string]string{ + "name": "name", + "email": "email", + }), + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // The new (temporary) column should exist on the underlying table. + ColumnMustExist(t, db, schema, "users", migrations.TemporaryName("name")) + // The new (temporary) column should exist on the underlying table. + ColumnMustExist(t, db, schema, "users", migrations.TemporaryName("email")) + // The check constraint exists on the new table. + CheckConstraintMustExist(t, db, schema, "users", "check_name_email") + + // Inserting values into the old schema that the violate the check constraint must succeed. + MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "alice", + "email": "alice", + }) + + // Inserting values into the new schema that meet the check constraint should succeed. + MustInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "bob", + "email": "bob@bob.me", + }) + MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "bob", + "email": "bob", + }, testutils.CheckViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // The check constraint must not exists on the table. + CheckConstraintMustNotExist(t, db, schema, "users", "check_name_email") + // Functions, triggers and temporary columns are dropped. + tableCleanedUp(t, db, schema, "users", "name") + tableCleanedUp(t, db, schema, "users", "email") + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // Functions, triggers and temporary columns are dropped. + tableCleanedUp(t, db, schema, "users", "name") + tableCleanedUp(t, db, schema, "users", "email") + + // Inserting values into the new schema that the violate the check constraint must fail. + MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{ + "name": "carol", + "email": "carol", + }, testutils.CheckViolationErrorCode) + + rows := MustSelect(t, db, schema, "02_create_constraint", "users") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "alice", "email": "alice@example.com"}, + {"id": 2, "name": "bob", "email": "bob@bob.me"}, + }, rows) + }, + }, { name: "invalid constraint name", migrations: []migrations.Migration{ @@ -270,6 +444,52 @@ func TestCreateConstraint(t *testing.T) { afterRollback: func(t *testing.T, db *sql.DB, schema string) {}, afterComplete: func(t *testing.T, db *sql.DB, schema string) {}, }, + { + name: "expression of check constraint is missing", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Nullable: ptr(false), + }, + }, + }, + }, + }, + { + Name: "02_create_constraint_with_missing_migration", + Operations: migrations.Operations{ + &migrations.OpCreateConstraint{ + Name: "check_name", + Table: "users", + Columns: []string{"name"}, + Type: "check", + Up: migrations.OpCreateConstraintUp(map[string]string{ + "name": "name", + }), + Down: migrations.OpCreateConstraintDown(map[string]string{ + "name": "name", + }), + }, + }, + }, + }, + wantStartErr: migrations.FieldRequiredError{Name: "check"}, + afterStart: func(t *testing.T, db *sql.DB, schema string) {}, + afterRollback: func(t *testing.T, db *sql.DB, schema string) {}, + afterComplete: func(t *testing.T, db *sql.DB, schema string) {}, + }, }) } diff --git a/pkg/migrations/types.go b/pkg/migrations/types.go index bd11428b..2f30d554 100644 --- a/pkg/migrations/types.go +++ b/pkg/migrations/types.go @@ -121,6 +121,9 @@ type OpAlterColumn struct { // Add constraint to table operation type OpCreateConstraint struct { + // Check constraint expression + Check *string `json:"check,omitempty"` + // Columns to add constraint to Columns []string `json:"columns,omitempty"` @@ -145,6 +148,7 @@ type OpCreateConstraintDown map[string]string type OpCreateConstraintType string +const OpCreateConstraintTypeCheck OpCreateConstraintType = "check" const OpCreateConstraintTypeUnique OpCreateConstraintType = "unique" // SQL expression of up migration by column diff --git a/schema.json b/schema.json index 428c11f1..44aa8188 100644 --- a/schema.json +++ b/schema.json @@ -454,7 +454,11 @@ "type": { "description": "Type of the constraint", "type": "string", - "enum": ["unique"] + "enum": ["unique", "check"] + }, + "check": { + "description": "Check constraint expression", + "type": "string" }, "up": { "type": "object",