diff --git a/docs/src/piccolo/schema/advanced.rst b/docs/src/piccolo/schema/advanced.rst index 0e4f3db0f..9cb1cd1e1 100644 --- a/docs/src/piccolo/schema/advanced.rst +++ b/docs/src/piccolo/schema/advanced.rst @@ -119,6 +119,8 @@ for example with :ref:`table_finder `. ------------------------------------------------------------------------------- +.. _Mixins: + Mixins ------ @@ -127,15 +129,37 @@ use mixins to reduce the amount of repetition. .. code-block:: python - from piccolo.columns import Varchar, Boolean + from piccolo.columns import Date, Varchar from piccolo.table import Table - class FavouriteMixin: - favourite = Boolean(default=False) + class DateOfBirthMixin: + date_of_birth = Date() + + + class Manager(DateOfBirthMixin, Table): + name = Varchar() + +You can also add :ref:`constraints ` to your mixin classes. + +.. code-block:: python + + import datetime + + from piccolo.columns import Varchar, Date + from piccolo.constraints import Check + from piccolo.table import Table + + + class DateOfBirthMixin: + date_of_birth = Date() + + min_date_of_birth = Check( + date_of_birth >= datetime.date(year=1920, month=1, day=1) + ) - class Manager(FavouriteMixin, Table): + class Manager(DateOfBirthMixin, Table): name = Varchar() ------------------------------------------------------------------------------- diff --git a/docs/src/piccolo/schema/constraints.rst b/docs/src/piccolo/schema/constraints.rst new file mode 100644 index 000000000..bdaa1eb3e --- /dev/null +++ b/docs/src/piccolo/schema/constraints.rst @@ -0,0 +1,50 @@ +=========== +Constraints +=========== + +Simple unique constraints +========================= + +Unique constraints can be added to a single column using the ``unique=True`` +argument of ``Column``: + +.. code-block:: python + + class Band(Table): + name = Varchar(unique=True) + +------------------------------------------------------------------------------- + +.. _AdvancedConstraints: + +Advanced constraints +===================== + +You can add you can implement powerful ``UNIQUE`` and ``CHECK`` constraints +on your ``Table``. + +``Unique`` +---------- + +.. currentmodule:: piccolo.constraints + +.. autoclass:: Unique + +``Check`` +---------- + +.. autoclass:: Check + +How are they created? +--------------------- + +If creating a new table using ``await MyTable.create_table()``, then the +constraints will also be created. + +Also, if using auto migrations, they handle the creation and deletion of these +constraints for you. + +Mixins +------ + +Constraints can be added to :ref:`mixin classes ` for reusability. diff --git a/docs/src/piccolo/schema/index.rst b/docs/src/piccolo/schema/index.rst index ec9b887e6..29396d7df 100644 --- a/docs/src/piccolo/schema/index.rst +++ b/docs/src/piccolo/schema/index.rst @@ -8,6 +8,7 @@ The schema is how you define your database tables, columns and relationships. ./defining ./column_types + ./constraints ./m2m ./one_to_one ./advanced diff --git a/piccolo/apps/migrations/auto/diffable_table.py b/piccolo/apps/migrations/auto/diffable_table.py index 522f4f001..07ac959e0 100644 --- a/piccolo/apps/migrations/auto/diffable_table.py +++ b/piccolo/apps/migrations/auto/diffable_table.py @@ -5,14 +5,17 @@ from piccolo.apps.migrations.auto.operations import ( AddColumn, + AddConstraint, AlterColumn, DropColumn, + DropConstraint, ) from piccolo.apps.migrations.auto.serialisation import ( deserialise_params, serialise_params, ) from piccolo.columns.base import Column +from piccolo.constraints import Constraint from piccolo.table import Table, create_table_class @@ -62,6 +65,8 @@ class TableDelta: add_columns: t.List[AddColumn] = field(default_factory=list) drop_columns: t.List[DropColumn] = field(default_factory=list) alter_columns: t.List[AlterColumn] = field(default_factory=list) + add_constraints: t.List[AddConstraint] = field(default_factory=list) + drop_constraints: t.List[DropConstraint] = field(default_factory=list) def __eq__(self, value: TableDelta) -> bool: # type: ignore """ @@ -92,6 +97,19 @@ def __eq__(self, value) -> bool: return False +@dataclass +class ConstraintComparison: + constraint: Constraint + + def __hash__(self) -> int: + return self.constraint.__hash__() + + def __eq__(self, value) -> bool: + if isinstance(value, ConstraintComparison): + return self.constraint._meta.name == value.constraint._meta.name + return False + + @dataclass class DiffableTable: """ @@ -103,6 +121,7 @@ class DiffableTable: tablename: str schema: t.Optional[str] = None columns: t.List[Column] = field(default_factory=list) + constraints: t.List[Constraint] = field(default_factory=list) previous_class_name: t.Optional[str] = None def __post_init__(self) -> None: @@ -196,10 +215,54 @@ def __sub__(self, value: DiffableTable) -> TableDelta: ) ) + add_constraints = [ + AddConstraint( + table_class_name=self.class_name, + constraint_name=i.constraint._meta.name, + constraint_class_name=i.constraint.__class__.__name__, + constraint_class=i.constraint.__class__, + params=i.constraint._meta.params, + schema=self.schema, + ) + for i in sorted( + { + ConstraintComparison(constraint=constraint) + for constraint in self.constraints + } + - { + ConstraintComparison(constraint=constraint) + for constraint in value.constraints + }, + key=lambda x: x.constraint._meta.name, + ) + ] + + drop_constraints = [ + DropConstraint( + table_class_name=self.class_name, + constraint_name=i.constraint._meta.name, + tablename=value.tablename, + schema=self.schema, + ) + for i in sorted( + { + ConstraintComparison(constraint=constraint) + for constraint in value.constraints + } + - { + ConstraintComparison(constraint=constraint) + for constraint in self.constraints + }, + key=lambda x: x.constraint._meta.name, + ) + ] + return TableDelta( add_columns=add_columns, drop_columns=drop_columns, alter_columns=alter_columns, + add_constraints=add_constraints, + drop_constraints=drop_constraints, ) def __hash__(self) -> int: @@ -225,10 +288,14 @@ def to_table_class(self) -> t.Type[Table]: """ Converts the DiffableTable into a Table subclass. """ + class_members: t.Dict[str, t.Any] = {} + for column in self.columns: + class_members[column._meta.name] = column + for constraint in self.constraints: + class_members[constraint._meta.name] = constraint + return create_table_class( class_name=self.class_name, class_kwargs={"tablename": self.tablename, "schema": self.schema}, - class_members={ - column._meta.name: column for column in self.columns - }, + class_members=class_members, ) diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index 31ad5c120..7519e3395 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -10,12 +10,14 @@ AlterColumn, ChangeTableSchema, DropColumn, + DropConstraint, RenameColumn, RenameTable, ) from piccolo.apps.migrations.auto.serialisation import deserialise_params from piccolo.columns import Column, column_types from piccolo.columns.column_types import ForeignKey, Serial +from piccolo.constraints import Constraint from piccolo.engine import engine_finder from piccolo.query import Query from piccolo.query.base import DDL @@ -127,6 +129,65 @@ def table_class_names(self) -> t.List[str]: return list({i.table_class_name for i in self.alter_columns}) +@dataclass +class AddConstraintClass: + constraint: Constraint + table_class_name: str + tablename: str + schema: t.Optional[str] + + +@dataclass +class AddConstraintCollection: + add_constraints: t.List[AddConstraintClass] = field(default_factory=list) + + def append(self, add_constraint: AddConstraintClass): + self.add_constraints.append(add_constraint) + + def for_table_class_name( + self, table_class_name: str + ) -> t.List[AddConstraintClass]: + return [ + i + for i in self.add_constraints + if i.table_class_name == table_class_name + ] + + def constraints_for_table_class_name( + self, table_class_name: str + ) -> t.List[Constraint]: + return [ + i.constraint + for i in self.add_constraints + if i.table_class_name == table_class_name + ] + + @property + def table_class_names(self) -> t.List[str]: + return list({i.table_class_name for i in self.add_constraints}) + + +@dataclass +class DropConstraintCollection: + drop_constraints: t.List[DropConstraint] = field(default_factory=list) + + def append(self, drop_constraint: DropConstraint): + self.drop_constraints.append(drop_constraint) + + def for_table_class_name( + self, table_class_name: str + ) -> t.List[DropConstraint]: + return [ + i + for i in self.drop_constraints + if i.table_class_name == table_class_name + ] + + @property + def table_class_names(self) -> t.List[str]: + return list({i.table_class_name for i in self.drop_constraints}) + + AsyncFunction = t.Callable[[], t.Coroutine] @@ -175,6 +236,12 @@ class MigrationManager: alter_columns: AlterColumnCollection = field( default_factory=AlterColumnCollection ) + add_constraints: AddConstraintCollection = field( + default_factory=AddConstraintCollection + ) + drop_constraints: DropConstraintCollection = field( + default_factory=DropConstraintCollection + ) raw: t.List[t.Union[t.Callable, AsyncFunction]] = field( default_factory=list ) @@ -364,6 +431,43 @@ def alter_column( ) ) + def add_constraint( + self, + table_class_name: str, + tablename: str, + constraint_name: str, + constraint_class: t.Type[Constraint], + params: t.Dict[str, t.Any], + schema: t.Optional[str] = None, + ): + constraint = constraint_class(**params) + constraint._meta.name = constraint_name + + self.add_constraints.append( + AddConstraintClass( + constraint=constraint, + table_class_name=table_class_name, + tablename=tablename, + schema=schema, + ) + ) + + def drop_constraint( + self, + table_class_name: str, + tablename: str, + constraint_name: str, + schema: t.Optional[str] = None, + ): + self.drop_constraints.append( + DropConstraint( + table_class_name=table_class_name, + constraint_name=constraint_name, + tablename=tablename, + schema=schema, + ) + ) + def add_raw(self, raw: t.Union[t.Callable, AsyncFunction]): """ A migration manager can execute arbitrary functions or coroutines when @@ -759,17 +863,28 @@ async def _run_add_tables(self, backwards: bool = False): add_columns: t.List[AddColumnClass] = ( self.add_columns.for_table_class_name(add_table.class_name) ) + class_members: t.Dict[str, t.Any] = {} + for add_column in add_columns: + class_members[add_column.column._meta.name] = add_column.column + _Table: t.Type[Table] = create_table_class( class_name=add_table.class_name, class_kwargs={ "tablename": add_table.tablename, "schema": add_table.schema, }, - class_members={ - add_column.column._meta.name: add_column.column - for add_column in add_columns - }, + class_members=class_members, + ) + + _Table._meta.constraints.extend( + [ + i.constraint + for i in self.add_constraints.for_table_class_name( + add_table.class_name + ) + ] ) + table_classes.append(_Table) # Sort by foreign key, so they're created in the right order. @@ -941,6 +1056,91 @@ async def _run_change_table_schema(self, backwards: bool = False): ) ) + async def _run_add_constraints(self, backwards: bool = False): + if backwards: + for add_constraint in self.add_constraints.add_constraints: + if add_constraint.table_class_name in [ + i.class_name for i in self.add_tables + ]: + # Don't reverse the add constraint as the table is going to + # be deleted. + continue + + _Table = create_table_class( + class_name=add_constraint.table_class_name, + class_kwargs={ + "tablename": add_constraint.tablename, + "schema": add_constraint.schema, + }, + ) + + await self._run_query( + _Table.alter().drop_constraint( + add_constraint.constraint._meta.name + ) + ) + else: + for table_class_name in self.add_constraints.table_class_names: + if table_class_name in [i.class_name for i in self.add_tables]: + continue # No need to add constraints to new tables + + add_constraints: t.List[AddConstraintClass] = ( + self.add_constraints.for_table_class_name(table_class_name) + ) + + _Table = create_table_class( + class_name=add_constraints[0].table_class_name, + class_kwargs={ + "tablename": add_constraints[0].tablename, + "schema": add_constraints[0].schema, + }, + ) + + for add_constraint in add_constraints: + await self._run_query( + _Table.alter().add_constraint( + add_constraint.constraint + ) + ) + + async def _run_drop_constraints(self, backwards: bool = False): + if backwards: + for drop_constraint in self.drop_constraints.drop_constraints: + _Table = await self.get_table_from_snapshot( + table_class_name=drop_constraint.table_class_name, + app_name=self.app_name, + offset=-1, + ) + constraint_to_restore = _Table._meta.get_constraint_by_name( + drop_constraint.constraint_name + ) + await self._run_query( + _Table.alter().add_constraint(constraint_to_restore) + ) + else: + for table_class_name in self.drop_constraints.table_class_names: + constraints = self.drop_constraints.for_table_class_name( + table_class_name + ) + + if not constraints: + continue + + _Table = create_table_class( + class_name=table_class_name, + class_kwargs={ + "tablename": constraints[0].tablename, + "schema": constraints[0].schema, + }, + ) + + for constraint in constraints: + await self._run_query( + _Table.alter().drop_constraint( + constraint_name=constraint.constraint_name + ) + ) + async def run(self, backwards: bool = False): direction = "backwards" if backwards else "forwards" if self.preview: @@ -981,6 +1181,10 @@ async def run(self, backwards: bool = False): # "ALTER COLUMN TYPE is not supported inside a transaction" if engine.engine_type != "cockroach": await self._run_alter_columns(backwards=backwards) + await self._run_add_constraints(backwards=backwards) + await self._run_drop_constraints(backwards=backwards) if engine.engine_type == "cockroach": await self._run_alter_columns(backwards=backwards) + await self._run_add_constraints(backwards=backwards) + await self._run_drop_constraints(backwards=backwards) diff --git a/piccolo/apps/migrations/auto/operations.py b/piccolo/apps/migrations/auto/operations.py index 0676bdbd4..a7d28aa20 100644 --- a/piccolo/apps/migrations/auto/operations.py +++ b/piccolo/apps/migrations/auto/operations.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from piccolo.columns.base import Column +from piccolo.constraints import Constraint @dataclass @@ -63,3 +64,21 @@ class AddColumn: column_class: t.Type[Column] params: t.Dict[str, t.Any] schema: t.Optional[str] = None + + +@dataclass +class AddConstraint: + table_class_name: str + constraint_name: str + constraint_class_name: str + constraint_class: t.Type[Constraint] + params: t.Dict[str, t.Any] + schema: t.Optional[str] = None + + +@dataclass +class DropConstraint: + table_class_name: str + constraint_name: str + tablename: str + schema: t.Optional[str] = None diff --git a/piccolo/apps/migrations/auto/schema_differ.py b/piccolo/apps/migrations/auto/schema_differ.py index 1d095b938..8fe1028d0 100644 --- a/piccolo/apps/migrations/auto/schema_differ.py +++ b/piccolo/apps/migrations/auto/schema_differ.py @@ -613,6 +613,69 @@ def add_columns(self) -> AlterStatements: extra_definitions=extra_definitions, ) + @property + def add_constraints(self) -> AlterStatements: + response: t.List[str] = [] + extra_imports: t.List[Import] = [] + extra_definitions: t.List[Definition] = [] + for table in self.schema: + snapshot_table = self._get_snapshot_table(table.class_name) + if snapshot_table: + delta: TableDelta = table - snapshot_table + else: + continue + + for add_constraint in delta.add_constraints: + constraint_class = add_constraint.constraint_class + extra_imports.append( + Import( + module=constraint_class.__module__, + target=constraint_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{constraint_class.__name__.upper()}", + None, + ), + ) + ) + + schema_str = ( + "None" + if add_constraint.schema is None + else f'"{add_constraint.schema}"' + ) + + response.append( + f"manager.add_constraint(table_class_name='{table.class_name}', tablename='{table.tablename}', constraint_name='{add_constraint.constraint_name}', constraint_class={constraint_class.__name__}, params={add_constraint.params}, schema={schema_str})" # noqa: E501 + ) + return AlterStatements( + statements=response, + extra_imports=extra_imports, + extra_definitions=extra_definitions, + ) + + @property + def drop_constraints(self) -> AlterStatements: + response = [] + for table in self.schema: + snapshot_table = self._get_snapshot_table(table.class_name) + if snapshot_table: + delta: TableDelta = table - snapshot_table + else: + continue + + for constraint in delta.drop_constraints: + schema_str = ( + "None" + if constraint.schema is None + else f'"{constraint.schema}"' + ) + + response.append( + f"manager.drop_constraint(table_class_name='{table.class_name}', tablename='{table.tablename}', constraint_name='{constraint.constraint_name}', schema={schema_str})" # noqa: E501 + ) + return AlterStatements(statements=response) + @property def rename_columns(self) -> AlterStatements: alter_statements = AlterStatements() @@ -679,6 +742,48 @@ def new_table_columns(self) -> AlterStatements: extra_definitions=extra_definitions, ) + @property + def new_table_constraints(self) -> AlterStatements: + new_tables: t.List[DiffableTable] = list( + set(self.schema) - set(self.schema_snapshot) + ) + + response: t.List[str] = [] + extra_imports: t.List[Import] = [] + extra_definitions: t.List[Definition] = [] + for table in new_tables: + if ( + table.class_name + in self.rename_tables_collection.new_class_names + ): + continue + + for constraint in table.constraints: + extra_imports.append( + Import( + module=constraint.__class__.__module__, + target=constraint.__class__.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{constraint.__class__.__name__.upper()}", + None, + ), + ) + ) + + schema_str = ( + "None" if table.schema is None else f'"{table.schema}"' + ) + + response.append( + f"manager.add_constraint(table_class_name='{table.class_name}', tablename='{table.tablename}', constraint_name='{constraint._meta.name}', constraint_class={constraint.__class__.__name__}, params={constraint._meta.params}, schema={schema_str})" # noqa: E501 + ) + return AlterStatements( + statements=response, + extra_imports=extra_imports, + extra_definitions=extra_definitions, + ) + ########################################################################### def get_alter_statements(self) -> t.List[AlterStatements]: @@ -691,10 +796,13 @@ def get_alter_statements(self) -> t.List[AlterStatements]: "Renamed tables": self.rename_tables, "Tables which changed schema": self.change_table_schemas, "Created table columns": self.new_table_columns, + "Created table constraints": self.new_table_constraints, "Dropped columns": self.drop_columns, "Columns added to existing tables": self.add_columns, "Renamed columns": self.rename_columns, "Altered columns": self.alter_columns, + "Dropped constraints": self.drop_constraints, + "Constraints added to existing tables": self.add_constraints, } for message, statements in alter_statements.items(): diff --git a/piccolo/apps/migrations/auto/schema_snapshot.py b/piccolo/apps/migrations/auto/schema_snapshot.py index 45963b717..50d8128d7 100644 --- a/piccolo/apps/migrations/auto/schema_snapshot.py +++ b/piccolo/apps/migrations/auto/schema_snapshot.py @@ -112,4 +112,23 @@ def get_snapshot(self) -> t.List[DiffableTable]: rename_column.new_db_column_name ) + add_constraints = ( + manager.add_constraints.constraints_for_table_class_name( + table.class_name + ) + ) + table.constraints.extend(add_constraints) + + drop_constraints = ( + manager.drop_constraints.for_table_class_name( + table.class_name + ) + ) + for drop_constraint in drop_constraints: + table.constraints = [ + i + for i in table.constraints + if i._meta.name != drop_constraint.constraint_name + ] + return tables diff --git a/piccolo/apps/migrations/commands/new.py b/piccolo/apps/migrations/commands/new.py index 082868435..06ee5462e 100644 --- a/piccolo/apps/migrations/commands/new.py +++ b/piccolo/apps/migrations/commands/new.py @@ -193,6 +193,7 @@ async def get_alter_statements( class_name=i.__name__, tablename=i._meta.tablename, columns=i._meta.non_default_columns, + constraints=i._meta.constraints, schema=i._meta.schema, ) for i in app_config.table_classes diff --git a/piccolo/apps/playground/commands/run.py b/piccolo/apps/playground/commands/run.py index 343c02e38..80c0f7f33 100644 --- a/piccolo/apps/playground/commands/run.py +++ b/piccolo/apps/playground/commands/run.py @@ -24,6 +24,7 @@ Varchar, ) from piccolo.columns.readable import Readable +from piccolo.constraints import Check, Unique from piccolo.engine import PostgresEngine, SQLiteEngine from piccolo.engine.base import Engine from piccolo.table import Table @@ -106,6 +107,9 @@ class TicketType(Enum): price = Numeric(digits=(5, 2)) ticket_type = Varchar(choices=TicketType, default=TicketType.standing) + unique_concert_ticket_type = Unique(columns=[concert, ticket_type]) + check_price = Check((price > 0) & (price < 200)) + @classmethod def get_readable(cls) -> Readable: return Readable( diff --git a/piccolo/columns/combination.py b/piccolo/columns/combination.py index e080cced2..9aab123cb 100644 --- a/piccolo/columns/combination.py +++ b/piccolo/columns/combination.py @@ -52,6 +52,14 @@ def querystring_for_update(self) -> QueryString: self.second.querystring_for_update, ) + @property + def querystring_for_constraint(self) -> QueryString: + return QueryString( + "({} " + self.operator + " {})", + self.first.querystring_for_constraint, + self.second.querystring_for_constraint, + ) + def __str__(self): return self.querystring.__str__() @@ -134,6 +142,10 @@ def __init__(self, sql: str, *args: t.Any) -> None: def querystring_for_update(self) -> QueryString: return self.querystring + @property + def querystring_for_constraint(self) -> QueryString: + return self.querystring + def __str__(self): return self.querystring.__str__() @@ -188,13 +200,7 @@ def clean_value(self, value: t.Any) -> t.Any: """ return convert_to_sql_value(value=value, column=self.column) - @property - def values_querystring(self) -> QueryString: - values = self.values - - if isinstance(values, Undefined): - raise ValueError("values is undefined") - + def get_values_querystring(self, values) -> QueryString: template = ", ".join("{}" for _ in values) return QueryString(template, *values) @@ -205,7 +211,7 @@ def querystring(self) -> QueryString: args.append(self.value) if self.values != UNDEFINED: - args.append(self.values_querystring) + args.append(self.get_values_querystring(self.values)) template = self.operator.template.format( name=self.column.get_where_string( @@ -224,7 +230,7 @@ def querystring_for_update(self) -> QueryString: args.append(self.value) if self.values != UNDEFINED: - args.append(self.values_querystring) + args.append(self.get_values_querystring(self.values)) column = self.column @@ -249,5 +255,47 @@ def querystring_for_update(self) -> QueryString: return QueryString(template, *args) + @property + def querystring_for_constraint(self) -> QueryString: + """ + This is used for check constraints - the main difference is we + don't prefix the column name with the table name. + """ + + from piccolo.columns.base import Column + + def stringify_column(column: Column) -> str: + return f'"{column._meta.db_column_name}"' + + args: t.List[t.Any] = [] + if self.value != UNDEFINED: + args.append( + QueryString(stringify_column(self.value)) + if isinstance(self.value, Column) + else self.value + ) + + if not isinstance(self.values, Undefined): + args.append( + self.get_values_querystring( + values=[ + ( + QueryString(stringify_column(value)) + if isinstance(self.value, Column) + else self.value + ) + for value in self.values + ] + ) + ) + + template = self.operator.template.format( + name=stringify_column(self.column), + value="{}", + values="{}", + ) + + return QueryString(template, *args) + def __str__(self): return self.querystring.__str__() diff --git a/piccolo/constraints.py b/piccolo/constraints.py new file mode 100644 index 000000000..424dbd3ed --- /dev/null +++ b/piccolo/constraints.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import typing as t +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, field + +if t.TYPE_CHECKING: + from piccolo.columns import Column + from piccolo.custom_types import Combinable + + +class ConstraintConfig(metaclass=ABCMeta): + + name: str + + @abstractmethod + def to_constraint(self) -> Constraint: + """ + Override in subclasses. + """ + raise NotImplementedError() + + +class Unique(ConstraintConfig): + """ + Add a unique constraint to one or more columns. For example:: + + from piccolo.constraints import Unique + + class Album(Table): + name = Varchar() + band = ForeignKey(Band) + + unique_name_band = Unique([name, band]) + + + In the above example, the database will enforce that ``name`` and + ``band`` form a unique combination. + + :param columns: + The table columns that should be unique together. + :param nulls_distinct: + See the `Postgres docs `_ + for more information. + + """ # noqa: E501 + + def __init__( + self, + columns: t.List[t.Union[Column, str]], + nulls_distinct: bool = True, + ): + if len(columns) < 1: + raise ValueError("At least 1 column must be specified.") + + self.columns = columns + self.nulls_distinct = nulls_distinct + + def to_constraint(self) -> UniqueConstraint: + """ + You should wait for the ``Table`` metaclass to assign names to all of + the columns before calling this method. + """ + from piccolo.columns import Column + + column_names = [ + ( + column._meta.db_column_name + if isinstance(column, Column) + else column + ) + for column in self.columns + ] + + return UniqueConstraint( + column_names=column_names, + name=self.name, + nulls_distinct=self.nulls_distinct, + ) + + +class Check(ConstraintConfig): + """ + Add a check constraint to the table. For example:: + + from piccolo.constraints import Check + + class Ticket(Table): + price = Decimal() + + check_price_positive = Check(price >= 0) + + You can have more complex conditions. For example:: + + from piccolo.constraints import Check + + class Ticket(Table): + price = Decimal() + + check_price_range = Check( + (price >= 0) & (price < 100) + ) + + :param condition: + The syntax is the same as the ``where`` clause used by most + queries (e.g. ``select``). + + """ + + def __init__( + self, + condition: t.Union[Combinable, str], + ): + self.condition = condition + + def to_constraint(self) -> CheckConstraint: + """ + You should wait for the ``Table`` metaclass to assign names to all of + the columns before calling this method. + """ + from piccolo.columns.combination import CombinableMixin + + if isinstance(self.condition, CombinableMixin): + condition_str = self.condition.querystring_for_constraint.__str__() + else: + condition_str = self.condition + + return CheckConstraint( + condition=condition_str, + name=self.name, + ) + + +############################################################################### + + +class Constraint(metaclass=ABCMeta): + """ + All other constraints inherit from ``Constraint``. Don't use it directly. + """ + + def __init__(self, name: str, **kwargs) -> None: + kwargs.update(name=name) + self._meta = ConstraintMeta(name=name, params=kwargs) + + def __hash__(self): + return hash(self._meta.name) + + @property + @abstractmethod + def ddl(self) -> str: + raise NotImplementedError + + @abstractmethod + def _table_str(self) -> str: + raise NotImplementedError + + +@dataclass +class ConstraintMeta: + """ + This is used to store info about the constraint. + """ + + name: str + + # Used for representing the table in migrations. + params: t.Dict[str, t.Any] = field(default_factory=dict) + + +class UniqueConstraint(Constraint): + """ + Unique constraint on the table columns. + + This is the internal representation that Piccolo uses for constraints - + the user just supplies ``Unique``. + """ + + def __init__( + self, + column_names: t.Sequence[str], + name: str, + nulls_distinct: bool = True, + ) -> None: + """ + :param columns: + The table columns that should be unique together. + :param name: + The name of the constraint in the database. + :param nulls_distinct: + See the `Postgres docs `_ + for more information. + + """ # noqa: E501 + if len(column_names) < 1: + raise ValueError("At least 1 column must be specified.") + + self.column_names = column_names + self.nulls_distinct = nulls_distinct + + super().__init__( + name=name, + column_names=column_names, + nulls_distinct=nulls_distinct, + ) + + @property + def ddl(self) -> str: + nulls_string = ( + "NULLS NOT DISTINCT " if self.nulls_distinct is False else "" + ) + columns_string = ", ".join(f'"{i}"' for i in self.column_names) + return f"UNIQUE {nulls_string}({columns_string})" + + def _table_str(self) -> str: + columns_string = ", ".join( + [f'"{column_name}"' for column_name in self.column_names] + ) + return ( + f"{self._meta.name} = Unique([{columns_string}], " + f"nulls_distinct={self.nulls_distinct})" + ) + + +class CheckConstraint(Constraint): + """ + Check constraint on the table. + + This is the internal representation that Piccolo uses for constraints - + the user just supplies ``Check``. + """ + + def __init__( + self, + condition: str, + name: str, + ) -> None: + """ + :param condition: + The SQL expression used to make sure the data is valid (e.g. + ``"price > 0"``). + :param name: + The name of the constraint in the database. + + """ + self.condition = condition + super().__init__(name=name, condition=condition) + + @property + def ddl(self) -> str: + return f"CHECK ({self.condition})" + + def _table_str(self) -> str: + return f'{self._meta.name} = Check("{self.condition}")' diff --git a/piccolo/query/methods/alter.py b/piccolo/query/methods/alter.py index 040b2f883..fbfdf5755 100644 --- a/piccolo/query/methods/alter.py +++ b/piccolo/query/methods/alter.py @@ -6,6 +6,7 @@ from piccolo.columns.base import Column from piccolo.columns.column_types import ForeignKey, Numeric, Varchar +from piccolo.constraints import Constraint from piccolo.query.base import DDL from piccolo.utils.warnings import Level, colored_warning @@ -177,6 +178,17 @@ def ddl(self) -> str: return f'ALTER COLUMN "{self.column_name}" TYPE VARCHAR({self.length})' +@dataclass +class AddConstraint(AlterStatement): + __slots__ = ("constraint",) + + constraint: Constraint + + @property + def ddl(self) -> str: + return f"ADD CONSTRAINT {self.constraint._meta.name} {self.constraint.ddl}" # noqa: E501 + + @dataclass class DropConstraint(AlterStatement): __slots__ = ("constraint_name",) @@ -275,6 +287,7 @@ class Alter(DDL): __slots__ = ( "_add_foreign_key_constraint", "_add", + "_add_constraint", "_drop_constraint", "_drop_default", "_drop_table", @@ -294,6 +307,7 @@ def __init__(self, table: t.Type[Table], **kwargs): super().__init__(table, **kwargs) self._add_foreign_key_constraint: t.List[AddForeignKeyConstraint] = [] self._add: t.List[AddColumn] = [] + self._add_constraint: t.List[AddConstraint] = [] self._drop_constraint: t.List[DropConstraint] = [] self._drop_default: t.List[DropDefault] = [] self._drop_table: t.Optional[DropTable] = None @@ -490,6 +504,10 @@ def _get_constraint_name(self, column: t.Union[str, ForeignKey]) -> str: tablename = self.table._meta.tablename return f"{tablename}_{column_name}_fk" + def add_constraint(self, constraint: Constraint) -> Alter: + self._add_constraint.append(AddConstraint(constraint=constraint)) + return self + def drop_constraint(self, constraint_name: str) -> Alter: self._drop_constraint.append( DropConstraint(constraint_name=constraint_name) @@ -590,6 +608,8 @@ def default_ddl(self) -> t.Sequence[str]: self._set_default, self._set_digits, self._set_schema, + self._add_constraint, + self._drop_constraint, ) ] diff --git a/piccolo/query/methods/create.py b/piccolo/query/methods/create.py index 68cccf6b2..418c45e5b 100644 --- a/piccolo/query/methods/create.py +++ b/piccolo/query/methods/create.py @@ -3,6 +3,7 @@ import typing as t from piccolo.query.base import DDL +from piccolo.query.methods.alter import AddConstraint from piccolo.query.methods.create_index import CreateIndex if t.TYPE_CHECKING: # pragma: no cover @@ -87,4 +88,9 @@ def default_ddl(self) -> t.Sequence[str]: ).ddl ) + for constraint in self.table._meta.constraints: + ddl.append( + f"ALTER TABLE {self.table._meta.get_formatted_tablename()} {AddConstraint(constraint=constraint).ddl}" # noqa: E501 + ) + return ddl diff --git a/piccolo/querystring.py b/piccolo/querystring.py index 7dec758a8..8bc080dcd 100644 --- a/piccolo/querystring.py +++ b/piccolo/querystring.py @@ -3,7 +3,7 @@ import typing as t from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from datetime import datetime +from datetime import date, datetime from importlib.util import find_spec from string import Formatter @@ -138,6 +138,10 @@ def __str__(self): """ The SQL returned by the ``__str__`` method isn't used directly in queries - it's just a usability feature. + + The only exception to this is CHECK constraints, where we use this to + convert simple querystrings into strings. + """ _, bundled, combined_args = self.bundle( start_index=1, bundled=[], combined_args=[] @@ -153,7 +157,7 @@ def __str__(self): _type = type(arg) if _type == str: converted_args.append(f"'{arg}'") - elif _type == datetime: + elif _type == datetime or _type == date: dt_string = arg.isoformat() converted_args.append(f"'{dt_string}'") elif _type == UUID or _type == apgUUID: diff --git a/piccolo/table.py b/piccolo/table.py index bae9b8a47..f63793de4 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -28,6 +28,7 @@ ) from piccolo.columns.readable import Readable from piccolo.columns.reference import LAZY_COLUMN_REFERENCES +from piccolo.constraints import Constraint, ConstraintConfig from piccolo.custom_types import TableInstance from piccolo.engine import Engine, engine_finder from piccolo.query import ( @@ -84,6 +85,7 @@ class TableMeta: primary_key: Column = field(default_factory=Column) json_columns: t.List[t.Union[JSON, JSONB]] = field(default_factory=list) secret_columns: t.List[Secret] = field(default_factory=list) + constraints: t.List[Constraint] = field(default_factory=list) auto_update_columns: t.List[Column] = field(default_factory=list) tags: t.List[str] = field(default_factory=list) help_text: t.Optional[str] = None @@ -173,6 +175,15 @@ def get_column_by_name(self, name: str) -> Column: return column_object + def get_constraint_by_name(self, name: str) -> Constraint: + """ + Returns a constraint which matches the given name. + """ + for constraint in self.constraints: + if constraint._meta.name == name: + return constraint + raise ValueError(f"No matching constraint found with name == {name}") + def get_auto_update_values(self) -> t.Dict[Column, t.Any]: """ If columns have ``auto_update`` defined, then we retrieve these values. @@ -279,6 +290,7 @@ def __init_subclass__( auto_update_columns: t.List[Column] = [] primary_key: t.Optional[Column] = None m2m_relationships: t.List[M2M] = [] + constraint_configs: t.List[ConstraintConfig] = [] attribute_names = itertools.chain( *[i.__dict__.keys() for i in reversed(cls.__mro__)] @@ -291,11 +303,14 @@ def __init_subclass__( attribute = getattr(cls, attribute_name) if isinstance(attribute, Column): + column = attribute + column._meta._name = attribute_name + # We have to copy, then override the existing column # definition, in case this column is inheritted from a mixin. # Otherwise, when we set attributes on that column, it will # effect all other users of that mixin. - column = attribute.copy() + column = column.copy() setattr(cls, attribute_name, column) if column._meta.primary_key: @@ -304,7 +319,6 @@ def __init_subclass__( non_default_columns.append(column) columns.append(column) - column._meta._name = attribute_name column._meta._table = cls if isinstance(column, Array): @@ -331,6 +345,10 @@ def __init_subclass__( attribute._meta._table = cls m2m_relationships.append(attribute) + if isinstance(attribute, ConstraintConfig): + attribute.name = attribute_name + constraint_configs.append(attribute) + if not primary_key: primary_key = cls._create_serial_primary_key() setattr(cls, "id", primary_key) @@ -368,6 +386,12 @@ def __init_subclass__( foreign_key_column ) + # Now the table and columns are all setup, we can do the constraints. + constraints: t.List[Constraint] = [] + for constraint_config in constraint_configs: + constraints.append(constraint_config.to_constraint()) + cls._meta.constraints = constraints + TABLE_REGISTRY.append(cls) def __init__( @@ -1353,7 +1377,7 @@ def _table_str( if excluded_params is None: excluded_params = [] spacer = "\n " - columns = [] + column_strings: t.List[str] = [] for col in cls._meta.columns: params: t.List[str] = [] for key, value in col._meta.params.items(): @@ -1369,10 +1393,16 @@ def _table_str( if not abbreviated: params.append(f"{key}={_value}") params_string = ", ".join(params) - columns.append( + column_strings.append( f"{col._meta.name} = {col.__class__.__name__}({params_string})" ) - columns_string = spacer.join(columns) + columns_string = spacer.join(column_strings) + + constraint_strings: t.List[str] = [] + for constraint in cls._meta.constraints: + constraint_strings.append(constraint._table_str()) + constraints_string = spacer.join(constraint_strings) + tablename = repr(cls._meta.tablename) parent_class_name = cls.mro()[1].__name__ @@ -1384,7 +1414,9 @@ def _table_str( ) return ( - f"class {cls.__name__}({class_args}):\n" f" {columns_string}\n" + f"class {cls.__name__}({class_args}):\n" + f" {columns_string}\n" + f" {constraints_string}\n" ) diff --git a/tests/apps/migrations/auto/test_migration_manager.py b/tests/apps/migrations/auto/test_migration_manager.py index 6e71846e0..ffbe77b9b 100644 --- a/tests/apps/migrations/auto/test_migration_manager.py +++ b/tests/apps/migrations/auto/test_migration_manager.py @@ -11,6 +11,7 @@ from piccolo.columns.base import OnDelete, OnUpdate from piccolo.columns.column_types import ForeignKey from piccolo.conf.apps import AppConfig +from piccolo.constraints import UniqueConstraint from piccolo.engine import engine_finder from piccolo.table import Table, sort_table_classes from piccolo.utils.lazy_loader import LazyLoader @@ -337,6 +338,270 @@ def test_add_column(self) -> None: if engine_is("cockroach"): self.assertEqual(response, [{"id": row_id, "name": "Dave"}]) + @engines_only("postgres", "cockroach") + def test_add_table_with_unique_constraint(self): + manager = MigrationManager() + manager.add_table(class_name="Musician", tablename="musician") + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="unique_name_label", + constraint_class=UniqueConstraint, + params={ + "column_names": ["name", "label"], + }, + ) + asyncio.run(manager.run()) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + + @engines_only("postgres", "cockroach") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_drop_table_with_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager_1.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="unique_name_label", + constraint_class=UniqueConstraint, + params={ + "column_names": ["name", "label"], + }, + ) + asyncio.run(manager_1.run()) + + # Drop table + manager_2 = MigrationManager() + manager_2.drop_table( + class_name="Musician", + tablename="musician", + ) + asyncio.run(manager_2.run()) + self.assertTrue(not self.table_exists("musician")) + + # Reverse + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager_1.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + + @engines_only("postgres", "cockroach") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_rename_table_with_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager_1.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="unique_name_label", + constraint_class=UniqueConstraint, + params={ + "column_names": ["name", "label"], + }, + ) + asyncio.run(manager_1.run()) + + # Rename table + manager_2 = MigrationManager() + manager_2.rename_table( + old_class_name="Musician", + old_tablename="musician", + new_class_name="Musician2", + new_tablename="musician2", + ) + asyncio.run(manager_2.run()) + self.assertTrue(not self.table_exists("musician")) + self.run_sync("INSERT INTO musician2 VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician2 VALUES (default, 'a', 'a');") + + # Reverse + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + self.assertTrue(not self.table_exists("musician2")) + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager_1.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + self.assertTrue(not self.table_exists("musician2")) + + @engines_only("postgres") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_add_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + """ + Test adding a unique constraint to a MigrationManager. + Cockroach DB doesn't support dropping unique constraints with ALTER TABLE DROP CONSTRAINT. + https://github.com/cockroachdb/cockroach/issues/42840 + """ # noqa: E501 + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + asyncio.run(manager_1.run()) + + manager_2 = MigrationManager() + manager_2.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="musician_unique", + constraint_class=UniqueConstraint, + params={ + "column_names": ["name", "label"], + }, + ) + asyncio.run(manager_2.run()) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager_1.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + + @engines_only("postgres") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_drop_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + """ + Test dropping a unique constraint with a MigrationManager. + Cockroach DB doesn't support dropping unique constraints with ALTER TABLE DROP CONSTRAINT. + https://github.com/cockroachdb/cockroach/issues/42840 + """ # noqa: E501 + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager_1.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="musician_unique", + constraint_class=UniqueConstraint, + params={ + "column_names": ["name", "label"], + }, + ) + asyncio.run(manager_1.run()) + + manager_2 = MigrationManager() + manager_2.drop_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="musician_unique", + ) + asyncio.run(manager_2.run()) + + # Reverse + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager_1.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + @engines_only("postgres", "cockroach") def test_add_column_with_index(self): """ diff --git a/tests/apps/migrations/auto/test_schema_differ.py b/tests/apps/migrations/auto/test_schema_differ.py index 9cf6d26f2..376ddc975 100644 --- a/tests/apps/migrations/auto/test_schema_differ.py +++ b/tests/apps/migrations/auto/test_schema_differ.py @@ -13,6 +13,7 @@ SchemaDiffer, ) from piccolo.columns.column_types import Numeric, Varchar +from piccolo.constraints import UniqueConstraint class TestSchemaDiffer(TestCase): @@ -488,6 +489,160 @@ def test_db_column_name(self) -> None: "manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', db_column_name='custom', params={'digits': (4, 2)}, old_params={'digits': (5, 2)}, column_class=Numeric, old_column_class=Numeric, schema=None)", # noqa ) + def test_add_table_with_constraint(self) -> None: + """ + Test adding a new table with a constraint. + """ + name_column = Varchar() + name_column._meta.name = "name" + + genre_column = Varchar() + genre_column._meta.name = "genre" + + name_genre_unique_constraint = UniqueConstraint( + column_names=["name", "genre"], + name="unique_name_genre", + ) + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[name_column, genre_column], + constraints=[name_genre_unique_constraint], + ) + ] + schema_snapshot: t.List[DiffableTable] = [] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + create_tables = schema_differ.create_tables + self.assertTrue(len(create_tables.statements) == 1) + self.assertEqual( + create_tables.statements[0], + "manager.add_table(class_name='Band', tablename='band', schema=None, columns=None)", # noqa: E501 + ) + + new_table_columns = schema_differ.new_table_columns + self.assertTrue(len(new_table_columns.statements) == 2) + self.assertEqual( + new_table_columns.statements[0], + "manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False}, schema=None)", # noqa + ) + self.assertEqual( + new_table_columns.statements[1], + "manager.add_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False}, schema=None)", # noqa + ) + + new_table_constraints = schema_differ.new_table_constraints + self.assertTrue(len(new_table_constraints.statements) == 1) + self.assertEqual( + new_table_constraints.statements[0], + "manager.add_constraint(table_class_name='Band', tablename='band', constraint_name='unique_name_genre', constraint_class=UniqueConstraint, params={'column_names': ['name', 'genre']}, schema=None)", # noqa + ) + + def test_add_constraint(self) -> None: + """ + Test adding a constraint to an existing table. + """ + name_column = Varchar() + name_column._meta.name = "name" + + genre_column = Varchar() + genre_column._meta.name = "genre" + + name_unique_constraint = UniqueConstraint( + column_names=["name"], + name="unique_name", + ) + + name_genre_unique_constraint = UniqueConstraint( + column_names=["name", "genre"], + name="unique_name_genre", + ) + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[name_column, genre_column], + constraints=[ + name_unique_constraint, + name_genre_unique_constraint, + ], + ) + ] + schema_snapshot: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[name_column, genre_column], + constraints=[name_unique_constraint], + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.add_constraints.statements) == 1) + self.assertEqual( + schema_differ.add_constraints.statements[0], + "manager.add_constraint(table_class_name='Band', tablename='band', constraint_name='unique_name_genre', constraint_class=UniqueConstraint, params={'column_names': ['name', 'genre']}, schema=None)", # noqa: E501 + ) + + def test_drop_constraint(self) -> None: + """ + Test dropping a constraint from an existing table. + """ + name_column = Varchar() + name_column._meta.name = "name" + + genre_column = Varchar() + genre_column._meta.name = "genre" + + name_unique_constraint = UniqueConstraint( + column_names=["name"], + name="unique_name", + ) + + name_genre_unique_constraint = UniqueConstraint( + column_names=["name", "genre"], + name="unique_name_genre", + ) + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[name_column, genre_column], + constraints=[name_unique_constraint], + ) + ] + schema_snapshot: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[name_column, genre_column], + constraints=[ + name_unique_constraint, + name_genre_unique_constraint, + ], + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.drop_constraints.statements) == 1) + self.assertEqual( + schema_differ.drop_constraints.statements[0], + "manager.drop_constraint(table_class_name='Band', tablename='band', constraint_name='unique_name_genre', schema=None)", # noqa: E501 + ) + def test_alter_default(self): pass diff --git a/tests/apps/migrations/auto/test_schema_snapshot.py b/tests/apps/migrations/auto/test_schema_snapshot.py index 834551f8a..dbb0a761d 100644 --- a/tests/apps/migrations/auto/test_schema_snapshot.py +++ b/tests/apps/migrations/auto/test_schema_snapshot.py @@ -1,6 +1,7 @@ from unittest import TestCase from piccolo.apps.migrations.auto import MigrationManager, SchemaSnapshot +from piccolo.constraints import UniqueConstraint class TestSchemaSnaphot(TestCase): @@ -187,3 +188,67 @@ def test_get_table_from_snapshot(self): with self.assertRaises(ValueError): schema_snapshot.get_table_from_snapshot("Foo") + + def test_add_constraint(self): + """ + Test adding constraints. + """ + manager = MigrationManager() + manager.add_table(class_name="Manager", tablename="manager") + manager.add_column( + table_class_name="Manager", + tablename="manager", + column_name="name", + column_class_name="Varchar", + params={"length": 100}, + ) + manager.add_constraint( + table_class_name="Manager", + tablename="manager", + constraint_name="unique_name", + constraint_class=UniqueConstraint, + params={ + "column_names": ["name"], + }, + ) + + schema_snapshot = SchemaSnapshot(managers=[manager]) + snapshot = schema_snapshot.get_snapshot() + + self.assertTrue(len(snapshot) == 1) + self.assertTrue(len(snapshot[0].columns) == 1) + self.assertTrue(len(snapshot[0].constraints) == 1) + + def test_drop_constraint(self): + """ + Test dropping constraints. + """ + manager_1 = MigrationManager() + manager_1.add_table(class_name="Manager", tablename="manager") + manager_1.add_column( + table_class_name="Manager", + tablename="manager", + column_name="name", + column_class_name="Varchar", + params={"length": 100}, + ) + manager_1.add_constraint( + table_class_name="Manager", + tablename="manager", + constraint_name="unique_name", + constraint_class=UniqueConstraint, + params={ + "column_names": ["name"], + }, + ) + + manager_2 = MigrationManager() + manager_2.drop_constraint( + table_class_name="Manager", + tablename="manager", + constraint_name="unique_name", + ) + + schema_snapshot = SchemaSnapshot(managers=[manager_1, manager_2]) + snapshot = schema_snapshot.get_snapshot() + self.assertEqual(len(snapshot[0].constraints), 0)