Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use asgiref when available instead of thread locals (#176) #198

Merged
merged 5 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Use asgiref when available instead of thread locals (#176)
add asgiref as test dependency

add tests for asgiref

add setting TENANT_USE_ASGIREF, updated tests
  • Loading branch information
darwing1210 committed Jul 12, 2023
commit 3465bfc0cb4b5b519f4c047c242573b5a39a27a8
1 change: 1 addition & 0 deletions django_multitenant/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
TENANT_MODEL_NAME = getattr(settings, "TENANT_MODEL_NAME", None)
CITUS_EXTENSION_INSTALLED = getattr(settings, "CITUS_EXTENSION_INSTALLED", False)
TENANT_STRICT_MODE = getattr(settings, "TENANT_STRICT_MODE", False)
TENANT_USE_ASGIREF = getattr(settings, "TENANT_USE_ASGIREF", False)
3 changes: 3 additions & 0 deletions django_multitenant/tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,6 @@

DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
USE_TZ = True

TENANT_USE_ASGIREF = False

55 changes: 55 additions & 0 deletions django_multitenant/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import asyncio
import sys, importlib
from asgiref.sync import async_to_sync, sync_to_async


from django_multitenant.utils import (
set_current_tenant,
get_current_tenant,
Expand All @@ -11,6 +16,12 @@


class UtilsTest(BaseTestCase):
async def async_get_current_tenant(self):
return get_current_tenant()

async def async_set_current_tenant(self, tenant):
return set_current_tenant(tenant)

def test_set_current_tenant(self):
projects = self.projects
account = projects[0].account
Expand All @@ -19,6 +30,50 @@ def test_set_current_tenant(self):
self.assertEqual(get_current_tenant(), account)
unset_current_tenant()

def test_tenant_persists_from_thread_to_async_task(self):
projects = self.projects
account = projects[0].account

# Set the tenant in main thread
set_current_tenant(account)

with self.settings(TENANT_USE_ASGIREF=True):
importlib.reload(sys.modules['django_multitenant.utils'])
from django_multitenant.utils import get_current_tenant
# Check the tenant within an async task when asgiref enabled
tenant = async_to_sync(self.async_get_current_tenant)()
self.assertEqual(get_current_tenant(), tenant)
unset_current_tenant()

with self.settings(TENANT_USE_ASGIREF=False):
importlib.reload(sys.modules['django_multitenant.utils'])
from django_multitenant.utils import get_current_tenant
# Check the tenant within an async task when asgiref is disabled
tenant = async_to_sync(self.async_get_current_tenant)()
self.assertIsNone(get_current_tenant())
unset_current_tenant()

def test_tenant_persists_from_async_task_to_thread(self):
projects = self.projects
account = projects[0].account

with self.settings(TENANT_USE_ASGIREF=True):
importlib.reload(sys.modules['django_multitenant.utils'])
from django_multitenant.utils import get_current_tenant
# Set the tenant in task
async_to_sync(self.async_set_current_tenant)(account)
self.assertEqual(get_current_tenant(), account)
unset_current_tenant()

with self.settings(TENANT_USE_ASGIREF=False):
importlib.reload(sys.modules['django_multitenant.utils'])
from django_multitenant.utils import get_current_tenant
# Set the tenant in task
async_to_sync(self.async_set_current_tenant)(account)
self.assertIsNone(get_current_tenant())
unset_current_tenant()


def test_get_tenant_column(self):
from .models import Project

Expand Down
26 changes: 16 additions & 10 deletions django_multitenant/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import inspect

from django.apps import apps
from django.conf import settings

try:
from threading import local
except ImportError:
from django.utils._threading_local import local

if settings.TENANT_USE_ASGIREF:
# asgiref must be installed, its included with Django >= 3.0
from asgiref.local import Local as local
else:
try:
from threading import local
except ImportError:
from django.utils._threading_local import local


_thread_locals = local()
_thread_locals = _context = local()


def get_model_by_db_table(db_table):
Expand All @@ -26,14 +32,14 @@ def get_model_by_db_table(db_table):

def get_current_tenant():
"""
Utils to get the tenant that hass been set in the current thread using `set_current_tenant`.
Utils to get the tenant that hass been set in the current thread/context using `set_current_tenant`.
Can be used by doing:
```
my_class_object = get_current_tenant()
```
Will return None if the tenant is not set
"""
return getattr(_thread_locals, "tenant", None)
return getattr(_context, "tenant", None)


def get_tenant_column(model_class_or_instance):
Expand Down Expand Up @@ -125,19 +131,19 @@ def get_tenant_filters(table, filters=None):

def set_current_tenant(tenant):
"""
Utils to set a tenant in the current thread.
Utils to set a tenant in the current thread/context.
Often used in a middleware once a user is logged in to make sure all db
calls are sharded to the current tenant.
Can be used by doing:
```
get_current_tenant(my_class_object)
```
"""
setattr(_thread_locals, "tenant", tenant)
setattr(_context, "tenant", tenant)


def unset_current_tenant():
setattr(_thread_locals, "tenant", None)
setattr(_context, "tenant", None)


def is_distributed_model(model):
Expand Down
10 changes: 3 additions & 7 deletions requirements/test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# This file is autogenerated by pip-compile with Python 3.11
# by the following command:
#
# pip-compile --output-file=requirements/test-requirements.txt --resolver=backtracking requirements/test.in
#
asgiref==3.7.2
# via -r requirements/test.in
coverage[toml]==7.2.7
# via pytest-cov
exam==0.10.6
# via -r requirements/test.in
exceptiongroup==1.1.2
# via pytest
iniconfig==2.0.0
# via pytest
mock==5.0.2
Expand All @@ -29,7 +29,3 @@ pytest-cov==4.1.0
# via -r requirements/test.in
pytest-django==4.5.2
# via -r requirements/test.in
tomli==2.0.1
# via
# coverage
# pytest
1 change: 1 addition & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest
pytest-cov
pytest-django
exam
asgiref>= 3.5.2