diff --git a/Makefile b/Makefile index 5b3cf91b..e9d2cbd2 100644 --- a/Makefile +++ b/Makefile @@ -36,3 +36,6 @@ release: twine check dist/* twine upload --skip-existing dist/* +test-model: + py.test -s django_multitenant/tests/test_models.py -k 'not concurrency' + diff --git a/django_multitenant/mixins.py b/django_multitenant/mixins.py index f27c2dd3..18832dff 100644 --- a/django_multitenant/mixins.py +++ b/django_multitenant/mixins.py @@ -5,6 +5,11 @@ from django.db.utils import NotSupportedError from django.conf import settings +import django +from django.db.models.fields.related_descriptors import ( + create_forward_many_to_many_manager, +) + from .deletion import related_objects from .exceptions import EmptyTenant @@ -17,12 +22,51 @@ get_tenant_filters, get_object_tenant, set_object_tenant, + get_tenant_column, ) logger = logging.getLogger(__name__) +def wrap_many_related_manager_add(many_related_manager_add): + + """ + Wraps the add method of many to many field to set tenant_id in through_defaults + parameter of the add method. + """ + + def add(self, *objs, through_defaults=None): + if get_current_tenant(): + through_defaults[ + get_tenant_column(self.through) + ] = get_current_tenant_value() + return many_related_manager_add(self, *objs, through_defaults=through_defaults) + + return add + + +def wrap_forward_many_to_many_manager(create_forward_many_to_many_manager_method): + + """ + Wraps the create_forward_many_to_many_manager method of the related_descriptors module + and changes the add method of the ManyRelatedManagerClass to set tenant_id in through_defaults + """ + + def create_forward_many_to_many_manager_wrapper(superclass, rel, reverse): + ManyRelatedManagerClass = create_forward_many_to_many_manager_method( + superclass, rel, reverse + ) + ManyRelatedManagerClass.add = wrap_many_related_manager_add( + ManyRelatedManagerClass.add + ) + return ManyRelatedManagerClass + + # pylint: disable=protected-access + create_forward_many_to_many_manager_wrapper._sign = "add django-multitenant" + return create_forward_many_to_many_manager_wrapper + + class TenantManagerMixin: # Below is the manager related to the above class. # Overrides the get_queryset method of to inject tenant_id filters in the get_queryset. @@ -69,6 +113,11 @@ def __init__(self, *args, **kwargs): # if distributed tables are being related to the model. Collector.delete = wrap_delete(Collector.delete) + if not hasattr(create_forward_many_to_many_manager, "_sign"): + django.db.models.fields.related_descriptors.create_forward_many_to_many_manager = wrap_forward_many_to_many_manager( + create_forward_many_to_many_manager + ) + # Decorates the update_batch method of UpdateQuery to add tenant_id filters. if not hasattr(UpdateQuery.get_compiler, "_sign"): UpdateQuery.update_batch = wrap_update_batch(UpdateQuery.update_batch) diff --git a/django_multitenant/tests/migrations/0024_product_purchase_store_alter_account_id_and_more.py b/django_multitenant/tests/migrations/0024_product_purchase_store_alter_account_id_and_more.py new file mode 100644 index 00000000..c5cc86e0 --- /dev/null +++ b/django_multitenant/tests/migrations/0024_product_purchase_store_alter_account_id_and_more.py @@ -0,0 +1,138 @@ +# Generated by Django 4.0 on 2023-02-20 17:41 + +from django.db import migrations, models +import django.db.models.deletion +import django_multitenant.fields +import django_multitenant.mixins + + +class Migration(migrations.Migration): + + dependencies = [ + ("tests", "0023_auto_20200412_0603"), + ] + + operations = [ + migrations.CreateModel( + name="Product", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=255)), + ("description", models.TextField()), + ], + bases=(django_multitenant.mixins.TenantModelMixin, models.Model), + ), + migrations.CreateModel( + name="Purchase", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ], + bases=(django_multitenant.mixins.TenantModelMixin, models.Model), + ), + migrations.CreateModel( + name="Store", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=50)), + ("address", models.CharField(max_length=255)), + ("email", models.CharField(max_length=50)), + ], + options={ + "abstract": False, + }, + bases=(django_multitenant.mixins.TenantModelMixin, models.Model), + ), + migrations.CreateModel( + name="Transaction", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("date", models.DateField()), + ( + "product", + django_multitenant.fields.TenantForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="tests.product" + ), + ), + ( + "purchase", + django_multitenant.fields.TenantForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="tests.purchase", + ), + ), + ( + "store", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="tests.store" + ), + ), + ], + options={ + "abstract": False, + }, + bases=(django_multitenant.mixins.TenantModelMixin, models.Model), + ), + migrations.AddField( + model_name="purchase", + name="product_purchased", + field=models.ManyToManyField( + through="tests.Transaction", to="tests.Product" + ), + ), + migrations.AddField( + model_name="purchase", + name="store", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="tests.store" + ), + ), + migrations.AddField( + model_name="product", + name="store", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="tests.store" + ), + ), + migrations.AlterUniqueTogether( + name="purchase", + unique_together={("id", "store")}, + ), + migrations.AlterUniqueTogether( + name="product", + unique_together={("id", "store")}, + ), + ] diff --git a/django_multitenant/tests/migrations/0025_many_to_many_distribute.py b/django_multitenant/tests/migrations/0025_many_to_many_distribute.py new file mode 100644 index 00000000..6b673fbd --- /dev/null +++ b/django_multitenant/tests/migrations/0025_many_to_many_distribute.py @@ -0,0 +1,50 @@ +# Generated by Django 4.1 on 2023-02-23 17:24 + +from django.db import migrations +from django_multitenant.db import migrations as tenant_migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("tests", "0024_product_purchase_store_alter_account_id_and_more"), + ] + operations = [] + + operations += [ + # Drop constraints + migrations.RunSQL( + "ALTER TABLE tests_store DROP CONSTRAINT tests_store_pkey CASCADE;" + ), + migrations.RunSQL( + "ALTER TABLE tests_product DROP CONSTRAINT tests_product_pkey CASCADE;" + ), + migrations.RunSQL( + "ALTER TABLE tests_purchase DROP CONSTRAINT tests_purchase_pkey CASCADE;" + ), + migrations.RunSQL( + "ALTER TABLE tests_transaction DROP CONSTRAINT tests_transaction_pkey CASCADE;" + ), + ] + + operations += [ + tenant_migrations.Distribute("Store"), + tenant_migrations.Distribute("Product"), + tenant_migrations.Distribute("Purchase"), + tenant_migrations.Distribute("Transaction"), + ] + + operations += [ + migrations.RunSQL( + "ALTER TABLE tests_store ADD CONSTRAINT tests_store_pkey PRIMARY KEY (id);" + ), + migrations.RunSQL( + "ALTER TABLE tests_product ADD CONSTRAINT tests_product_pkey PRIMARY KEY (store_id, id);" + ), + migrations.RunSQL( + "ALTER TABLE tests_purchase ADD CONSTRAINT tests_purchase_pkey PRIMARY KEY (store_id, id);" + ), + migrations.RunSQL( + "ALTER TABLE tests_transaction ADD CONSTRAINT tests_transaction_pkey PRIMARY KEY (store_id, id);" + ), + ] diff --git a/django_multitenant/tests/models.py b/django_multitenant/tests/models.py index 8ffd74d9..bd2829af 100644 --- a/django_multitenant/tests/models.py +++ b/django_multitenant/tests/models.py @@ -209,3 +209,41 @@ class MigrationTestModel(TenantModel): class MigrationTestReferenceModel(models.Model): name = models.CharField(max_length=255) + + +class Store(TenantModel): + tenant_id = "id" + name = models.CharField(max_length=50) + address = models.CharField(max_length=255) + email = models.CharField(max_length=50) + + +class Product(TenantModel): + store = models.ForeignKey(Store, on_delete=models.CASCADE) + tenant_id = "store_id" + name = models.CharField(max_length=255) + description = models.TextField() + + class Meta: + unique_together = ["id", "store"] + + +class Purchase(TenantModel): + store = models.ForeignKey(Store, on_delete=models.CASCADE) + tenant_id = "store_id" + product_purchased = models.ManyToManyField( + Product, through="Transaction", through_fields=("purchase", "product") + ) + + class Meta: + unique_together = ["id", "store"] + + +class Transaction(TenantModel): + store = models.ForeignKey(Store, on_delete=models.CASCADE) + tenant_id = "store_id" + purchase = TenantForeignKey( + Purchase, on_delete=models.CASCADE, blank=True, null=True + ) + product = TenantForeignKey(Product, on_delete=models.CASCADE) + date = models.DateField() diff --git a/django_multitenant/tests/test_models.py b/django_multitenant/tests/test_models.py index 433a6a12..8dfbf872 100644 --- a/django_multitenant/tests/test_models.py +++ b/django_multitenant/tests/test_models.py @@ -1,3 +1,4 @@ +from datetime import date import re import django @@ -5,6 +6,8 @@ from django.conf import settings from django.db.models import Count from django.db.utils import NotSupportedError, DataError +from .models import Store, Product, Purchase + from django_multitenant.utils import ( set_current_tenant, @@ -812,3 +815,17 @@ def test_aggregate(self): unset_current_tenant() projects_per_manager = ProjectManager.objects.annotate(Count("project_id")) list(projects_per_manager) + + def test_many_to_many_through_saves(self): + + store = Store.objects.create(name="store1") + store.save() + + set_current_tenant(tenant=store) + + product = Product.objects.create(name="product1", store=store) + product.save() + + purchase = Purchase.objects.create(store=store) + purchase.save() + purchase.product_purchased.add(product, through_defaults={"date": date.today()}) diff --git a/django_multitenant/utils.py b/django_multitenant/utils.py index e5de4f81..00531683 100644 --- a/django_multitenant/utils.py +++ b/django_multitenant/utils.py @@ -133,7 +133,6 @@ def set_current_tenant(tenant): get_current_tenant(my_class_object) ``` """ - setattr(_thread_locals, "tenant", tenant) diff --git a/manage.py b/manage.py index aa8f7e9e..0d85e6b7 100755 --- a/manage.py +++ b/manage.py @@ -13,7 +13,7 @@ ENV = ENV.lower() if ENV not in SUPPORTED_ENVS: - raise Exception(f"Unsupported environment: {ENV}") + raise EnvironmentError(f"Unsupported environment: {ENV}") if __name__ == "__main__": os.environ.setdefault("DJANGO_SETTINGS_MODULE", SETTINGS_MODULES[ENV]) diff --git a/pytest.ini b/pytest.ini index e846e4e0..3cd2e805 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,5 @@ [pytest] -addopts = --reuse-db \ No newline at end of file +addopts = --reuse-db +filterwarnings = + ignore::pytest.PytestCacheWarning +