March 2023
Django constraints allow you to add different types of table-level database constraints to your models.
Django currently supports check & unique constraints. Both of these extend from BaseConstraint
which provide the
necessary hooks for forwards/reverse migrations, makemigrations as well as validation of constraints from model & form instances.
A Django constraint follows this basic pattern:
class CustomConstraint(BaseConstraint):
def __init__(self, *, name, custom_param, violation_error_message=None):
# define your custom params as needed, be sure to pass name & violation_error_message to BaseConstraint
super().__init__(name, violation_error_message)
self.custom_param = custom_param
def constraint_sql(self, model, schema_editor):
# Used by migrations to define the constraint when included within CREATE TABLE
# Eg: CHECK price > 0
# Note: model is fake
def create_sql(self, model, schema_editor):
# Used by migrations to define the constraint when included with ALTER TABLE
# Eg: ALTER TABLE table ADD CONSTRAINT name CHECK price > 0
# Note: model is fake
def remove_sql(self, model, schema_editor):
# Used for reverse migrations and done within ALTER TABLE
# Eg: ALTER TABLE table DROP CONSTRAINT name
# Note: model is fake
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
# Validation done as part of model or form validation ie instance.validate_constraints()
# The model passed here is a real model. Use any database query as needed to verify the constraint.
# Validation failure must raise a ValidationError.
# Eg: Q(...filtering_as_required...).check() -> fail then raise ValidationError
def __eq__(self):
# You *must* define this in order to prevent makemigrations from recreating your constraint migration operations.
if isinstance(other, CustomConstraint):
return (
self.name == other.name
and self.custom_param == other.custom_param
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
def deconstruct(self):
# You must extend this to include additional params passed to __init__() for serialisation purposes.
path, args, kwargs = super().deconstruct()
kwargs["custom_param"] = self.custom_param
return path, args, kwargs
We can extend the supplied BaseConstraint
in interesting ways:
Django's foreign keys are basic single-column keys with primary key references. Sometimes we may want more complex foreign keys, like using composite keys in multi-tenanted databases to enforce tenancy across relationships.
Take the following trivial multi-tenancy design where tenanted models Foo
and Bar
are related with a foreign key:
class Tenant(models.Model):
...
class Foo(models.Model):
tenant = models.ForeignKey(Tenant, ...)
class Bar(models.Model):
tenant = models.ForeignKey(Tenant, ...)
foo = models.ForeignKey(Foo, ...)
Adding an extra composite foreign key from Bar
to Foo
, with the tenant as part of the key, enforces tenant equality. ie
the foreign key now prevents relationships from existing where Foo and Bar belong to different tenants.
To do this 2 steps are required:
- Create a unique index on the referenced model, Foo, for the key to target that will include the primary key + tenant:
(id, tenant_id)
- Add a supplementary foreign key on the referencing model, Bar, to reference the newly created index.
class Tenant(models.Model):
...
class Foo(models.Model):
tenant = models.ForeignKey(Tenant)
class Meta:
constraints = [
UniqueConstraint(columns=["id", "tenant"], ...)
]
class Bar(models.Model):
tenant = models.ForeignKey(Tenant, ...)
foo = models.ForeignKey(Foo, ...)
class Meta:
constraints = [
ForeignKeyConstraint(
columns=["foo_id", "tenant_id"],
to_table="abusing_constraints_foo",
to_columns=["id", "tenant_id"],
...
),
]
This new foreign key constraint is a simple extension of BaseConstraint
(note this is a simple implementation without
any proper quoting of names, etc for demonstration purposes):
class ForeignKeyConstraint(BaseConstraint):
def __init__(
self,
*,
name,
columns,
to_table,
to_columns,
on_delete=None,
on_update=None,
violation_error_message=None,
):
super().__init__(name, violation_error_message=violation_error_message)
self.columns = columns
self.to_table = to_table
self.to_columns = to_columns
self.on_delete = on_delete
self.on_update = on_update
def create_sql(self, model, schema_editor):
table = model._meta.db_table
constraint_sql = self.constraint_sql(model, schema_editor)
return f"ALTER TABLE {table} ADD CONSTRAINT {self.name} {constraint_sql}"
def remove_sql(self, model, schema_editor):
table = model._meta.db_table
return f"ALTER TABLE {table} DROP CONSTRAINT {self.name}"
def constraint_sql(self, model, schema_editor):
columns = ", ".join(self.columns)
to_columns = ", ".join(self.to_columns)
on_delete = f"ON DELETE {self.on_delete}" if self.on_delete else ""
on_update = f"ON UPDATE {self.on_update}" if self.on_update else ""
return f"FOREIGN KEY ({columns}) REFERENCES {self.to_table} ({to_columns}) {on_delete} {on_update}"
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
with connection.cursor() as cursor:
# to keep things simple assume each field doesn't have a separate column name
where_clause = " AND ".join(
f"{field} = %({field})s" for field in self.to_columns
)
params = {
field: getattr(instance, self.columns[i])
for i, field in enumerate(self.to_columns)
}
table = self.to_table
cursor.execute(f"SELECT count(*) FROM {table} WHERE {where_clause}", params)
result = cursor.fetchone()
if result[0] == 0:
raise ValidationError(self.get_violation_error_message())
def __eq__(self, other):
if isinstance(other, ForeignKeyConstraint):
return (
self.name == other.name
and self.violation_error_message == other.violation_error_message
and self.columns == other.columns
and self.to_table == other.to_table
and self.to_columns == other.to_columns
and self.on_delete == other.on_delete
and self.on_update == other.on_update
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs["to_table"] = self.to_table
kwargs["columns"] = self.columns
kwargs["to_columns"] = self.to_columns
kwargs["on_delete"] = self.on_delete
kwargs["on_update"] = self.on_update
return path, args, kwargs
Notice in the ForeignKeyConstraint
class above I made provision for cascading updates & deletes - if you wish to have
database-level cascades implemented you can use this supplementary foreign key to do that:
- Specify
DO_NOTHING
for the main ForeignKey - Add the supplementary foreign key specifying database-level cascades
class Parent(models.Model):
...
class Child(models.Model):
parent = models.ForeignKey(Parent, on_delete=models.DO_NOTHING)
class Meta:
constraints = [
ForeignKeyConstraint(
name="native_on_delete",
columns=["parent_id"],
to_table="abusing_constraints_parent",
to_columns=["id"],
on_delete="CASCADE",
),
]
Constraints provide the ability to inject any database query – DDL or DML – into your migrations
Things get interesting when we realise that constraints become a point where we can add custom database operations to our migrations as an alternative
to the manual RunPython
or RunSQL
migration operations.
class RawSQL(BaseConstraint):
def __init__(self, *, name, sql, reverse_sql):
super().__init__(name)
self.sql = sql
self.reverse_sql = reverse_sql
def create_sql(self, model, schema_editor):
return self.sql
def remove_sql(self, model, schema_editor):
return self.reverse_sql
# These 2 methods don't apply for non-constraints
def constraint_sql(self, model, schema_editor):
return None
def validate(self, *args, **kwargs):
return True
# ...other methods similarly defined as above
Here's what we can do with this:
Define a stored procedure for your model within the model's meta 😆
data_stored_procedure = """\
CREATE OR REPLACE PROCEDURE data_stored_procedure()
LANGUAGE SQL
AS $$
INSERT INTO data (data) VALUES (99);
$$
"""
drop_data_stored_procedure = """\
DROP PROCEDURE IF EXISTS data_stored_procedure CASCADE
"""
class Data(models.Model):
data = models.IntegerField()
class Meta:
db_table = "data"
constraints = [
RawSQL(
name="data_stored_procedure",
sql=data_stored_procedure,
reverse_sql=drop_data_stored_procedure,
),
]
An example with simple forwarding of the fake model & schema editor, noting that we need to serialise the callbacks in order for it to be injected into your migrations. Obviously this limits what you can do but it is possible for simple functions:
class Callback(BaseConstraint):
def __init__(self, *, name, callback, reverse_callback):
super().__init__(name)
self.callback = (
marshal.dumps(callback.__code__) if callable(callback) else callback
)
self.reverse_callback = (
marshal.dumps(reverse_callback.__code__)
if callable(reverse_callback)
else reverse_callback
)
def create_sql(self, model, schema_editor):
code = marshal.loads(self.callback)
forwards = types.FunctionType(code, globals(), "forwards")
forwards(model, schema_editor)
def remove_sql(self, model, schema_editor):
code = marshal.loads(self.reverse_callback)
reverse = types.FunctionType(code, globals(), "reverse")
reverse(model, schema_editor)
# ...other methods similarly defined as above
Using the callback constraint to define initial model data:
def initial_data(model, schema_editor):
# (here model is fake)
queryset = model._default_manager.using(schema_editor.connection.alias)
queryset.bulk_create(
[
model(data=1),
model(data=2),
model(data=3),
]
)
def reverse_initial_data(model, schema_editor):
...
class Data(models.Model):
data = models.IntegerField()
class Meta:
constraints = [
Callback(
name="initial_data",
callback=initial_data,
reverse_callback=reverse_initial_data,
),
]
Database views are a useful abstraction and a common method for using them in Django is to create the view in a migration,
then create an unmanaged model that refers to the view using Meta.db_table
.
To avoid the hassle of manual migrations an extension of constraints like so can be used (noting that constraints will only be applied to managed models):
class Document(models.Model):
# This is the main document model
name = models.CharField(max_length=255)
is_archived = models.BooleanField(default=False)
class Meta:
constraints = []
Document._meta.constraints += [
View(
name="active_documents",
# Remember to forward the primary key (or create one)
query=Document.objects.filter(is_archived=False).values("id", "name"),
),
]
class ActiveDocument(models.Model):
# Model representing our database view "active_documents"
name = models.CharField(max_length=255)
class Meta:
db_table = "active_documents"
managed = False
Bonus advantage: Simple PostgreSQL views like this are automatically updatable meaning that any save operations from Django will work!
Here's what the View
"constraint" could look like:
class View(BaseConstraint):
def __init__(self, *, name, query):
super().__init__(name)
if isinstance(query, str):
self.query = query
else:
# Better to parameterise correctly as __str__() just fills in placeholders
# without any correct quoting for strings, etc.
self.query = str(query.query)
def create_sql(self, model, schema_editor):
return f"CREATE OR REPLACE VIEW {self.name} AS {self.query}"
def remove_sql(self, model, schema_editor):
return f"DROP VIEW IF EXISTS {self.name} CASCADE"
# ...other methods similarly defined as above
See the tests & code for complete examples.