Skip to content

Commit

Permalink
Adds many to many model tenant save (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
gurkanindibay authored Feb 28, 2023
1 parent bbb4810 commit 9ac16f3
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 3 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ release:
twine check dist/*
twine upload --skip-existing dist/*

test-model:
py.test -s django_multitenant/tests/test_models.py -k 'not concurrency'

49 changes: 49 additions & 0 deletions django_multitenant/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from django.db.utils import NotSupportedError
from django.conf import settings

import django
from django.db.models.fields.related_descriptors import (
create_forward_many_to_many_manager,
)


from .deletion import related_objects
from .exceptions import EmptyTenant
Expand All @@ -17,12 +22,51 @@
get_tenant_filters,
get_object_tenant,
set_object_tenant,
get_tenant_column,
)


logger = logging.getLogger(__name__)


def wrap_many_related_manager_add(many_related_manager_add):

"""
Wraps the add method of many to many field to set tenant_id in through_defaults
parameter of the add method.
"""

def add(self, *objs, through_defaults=None):
if get_current_tenant():
through_defaults[
get_tenant_column(self.through)
] = get_current_tenant_value()
return many_related_manager_add(self, *objs, through_defaults=through_defaults)

return add


def wrap_forward_many_to_many_manager(create_forward_many_to_many_manager_method):

"""
Wraps the create_forward_many_to_many_manager method of the related_descriptors module
and changes the add method of the ManyRelatedManagerClass to set tenant_id in through_defaults
"""

def create_forward_many_to_many_manager_wrapper(superclass, rel, reverse):
ManyRelatedManagerClass = create_forward_many_to_many_manager_method(
superclass, rel, reverse
)
ManyRelatedManagerClass.add = wrap_many_related_manager_add(
ManyRelatedManagerClass.add
)
return ManyRelatedManagerClass

# pylint: disable=protected-access
create_forward_many_to_many_manager_wrapper._sign = "add django-multitenant"
return create_forward_many_to_many_manager_wrapper


class TenantManagerMixin:
# Below is the manager related to the above class.
# Overrides the get_queryset method of to inject tenant_id filters in the get_queryset.
Expand Down Expand Up @@ -69,6 +113,11 @@ def __init__(self, *args, **kwargs):
# if distributed tables are being related to the model.
Collector.delete = wrap_delete(Collector.delete)

if not hasattr(create_forward_many_to_many_manager, "_sign"):
django.db.models.fields.related_descriptors.create_forward_many_to_many_manager = wrap_forward_many_to_many_manager(
create_forward_many_to_many_manager
)

# Decorates the update_batch method of UpdateQuery to add tenant_id filters.
if not hasattr(UpdateQuery.get_compiler, "_sign"):
UpdateQuery.update_batch = wrap_update_batch(UpdateQuery.update_batch)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Generated by Django 4.0 on 2023-02-20 17:41

from django.db import migrations, models
import django.db.models.deletion
import django_multitenant.fields
import django_multitenant.mixins


class Migration(migrations.Migration):

dependencies = [
("tests", "0023_auto_20200412_0603"),
]

operations = [
migrations.CreateModel(
name="Product",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=255)),
("description", models.TextField()),
],
bases=(django_multitenant.mixins.TenantModelMixin, models.Model),
),
migrations.CreateModel(
name="Purchase",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
],
bases=(django_multitenant.mixins.TenantModelMixin, models.Model),
),
migrations.CreateModel(
name="Store",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=50)),
("address", models.CharField(max_length=255)),
("email", models.CharField(max_length=50)),
],
options={
"abstract": False,
},
bases=(django_multitenant.mixins.TenantModelMixin, models.Model),
),
migrations.CreateModel(
name="Transaction",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("date", models.DateField()),
(
"product",
django_multitenant.fields.TenantForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="tests.product"
),
),
(
"purchase",
django_multitenant.fields.TenantForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="tests.purchase",
),
),
(
"store",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="tests.store"
),
),
],
options={
"abstract": False,
},
bases=(django_multitenant.mixins.TenantModelMixin, models.Model),
),
migrations.AddField(
model_name="purchase",
name="product_purchased",
field=models.ManyToManyField(
through="tests.Transaction", to="tests.Product"
),
),
migrations.AddField(
model_name="purchase",
name="store",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="tests.store"
),
),
migrations.AddField(
model_name="product",
name="store",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="tests.store"
),
),
migrations.AlterUniqueTogether(
name="purchase",
unique_together={("id", "store")},
),
migrations.AlterUniqueTogether(
name="product",
unique_together={("id", "store")},
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Generated by Django 4.1 on 2023-02-23 17:24

from django.db import migrations
from django_multitenant.db import migrations as tenant_migrations


class Migration(migrations.Migration):

dependencies = [
("tests", "0024_product_purchase_store_alter_account_id_and_more"),
]
operations = []

operations += [
# Drop constraints
migrations.RunSQL(
"ALTER TABLE tests_store DROP CONSTRAINT tests_store_pkey CASCADE;"
),
migrations.RunSQL(
"ALTER TABLE tests_product DROP CONSTRAINT tests_product_pkey CASCADE;"
),
migrations.RunSQL(
"ALTER TABLE tests_purchase DROP CONSTRAINT tests_purchase_pkey CASCADE;"
),
migrations.RunSQL(
"ALTER TABLE tests_transaction DROP CONSTRAINT tests_transaction_pkey CASCADE;"
),
]

operations += [
tenant_migrations.Distribute("Store"),
tenant_migrations.Distribute("Product"),
tenant_migrations.Distribute("Purchase"),
tenant_migrations.Distribute("Transaction"),
]

operations += [
migrations.RunSQL(
"ALTER TABLE tests_store ADD CONSTRAINT tests_store_pkey PRIMARY KEY (id);"
),
migrations.RunSQL(
"ALTER TABLE tests_product ADD CONSTRAINT tests_product_pkey PRIMARY KEY (store_id, id);"
),
migrations.RunSQL(
"ALTER TABLE tests_purchase ADD CONSTRAINT tests_purchase_pkey PRIMARY KEY (store_id, id);"
),
migrations.RunSQL(
"ALTER TABLE tests_transaction ADD CONSTRAINT tests_transaction_pkey PRIMARY KEY (store_id, id);"
),
]
38 changes: 38 additions & 0 deletions django_multitenant/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,41 @@ class MigrationTestModel(TenantModel):

class MigrationTestReferenceModel(models.Model):
name = models.CharField(max_length=255)


class Store(TenantModel):
tenant_id = "id"
name = models.CharField(max_length=50)
address = models.CharField(max_length=255)
email = models.CharField(max_length=50)


class Product(TenantModel):
store = models.ForeignKey(Store, on_delete=models.CASCADE)
tenant_id = "store_id"
name = models.CharField(max_length=255)
description = models.TextField()

class Meta:
unique_together = ["id", "store"]


class Purchase(TenantModel):
store = models.ForeignKey(Store, on_delete=models.CASCADE)
tenant_id = "store_id"
product_purchased = models.ManyToManyField(
Product, through="Transaction", through_fields=("purchase", "product")
)

class Meta:
unique_together = ["id", "store"]


class Transaction(TenantModel):
store = models.ForeignKey(Store, on_delete=models.CASCADE)
tenant_id = "store_id"
purchase = TenantForeignKey(
Purchase, on_delete=models.CASCADE, blank=True, null=True
)
product = TenantForeignKey(Product, on_delete=models.CASCADE)
date = models.DateField()
17 changes: 17 additions & 0 deletions django_multitenant/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from datetime import date
import re

import django
import pytest
from django.conf import settings
from django.db.models import Count
from django.db.utils import NotSupportedError, DataError
from .models import Store, Product, Purchase


from django_multitenant.utils import (
set_current_tenant,
Expand Down Expand Up @@ -812,3 +815,17 @@ def test_aggregate(self):
unset_current_tenant()
projects_per_manager = ProjectManager.objects.annotate(Count("project_id"))
list(projects_per_manager)

def test_many_to_many_through_saves(self):

store = Store.objects.create(name="store1")
store.save()

set_current_tenant(tenant=store)

product = Product.objects.create(name="product1", store=store)
product.save()

purchase = Purchase.objects.create(store=store)
purchase.save()
purchase.product_purchased.add(product, through_defaults={"date": date.today()})
1 change: 0 additions & 1 deletion django_multitenant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def set_current_tenant(tenant):
get_current_tenant(my_class_object)
```
"""

setattr(_thread_locals, "tenant", tenant)


Expand Down
2 changes: 1 addition & 1 deletion manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ENV = ENV.lower()

if ENV not in SUPPORTED_ENVS:
raise Exception(f"Unsupported environment: {ENV}")
raise EnvironmentError(f"Unsupported environment: {ENV}")

if __name__ == "__main__":
os.environ.setdefault("DJANGO_SETTINGS_MODULE", SETTINGS_MODULES[ENV])
Expand Down
5 changes: 4 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[pytest]
addopts = --reuse-db
addopts = --reuse-db
filterwarnings =
ignore::pytest.PytestCacheWarning

0 comments on commit 9ac16f3

Please sign in to comment.