From 55da032174fbbf59d68f752cbf2314292f88fb64 Mon Sep 17 00:00:00 2001 From: Shai Berger Date: Wed, 16 Jun 2021 17:45:57 +0300 Subject: [PATCH] Detect cases of n+1 caused by Django deferred fields --- nplusone/__init__.py | 2 +- nplusone/ext/django/patch.py | 35 +++++++++++++++++++++++++++++++++ tests/testapp/testapp/models.py | 4 ++++ tests/testapp/testapp/tests.py | 29 +++++++++++++++++++++++++-- tests/testapp/testapp/urls.py | 1 + tests/testapp/testapp/views.py | 5 +++++ 6 files changed, 73 insertions(+), 3 deletions(-) diff --git a/nplusone/__init__.py b/nplusone/__init__.py index 1f356cc..46a2620 100644 --- a/nplusone/__init__.py +++ b/nplusone/__init__.py @@ -1 +1 @@ -__version__ = '1.0.0' +__version__ = '1.1.0a1' diff --git a/nplusone/ext/django/patch.py b/nplusone/ext/django/patch.py index cd7db13..49ec782 100644 --- a/nplusone/ext/django/patch.py +++ b/nplusone/ext/django/patch.py @@ -17,6 +17,7 @@ create_reverse_many_to_one_manager, create_forward_many_to_many_manager, ) +from django.db.models.query_utils import DeferredAttribute def get_worker(): @@ -345,3 +346,37 @@ def getitem_queryset(self, index): ) return original_getitem_queryset(self, index) query.QuerySet.__getitem__ = getitem_queryset + + +def parse_refresh_from_db(instance, args, kwargs, context): + # Instance passed via partial + fields = kwargs.get('fields') or args[0] + model = type(instance) + return model, to_key(instance), fields[0] + + +original_deferred_attribute_get = DeferredAttribute.__get__ +def deferred_attribute_get(self, instance, cls=None): + """ + DeferredAttribute.__get__() is called when a deferred + field is accessed. It may or may not trigger a db query; + if it does, it's going to be a refresh_from_db() call + So we'll emit a `touch` from there + """ + if instance is None: + return self + # Refresh-from-db, intenally, calls QuerySet.get() on our + # instance. Normally, this would make our instance immune + # to further notifications. We don't want that to happen, + # so we disable the ignore_load signal within refresh_from_db + orig_refresh_from_db = instance.refresh_from_db + def refresh_from_db(*args, **kwargs): + with signals.ignore(signals.ignore_load): + return orig_refresh_from_db(*args, **kwargs) + instance.refresh_from_db = signals.signalify( + signals.lazy_load, + refresh_from_db, + parser=functools.partial(parse_refresh_from_db, instance), + ) + return original_deferred_attribute_get(self, instance, cls) +DeferredAttribute.__get__ = deferred_attribute_get diff --git a/tests/testapp/testapp/models.py b/tests/testapp/testapp/models.py index 6ffc15a..5daf07d 100644 --- a/tests/testapp/testapp/models.py +++ b/tests/testapp/testapp/models.py @@ -26,3 +26,7 @@ class Address(models.Model): class Hobby(models.Model): pass + + +class Medicine(models.Model): + name = models.CharField(max_length=20) diff --git a/tests/testapp/testapp/tests.py b/tests/testapp/testapp/tests.py index 3db4d9a..0c7d03e 100644 --- a/tests/testapp/testapp/tests.py +++ b/tests/testapp/testapp/tests.py @@ -8,6 +8,7 @@ from django.conf import settings from django.http.request import HttpRequest from django.http.response import HttpResponse +from django.test import override_settings from nplusone.ext.django.patch import setup_state from nplusone.ext.django.middleware import NPlusOneMiddleware @@ -32,6 +33,7 @@ def objects(): address = models.Address.objects.create(user=user) hobby = models.Hobby.objects.create() user.hobbies.add(hobby) + medicine = models.Medicine.objects.create(name="Allergix") return locals() @@ -133,6 +135,23 @@ def test_many_to_many_reverse_prefetch(self, objects, calls): assert len(calls) == 0 +@pytest.mark.django_db +class TestDeferred: + + def test_deferred(self, objects, calls): + medicine = list(models.Medicine.objects.defer('name'))[0] + medicine.name + assert len(calls) == 1 + call = calls[0] + assert call.objects == (models.Medicine, 'Medicine:1', 'name') + assert 'medicine.name' in ''.join(call.frame[4]) + + def test_non_deferred(self, objects, calls): + medicine = list(models.Medicine.objects.all())[0] + medicine.name + assert len(calls) == 0 + + @pytest.fixture def logger(monkeypatch): mock_logger = mock.Mock() @@ -272,16 +291,22 @@ def test_select_nested_unused(self, objects, client, logger): assert any('Pet.user' in call[1] for call in calls) assert any('User.occupation' in call[1] for call in calls) + @override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.User'}]) def test_many_to_many_whitelist(self, objects, client, logger): - settings.NPLUSONE_WHITELIST = [{'model': 'testapp.User'}] client.get('/many_to_many/') assert not logger.log.called + @override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.*'}]) def test_many_to_many_whitelist_wildcard(self, objects, client, logger): - settings.NPLUSONE_WHITELIST = [{'model': 'testapp.*'}] client.get('/many_to_many/') assert not logger.log.called + def test_deferred(self, objects, client, logger): + client.get('/deferred/') + assert len(logger.log.call_args_list) == 1 + args = logger.log.call_args[0] + assert 'Medicine.name' in args[1] + @pytest.mark.django_db def test_values(objects, lazy_listener): diff --git a/tests/testapp/testapp/urls.py b/tests/testapp/testapp/urls.py index 510183b..be50425 100644 --- a/tests/testapp/testapp/urls.py +++ b/tests/testapp/testapp/urls.py @@ -25,4 +25,5 @@ url(r'^prefetch_nested_unused/$', views.prefetch_nested_unused), url(r'^select_nested/$', views.select_nested), url(r'^select_nested_unused/$', views.select_nested_unused), + url(r'^deferred/$', views.deferred), ] diff --git a/tests/testapp/testapp/views.py b/tests/testapp/testapp/views.py index 048b6fd..fe262f6 100644 --- a/tests/testapp/testapp/views.py +++ b/tests/testapp/testapp/views.py @@ -127,3 +127,8 @@ def select_nested(request): def select_nested_unused(request): pets = list(models.Pet.objects.all().select_related('user__occupation')) return HttpResponse(pets[0]) + + +def deferred(request): + meds = list(models.Medicine.objects.defer('name')) + return HttpResponse("; ".join(med.name for med in meds))