From 02218af3df39452a9c21b5677f530bf29541c527 Mon Sep 17 00:00:00 2001 From: David Sanders <> Date: Fri, 17 Mar 2023 23:41:14 +1100 Subject: [PATCH] Abusing constraints --- README.md | 1 + abusing_constraints/README.md | 386 ++++++++++++++++++ abusing_constraints/__init__.py | 0 abusing_constraints/apps.py | 6 + abusing_constraints/constraints.py | 308 ++++++++++++++ .../migrations/0001_initial.py | 172 ++++++++ abusing_constraints/migrations/__init__.py | 0 abusing_constraints/models.py | 108 +++++ abusing_constraints/tests.py | 66 +++ stupid_django_tricks/settings.py | 1 + 10 files changed, 1048 insertions(+) create mode 100644 abusing_constraints/README.md create mode 100644 abusing_constraints/__init__.py create mode 100644 abusing_constraints/apps.py create mode 100644 abusing_constraints/constraints.py create mode 100644 abusing_constraints/migrations/0001_initial.py create mode 100644 abusing_constraints/migrations/__init__.py create mode 100644 abusing_constraints/models.py create mode 100644 abusing_constraints/tests.py diff --git a/README.md b/README.md index 1ecbbdb..8e64637 100644 --- a/README.md +++ b/README.md @@ -11,3 +11,4 @@ Various tricks with Django - some silly, some quite useful. 4. [Singleton Models](./singleton_models) 5. [Generated Columns](./generated_columns) 6. [ALL Subqueries](./all_subqueries) +7. [Having Fun with Constraints](./abusing_constraints) diff --git a/abusing_constraints/README.md b/abusing_constraints/README.md new file mode 100644 index 0000000..f13815e --- /dev/null +++ b/abusing_constraints/README.md @@ -0,0 +1,386 @@ +Having Fun with Constraints +=========================== + +March 2023 + + +Introduction to Constraints +--------------------------- + +[Django constraints](https://docs.djangoproject.com/en/4.1/ref/models/constraints/) allow you to add different types of table-level database +constraints to your models. + +Django currently supports check & unique constraints. Both of these extend from `BaseConstraint` which provide the +necessary hooks for forwards/reverse migrations, makemigrations as well as validation of constraints from model & form instances. + +A Django constraint follows this basic pattern: + +```python +class CustomConstraint(BaseConstraint): + def __init__(self, *, name, custom_param, violation_error_message=None): + # define your custom params as needed, be sure to pass name & violation_error_message to BaseConstraint + super().__init__(name, violation_error_message) + self.custom_param = custom_param + + def constraint_sql(self, model, schema_editor): + # Used by migrations to define the constraint when included within CREATE TABLE + # Eg: CHECK price > 0 + # Note: model is fake + + def create_sql(self, model, schema_editor): + # Used by migrations to define the constraint when included with ALTER TABLE + # Eg: ALTER TABLE table ADD CONSTRAINT name CHECK price > 0 + # Note: model is fake + + def remove_sql(self, model, schema_editor): + # Used for reverse migrations and done within ALTER TABLE + # Eg: ALTER TABLE table DROP CONSTRAINT name + # Note: model is fake + + def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): + # Validation done as part of model or form validation ie instance.validate_constraints() + # The model passed here is a real model. Use any database query as needed to verify the constraint. + # Validation failure must raise a ValidationError. + # Eg: Q(...filtering_as_required...).check() -> fail then raise ValidationError + + def __eq__(self): + # You *must* define this in order to prevent makemigrations from recreating your constraint migration operations. + if isinstance(other, CustomConstraint): + return ( + self.name == other.name + and self.custom_param == other.custom_param + and self.violation_error_message == other.violation_error_message + ) + return super().__eq__(other) + + def deconstruct(self): + # You must extend this to include additional params passed to __init__() for serialisation purposes. + path, args, kwargs = super().deconstruct() + kwargs["custom_param"] = self.custom_param + return path, args, kwargs +``` + + + +We can extend the supplied `BaseConstraint` in interesting ways: + + +Custom Constraint: Foreign Keys +------------------------------- + +### Defining composite foreign keys to enforce tenancy equality in multi-tenancy databases + +Django's foreign keys are basic single-column keys with primary key references. Sometimes we may want more complex foreign keys, like using +composite keys in multi-tenanted databases to enforce tenancy across relationships. + +Take the following trivial multi-tenancy design where tenanted models `Foo` and `Bar` are related with a foreign key: + +```python +class Tenant(models.Model): + ... + +class Foo(models.Model): + tenant = models.ForeignKey(Tenant, ...) + +class Bar(models.Model): + tenant = models.ForeignKey(Tenant, ...) + foo = models.ForeignKey(Foo, ...) +``` + +Adding an **extra** composite foreign key from `Bar` to `Foo`, with the tenant as part of the key, **enforces** tenant equality. ie +the foreign key now prevents relationships from existing where Foo and Bar belong to different tenants. + +To do this 2 steps are required: + +1. Create a unique index on the referenced model, Foo, for the key to target that will include the primary key + tenant: `(id, tenant_id)` +2. Add a supplementary foreign key on the referencing model, Bar, to reference the newly created index. + +```python +class Tenant(models.Model): + ... + +class Foo(models.Model): + tenant = models.ForeignKey(Tenant) + + class Meta: + constraints = [ + UniqueConstraint(fields=["id", "tenant"], ...) + ] + +class Bar(models.Model): + tenant = models.ForeignKey(Tenant, ...) + foo = models.ForeignKey(Foo, ...) + + class Meta: + constraints = [ + ForeignKeyConstraint( + fields=["foo_id", "tenant_id"], + to_table="abusing_constraints_foo", + to_fields=["id", "tenant_id"], + ... + ), + ] +``` + +This new foreign key constraint is a simple extension of `BaseConstraint`: + +```python +class ForeignKeyConstraint(BaseConstraint): + def __init__( + self, + *, + name, + fields, + to_table, + to_fields, + violation_error_message=None, + ): + super().__init__(name, violation_error_message=violation_error_message) + self.fields = fields + self.to_table = to_table + self.to_fields = to_fields + + def create_sql(self, model, schema_editor): + table = model._meta.db_table + constraint_sql = self.constraint_sql(model, schema_editor) + return f"ALTER TABLE {table} ADD CONSTRAINT {self.name} {constraint_sql}" + + def remove_sql(self, model, schema_editor): + table = model._meta.db_table + return f"ALTER TABLE {table} DROP CONSTRAINT {self.name}" + + def constraint_sql(self, model, schema_editor): + columns = ", ".join(self.fields) + to_columns = ", ".join(self.to_fields) + return f"FOREIGN KEY ({columns}) REFERENCES {self.to_table} ({to_columns})" + + def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): + with connection.cursor() as cursor: + # to keep things simple assume each field doesn't have a separate column name + where_clause = " AND ".join( + f"{field} = %({field})s" for field in self.to_fields + ) + params = { + field: getattr(instance, self.fields[i]) + for i, field in enumerate(self.to_fields) + } + table = self.to_table + cursor.execute(f"SELECT count(*) FROM {table} WHERE {where_clause}", params) + result = cursor.fetchone() + if result[0] == 0: + raise ValidationError(self.get_violation_error_message()) + + def __eq__(self, other): + if isinstance(other, ForeignKeyConstraint): + return ( + self.name == other.name + and self.violation_error_message == other.violation_error_message + and self.fields == other.fields + and self.to_table == other.to_table + and self.to_fields == other.to_fields + ) + return super().__eq__(other) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + kwargs["to_table"] = self.to_table + kwargs["fields"] = self.fields + kwargs["to_fields"] = self.to_fields + return path, args, kwargs +``` + + +Creating Arbitrary Database Artifacts +------------------------------------- + +> Constraints provide the ability to inject any database query – DDL or DML – into your migrations + +Things get interesting when we realise that constraints become a point where we can add custom database operations to our migrations as an alternative +to the manual `RunPython` or `RunSQL` migration operations. + + +### A RawSQL Constraint + +```python +class RawSQL(BaseConstraint): + def __init__(self, *, name, sql, reverse_sql): + super().__init__(name) + self.sql = sql + self.reverse_sql = reverse_sql + + def create_sql(self, model, schema_editor): + return self.sql + + def remove_sql(self, model, schema_editor): + return self.reverse_sql + + # These 2 methods don't apply for non-constraints + + def constraint_sql(self, model, schema_editor): + return None + + def validate(self, *args, **kwargs): + return True + + # ...other methods similarly defined as above +``` + +Here's what we can do with this: + +### Example: Stored Procedures + +Define a stored procedure for your model within the model's meta 😆 + +```python +data_stored_procedure = """\ +CREATE OR REPLACE PROCEDURE data_stored_procedure() +LANGUAGE SQL +AS $$ +INSERT INTO data (data) VALUES (99); +$$ +""" + +drop_data_stored_procedure = """\ +DROP PROCEDURE IF EXISTS data_stored_procedure CASCADE +""" + +class Data(models.Model): + data = models.IntegerField() + + class Meta: + db_table = "data" + constraints = [ + RawSQL( + name="data_stored_procedure", + sql=data_stored_procedure, + reverse_sql=drop_data_stored_procedure, + ), + ] +``` + +### A Callback Constraint + +An example with simple forwarding of the fake model & schema editor, noting that we need to **serialise** +the callbacks in order for it to be injected into your migrations. Obviously this limits what you can do but +it is possible for simple functions: + +```python +class Callback(BaseConstraint): + def __init__(self, *, name, callback, reverse_callback): + super().__init__(name) + self.callback = ( + marshal.dumps(callback.__code__) if callable(callback) else callback + ) + self.reverse_callback = ( + marshal.dumps(reverse_callback.__code__) + if callable(reverse_callback) + else reverse_callback + ) + + def create_sql(self, model, schema_editor): + code = marshal.loads(self.callback) + forwards = types.FunctionType(code, globals(), "forwards") + forwards(model, schema_editor) + + def remove_sql(self, model, schema_editor): + code = marshal.loads(self.reverse_callback) + reverse = types.FunctionType(code, globals(), "reverse") + reverse(model, schema_editor) + + # ...other methods similarly defined as above +``` + +### Example: Initial data + +Using the callback constraint to define initial model data: + +```python +def initial_data(model, schema_editor): + # (here model is fake) + queryset = model._default_manager.using(schema_editor.connection.alias) + queryset.bulk_create( + [ + model(data=1), + model(data=2), + model(data=3), + ] + ) + +def reverse_initial_data(model, schema_editor): + ... + +class Data(models.Model): + data = models.IntegerField() + + class Meta: + constraints = [ + Callback( + name="initial_data", + callback=initial_data, + reverse_callback=reverse_initial_data, + ), + ] +``` + + +Something More Useful: Views +---------------------------- + +Database views are a useful abstraction and a common method for using them in Django is to create the view in a migration, +then create an unmanaged model that refers to the view using `Meta.db_table`. + +To avoid the hassle of manual migrations an extension of constraints like so can be used (noting that constraints will only +be applied to managed models): + +```python +class Document(models.Model): + # This is the main document model + name = models.CharField(max_length=255) + is_archived = models.BooleanField(default=False) + + class Meta: + constraints = [] + +Document._meta.constraints += [ + View( + name="active_documents", + # Remember to forward the primary key (or create one) + query=Document.objects.filter(is_archived=False).values("id", "name"), + ), +] + +class ActiveDocument(models.Model): + # Model representing our database view "active_documents" + name = models.CharField(max_length=255) + + class Meta: + db_table = "active_documents" + managed = False +``` + +**Bonus advantage:** Simple PostgreSQL views like this are **automatically updatable** meaning that any save operations from Django will work! + +Here's what the `View` "constraint" could look like: + +```python +class View(BaseConstraint): + def __init__(self, *, name, query): + super().__init__(name) + if isinstance(query, str): + self.query = query + else: + self.query = str(query.query) + + def create_sql(self, model, schema_editor): + return f"CREATE OR REPLACE VIEW {self.name} AS {self.query}" + + def remove_sql(self, model, schema_editor): + return f"DROP VIEW IF EXISTS {self.name} CASCADE" + + # ...other methods similarly defined as above +``` + + +Further Reading +--------------- +See the [tests & code](.) for complete examples. diff --git a/abusing_constraints/__init__.py b/abusing_constraints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/abusing_constraints/apps.py b/abusing_constraints/apps.py new file mode 100644 index 0000000..1b955e4 --- /dev/null +++ b/abusing_constraints/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class AbusingConstraintsConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "abusing_constraints" diff --git a/abusing_constraints/constraints.py b/abusing_constraints/constraints.py new file mode 100644 index 0000000..fe4820b --- /dev/null +++ b/abusing_constraints/constraints.py @@ -0,0 +1,308 @@ +import marshal +import types + +from django.apps import apps +from django.core.exceptions import ValidationError +from django.db import connection +from django.db.backends.ddl_references import Columns, Statement, Table +from django.db.models import Model, Q +from django.db.models.constraints import BaseConstraint +from django.db.models.query import QuerySet +from django.db.utils import DEFAULT_DB_ALIAS + + +class BasicForeignKeyConstraint(BaseConstraint): + def __init__( + self, + *, + name, + fields, + to_table, + to_fields, + violation_error_message=None, + ): + super().__init__(name, violation_error_message=violation_error_message) + self.fields = fields + self.to_table = to_table + self.to_fields = to_fields + + def create_sql(self, model, schema_editor): + table = model._meta.db_table + constraint_sql = self.constraint_sql(model, schema_editor) + return f"ALTER TABLE {table} ADD CONSTRAINT {self.name} {constraint_sql}" + + def remove_sql(self, model, schema_editor): + table = model._meta.db_table + return f"ALTER TABLE {table} DROP CONSTRAINT {self.name}" + + def constraint_sql(self, model, schema_editor): + columns = ", ".join(self.fields) + to_columns = ", ".join(self.to_fields) + return f"FOREIGN KEY ({columns}) REFERENCES {self.to_table} ({to_columns})" + + def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): + with connection.cursor() as cursor: + # to keep things simple assume each field doesn't have a separate column name + where_clause = " AND ".join( + f"{field} = %({field})s" for field in self.to_fields + ) + params = { + field: getattr(instance, self.fields[i]) + for i, field in enumerate(self.to_fields) + } + table = self.to_table + cursor.execute(f"SELECT count(*) FROM {table} WHERE {where_clause}", params) + result = cursor.fetchone() + if result[0] == 0: + raise ValidationError(self.get_violation_error_message()) + + def __eq__(self, other): + if isinstance(other, BasicForeignKeyConstraint): + return ( + self.name == other.name + and self.violation_error_message == other.violation_error_message + and self.fields == other.fields + and self.to_table == other.to_table + and self.to_fields == other.to_fields + ) + return super().__eq__(other) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + kwargs["to_table"] = self.to_table + kwargs["fields"] = self.fields + kwargs["to_fields"] = self.to_fields + return path, args, kwargs + + +class ForeignKeyConstraint(BaseConstraint): + def __init__( + self, + *, + name, + fields, + to_model, + to_fields, + deferrable=None, + violation_error_message=None, + ): + super().__init__(name, violation_error_message=violation_error_message) + self.fields = fields + self.to_model = to_model + self.to_fields = to_fields + self.deferrable = deferrable + + def get_to_model(self, from_model): + if isinstance(self.to_model, str): + if "." in self.to_model: + app_label, to_model_name = self.to_model.split(".") + else: + app_label = from_model._meta.app_label + to_model_name = self.to_model + return apps.get_model(app_label, to_model_name) + else: + return self.to_model + + def create_sql(self, model, schema_editor): + sql = schema_editor.sql_create_fk + table = Table(model._meta.db_table, schema_editor.quote_name) + name = schema_editor.quote_name(self.name) + column_names = [ + model._meta.get_field(field_name).get_attname() + for field_name in self.fields + ] + columns = Columns(table, column_names, schema_editor.quote_name) + to_model = self.get_to_model(model) + to_table = Table(to_model._meta.db_table, schema_editor.quote_name) + to_column_names = [ + to_model._meta.get_field(field_name).get_attname() + for field_name in self.to_fields + ] + to_columns = Columns(to_table, to_column_names, schema_editor.quote_name) + deferrable = schema_editor._deferrable_constraint_sql(self.deferrable) + return Statement( + sql, + table=table, + name=name, + column=columns, + to_table=to_table, + to_column=to_columns, + deferrable=deferrable, + ) + + def remove_sql(self, model, schema_editor): + table = schema_editor.connection.ops.quote_name(model._meta.db_table) + name = schema_editor.quote_name(self.name) + return Statement( + schema_editor.sql_delete_fk, + table=table, + name=name, + ) + + def constraint_sql(self, model, schema_editor): + # It's unclear whether this is ever actually used.... it's called from the CreateModel operation however + # constraints are added in a separate operation AddConstraint. + return None + + def get_value(self, value): + # Deal with allowing either field name referring to the related object or the foreign key value + if isinstance(value, Model): + return value.pk + return value + + def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): + to_model = self.get_to_model(model) + queryset = to_model._default_manager.using(using) + filters = [ + Q( + **{ + model._meta.get_field(field_name).target_field.name: self.get_value( + getattr(instance, field_name) + ) + } + ) + for field_name in self.fields + ] + if not queryset.filter(*filters).exists(): + raise ValidationError(self.get_violation_error_message()) + + def __eq__(self, other): + if isinstance(other, ForeignKeyConstraint): + return ( + self.name == other.name + and self.violation_error_message == other.violation_error_message + and self.fields == other.fields + and self.to_model == other.to_model + and self.to_fields == other.to_fields + and self.deferrable == other.deferrable + ) + return super().__eq__(other) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + kwargs["to_model"] = self.to_model + kwargs["fields"] = self.fields + kwargs["to_fields"] = self.to_fields + kwargs["deferrable"] = self.deferrable + return path, args, kwargs + + +class RawSQL(BaseConstraint): + def __init__(self, *, name, sql, reverse_sql): + super().__init__(name) + self.sql = sql + self.reverse_sql = reverse_sql + + def create_sql(self, model, schema_editor): + return self.sql + + def remove_sql(self, model, schema_editor): + return self.reverse_sql + + def constraint_sql(self, model, schema_editor): + return None + + def validate(self, *args, **kwargs): + return True + + def __eq__(self, other): + if isinstance(other, RawSQL): + return ( + self.name == other.name + and self.sql == other.sql + and self.reverse_sql == other.reverse_sql + ) + return super().__eq__(other) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + kwargs["sql"] = self.sql + kwargs["reverse_sql"] = self.reverse_sql + return path, args, kwargs + + +class View(BaseConstraint): + def __init__(self, *, name, query, is_materialized=False): + super().__init__(name) + if isinstance(query, str): + self.query = query + elif isinstance(query, QuerySet): + self.query = str(query.query) + else: + raise TypeError("string or Query expected") + self.is_materialized = is_materialized + + def create_sql(self, model, schema_editor): + if self.is_materialized: + remove_sql = self.remove_sql(model, schema_editor) + return f"{remove_sql}; CREATE MATERIALIZED VIEW {self.name} AS {self.query}" + + return f"CREATE OR REPLACE VIEW {self.name} AS {self.query}" + + def remove_sql(self, model, schema_editor): + qualifier = "MATERIALIZED" if self.is_materialized else "" + return f"DROP {qualifier} VIEW IF EXISTS {self.name} CASCADE" + + def constraint_sql(self, model, schema_editor): + return None + + def validate(self, *args, **kwargs): + return True + + def __eq__(self, other): + if isinstance(other, View): + return ( + self.name == other.name + and self.query == other.query + and self.is_materialized == other.is_materialized + ) + return super().__eq__(other) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + kwargs["query"] = self.query + kwargs["is_materialized"] = self.is_materialized + return path, args, kwargs + + +class Callback(BaseConstraint): + def __init__(self, *, name, callback, reverse_callback): + super().__init__(name) + self.callback = ( + marshal.dumps(callback.__code__) if callable(callback) else callback + ) + self.reverse_callback = ( + marshal.dumps(reverse_callback.__code__) + if callable(reverse_callback) + else reverse_callback + ) + + def create_sql(self, model, schema_editor): + code = marshal.loads(self.callback) + forwards = types.FunctionType(code, globals(), "forwards") + forwards(model, schema_editor) + + def remove_sql(self, model, schema_editor): + code = marshal.loads(self.reverse_callback) + reverse = types.FunctionType(code, globals(), "reverse") + reverse(model, schema_editor) + + def constraint_sql(self, model, schema_editor): + return None + + def validate(self, *args, **kwargs): + return True + + def __eq__(self, other): + if isinstance(other, Callback): + return ( + self.name == other.name + and self.callback == other.callback + and self.reverse_callback == other.reverse_callback + ) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + kwargs["callback"] = self.callback + kwargs["reverse_callback"] = self.reverse_callback + return path, args, kwargs diff --git a/abusing_constraints/migrations/0001_initial.py b/abusing_constraints/migrations/0001_initial.py new file mode 100644 index 0000000..251c869 --- /dev/null +++ b/abusing_constraints/migrations/0001_initial.py @@ -0,0 +1,172 @@ +import django.db.models.deletion +from django.db import migrations, models + +import abusing_constraints.constraints + + +class Migration(migrations.Migration): + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="ActiveDocument", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=255)), + ], + options={ + "db_table": "active_documents", + "managed": False, + }, + ), + migrations.CreateModel( + name="Bar", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ], + ), + migrations.CreateModel( + name="Data", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("data", models.IntegerField()), + ], + options={ + "db_table": "data", + }, + ), + migrations.CreateModel( + name="Document", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=255)), + ("is_archived", models.BooleanField(default=False)), + ], + ), + migrations.CreateModel( + name="Tenant", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ], + ), + migrations.CreateModel( + name="Foo", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "tenant", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="abusing_constraints.tenant", + ), + ), + ], + ), + migrations.AddConstraint( + model_name="document", + constraint=abusing_constraints.constraints.View( + is_materialized=False, + name="active_documents", + query='SELECT "abusing_constraints_document"."id", "abusing_constraints_document"."name" FROM "abusing_constraints_document" WHERE NOT "abusing_constraints_document"."is_archived"', + ), + ), + migrations.AddConstraint( + model_name="data", + constraint=abusing_constraints.constraints.RawSQL( + name="data_stored_procedure", + reverse_sql="DROP PROCEDURE IF EXISTS data_stored_procedure CASCADE\n", + sql="CREATE OR REPLACE PROCEDURE data_stored_procedure()\nLANGUAGE SQL\nAS $$\nINSERT INTO data (data) VALUES (99);\n$$\n", + ), + ), + migrations.AddConstraint( + model_name="data", + constraint=abusing_constraints.constraints.Callback( + callback=b"\xe3\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00\x03\x00\x00\x00\xf3\xe2\x00\x00\x00\x97\x00|\x00j\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa0\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00|\x01j\x02\x00\x00\x00\x00\x00\x00\x00\x00j\x03\x00\x00\x00\x00\x00\x00\x00\x00\xa6\x01\x00\x00\xab\x01\x00\x00\x00\x00\x00\x00\x00\x00}\x02|\x02\xa0\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa6\x00\x00\x00\xab\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00|\x02\xa0\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00|\x00d\x01\xac\x02\xa6\x01\x00\x00\xab\x01\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00|\x00d\x03\xac\x02\xa6\x01\x00\x00\xab\x01\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00|\x00d\x04\xac\x02\xa6\x01\x00\x00\xab\x01\x00\x00\x00\x00\x00\x00\x00\x00g\x03\xa6\x01\x00\x00\xab\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00d\x00S\x00)\x05N\xe9\x01\x00\x00\x00)\x01\xda\x04data\xe9\x02\x00\x00\x00\xe9\x03\x00\x00\x00)\x06\xda\x10_default_manager\xda\x05using\xda\nconnection\xda\x05alias\xda\x06delete\xda\x0bbulk_create)\x03\xda\x05model\xda\rschema_editor\xda\x08querysets\x03\x00\x00\x00 \xfaK/Users/dsanders/projects/stupid_django_tricks/abusing_constraints/models.py\xda\x0cinitial_datar\x10\x00\x00\x003\x00\x00\x00s|\x00\x00\x00\x80\x00\xe0\x0f\x14\xd4\x0f%\xd7\x0f+\xd2\x0f+\xa8M\xd4,D\xd4,J\xd1\x0fK\xd4\x0fK\x80H\xd8\x04\x0c\x87O\x82O\xd1\x04\x15\xd4\x04\x15\xd0\x04\x15\xd8\x04\x0c\xd7\x04\x18\xd2\x04\x18\xe0\x0c\x11\x88E\x90q\x88M\x89M\x8cM\xd8\x0c\x11\x88E\x90q\x88M\x89M\x8cM\xd8\x0c\x11\x88E\x90q\x88M\x89M\x8cM\xf0\x07\x04\t\n\xf1\x03\x06\x05\x06\xf4\x00\x06\x05\x06\xf0\x00\x06\x05\x06\xf0\x00\x06\x05\x06\xf0\x00\x06\x05\x06\xf3\x00\x00\x00\x00", + name="initial_data", + reverse_callback=b"\xe3\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\xf3\x06\x00\x00\x00\x97\x00d\x00S\x00)\x01N\xa9\x00)\x02\xda\x05model\xda\rschema_editors\x02\x00\x00\x00 \xfaK/Users/dsanders/projects/stupid_django_tricks/abusing_constraints/models.py\xda\x14reverse_initial_datar\x06\x00\x00\x00@\x00\x00\x00s\x07\x00\x00\x00\x80\x00\xd8\x04\x07\x80C\xf3\x00\x00\x00\x00", + ), + ), + migrations.AddField( + model_name="bar", + name="foo", + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="abusing_constraints.foo", + ), + ), + migrations.AddField( + model_name="bar", + name="tenant", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="abusing_constraints.tenant", + ), + ), + migrations.AddConstraint( + model_name="foo", + constraint=models.UniqueConstraint( + fields=("id", "tenant"), name="tenant_constraint_target" + ), + ), + migrations.AddConstraint( + model_name="bar", + constraint=abusing_constraints.constraints.ForeignKeyConstraint( + deferrable=None, + fields=["foo", "tenant"], + name="tenant_constraint", + to_fields=["id", "tenant"], + to_model="Foo", + ), + ), + ] diff --git a/abusing_constraints/migrations/__init__.py b/abusing_constraints/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/abusing_constraints/models.py b/abusing_constraints/models.py new file mode 100644 index 0000000..04ea526 --- /dev/null +++ b/abusing_constraints/models.py @@ -0,0 +1,108 @@ +from django.db import models +from django.db.models.constraints import UniqueConstraint + +from abusing_constraints.constraints import Callback, ForeignKeyConstraint, RawSQL, View + + +class Tenant(models.Model): + ... + + +class Foo(models.Model): + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) + + class Meta: + constraints = [ + UniqueConstraint( + name="tenant_constraint_target", + fields=["id", "tenant"], + ) + ] + + +class Bar(models.Model): + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) + foo = models.ForeignKey(Foo, null=True, on_delete=models.CASCADE) + + class Meta: + constraints = [ + ForeignKeyConstraint( + name="tenant_constraint", + fields=["foo", "tenant"], + to_model="Foo", + to_fields=["id", "tenant"], + ), + ] + + +data_stored_procedure = """\ +CREATE OR REPLACE PROCEDURE data_stored_procedure() +LANGUAGE SQL +AS $$ +INSERT INTO data (data) VALUES (99); +$$ +""" + +drop_data_stored_procedure = """\ +DROP PROCEDURE IF EXISTS data_stored_procedure CASCADE +""" + + +def initial_data(model, schema_editor): + # (here model is fake) + queryset = model._default_manager.using(schema_editor.connection.alias) + queryset.delete() + queryset.bulk_create( + [ + model(data=1), + model(data=2), + model(data=3), + ] + ) + + +def reverse_initial_data(model, schema_editor): + ... + + +class Data(models.Model): + data = models.IntegerField() + + class Meta: + db_table = "data" + constraints = [ + RawSQL( + name="data_stored_procedure", + sql=data_stored_procedure, + reverse_sql=drop_data_stored_procedure, + ), + Callback( + name="initial_data", + callback=initial_data, + reverse_callback=reverse_initial_data, + ), + ] + + +class Document(models.Model): + name = models.CharField(max_length=255) + is_archived = models.BooleanField(default=False) + + class Meta: + constraints = [] + + +Document._meta.constraints += [ + View( + name="active_documents", + query=Document.objects.filter(is_archived=False).values("id", "name"), + ), +] + + +class ActiveDocument(models.Model): + name = models.CharField(max_length=255) + + class Meta: + db_table = "active_documents" + managed = False diff --git a/abusing_constraints/tests.py b/abusing_constraints/tests.py new file mode 100644 index 0000000..5c3443c --- /dev/null +++ b/abusing_constraints/tests.py @@ -0,0 +1,66 @@ +import pytest +from django.core.exceptions import ValidationError +from django.db import connection +from django.db.utils import IntegrityError + +from abusing_constraints.models import ActiveDocument, Bar, Data, Document, Foo, Tenant + +pytestmark = pytest.mark.django_db + + +def test_constrained_tenant(): + tenant_1 = Tenant.objects.create() + tenant_2 = Tenant.objects.create() + foo_1 = Foo.objects.create(tenant=tenant_1) + foo_2 = Foo.objects.create(tenant=tenant_2) + + # should be allowed + Bar.objects.create(tenant=tenant_1, foo=foo_1) + + # should NOT be allowed + with pytest.raises(IntegrityError): + Bar.objects.create(tenant=tenant_1, foo=foo_2) + + +def test_fk_validate(): + tenant_1 = Tenant.objects.create() + tenant_2 = Tenant.objects.create() + foo = Foo.objects.create(tenant=tenant_1) + bar = Bar.objects.create(tenant=tenant_1, foo=foo) + + # ok + bar.validate_constraints() + + bar.tenant = tenant_2 + + # now should fail validation + with pytest.raises(ValidationError) as error: + bar.validate_constraints() + assert error.value.messages == ["Constraint “tenant_constraint” is violated."] + + +def test_intial_data_and_store_procedure(): + # Initial data provided through Callback should be 1, 2, 3 + assert list(Data.objects.values_list("data", flat=True)) == [1, 2, 3] + + with connection.cursor() as cursor: + cursor.execute("CALL data_stored_procedure()") + + # The stored procedure should append 99 + assert list(Data.objects.values_list("data", flat=True)) == [1, 2, 3, 99] + + +def test_view(): + Document.objects.create(name="Active Document") + Document.objects.create(name="Archived Document", is_archived=True) + active_document = ActiveDocument.objects.first() + + # Assert our view reflects the source + assert active_document.name == "Active Document" + + # Try updating the view! + active_document.name = "Active Document has been updated!" + active_document.save() + active_document.refresh_from_db() + + assert active_document.name == "Active Document has been updated!" diff --git a/stupid_django_tricks/settings.py b/stupid_django_tricks/settings.py index 71f5e5b..86b4ad6 100644 --- a/stupid_django_tricks/settings.py +++ b/stupid_django_tricks/settings.py @@ -33,6 +33,7 @@ # Application definition INSTALLED_APPS = [ + "abusing_constraints", "all_subqueries", "generated_columns", "singleton_models",