From 3a14cc6fd2cf73c2ef6a549ebc9ca822cf1ecd2c Mon Sep 17 00:00:00 2001 From: Shai Berger Date: Wed, 16 Jun 2021 17:45:57 +0300 Subject: [PATCH] Detect cases of n+1 caused by Django deferred fields --- nplusone/__init__.py | 2 +- nplusone/ext/django/patch.py | 53 +++++++++++++++++++++++++++++++++ tests/testapp/testapp/models.py | 5 ++++ tests/testapp/testapp/tests.py | 36 ++++++++++++++++++++-- tests/testapp/testapp/urls.py | 2 ++ tests/testapp/testapp/views.py | 11 +++++++ 6 files changed, 106 insertions(+), 3 deletions(-) diff --git a/nplusone/__init__.py b/nplusone/__init__.py index 1f356cc..46a2620 100644 --- a/nplusone/__init__.py +++ b/nplusone/__init__.py @@ -1 +1 @@ -__version__ = '1.0.0' +__version__ = '1.1.0a1' diff --git a/nplusone/ext/django/patch.py b/nplusone/ext/django/patch.py index cd7db13..3903680 100644 --- a/nplusone/ext/django/patch.py +++ b/nplusone/ext/django/patch.py @@ -17,6 +17,9 @@ create_reverse_many_to_one_manager, create_forward_many_to_many_manager, ) +from django.db.models.query_utils import DeferredAttribute + +NPLUSONE_WRAPPED = 'nplusone_wrapped' def get_worker(): @@ -345,3 +348,53 @@ def getitem_queryset(self, index): ) return original_getitem_queryset(self, index) query.QuerySet.__getitem__ = getitem_queryset + + +def parse_refresh_from_db(instance, fields, args, kwargs, context): + # Instance & fields passed via partial + model = type(instance) + return model, to_key(instance), fields[0] + + +original_deferred_attribute_get = DeferredAttribute.__get__ +def deferred_attribute_get(self, instance, cls=None): + """ + DeferredAttribute.__get__() is called when a deferred + field is accessed. It may or may not trigger a db query; + if it does, it's going to be a refresh_from_db() call + So we'll emit a `touch` from there + """ + if instance is None: + return self + # Refresh-from-db, intenally, calls QuerySet.get() on our + # instance. Normally, this would make our instance immune + # to further notifications. We don't want that to happen, + # so we disable the ignore_load signal within refresh_from_db + ensure_wrapped_refresh_from_db(instance) + return original_deferred_attribute_get(self, instance, cls) +DeferredAttribute.__get__ = deferred_attribute_get + + +def ensure_wrapped_refresh_from_db(instance): + orig_refresh_from_db = instance.refresh_from_db + if getattr(orig_refresh_from_db, NPLUSONE_WRAPPED, False): + return + @functools.wraps(orig_refresh_from_db) + def refresh_from_db(fields=None, *args, **kwargs): + with signals.ignore(signals.ignore_load): + ret = orig_refresh_from_db(fields=fields, **kwargs) + # and now, if the refresh_from_db was called for specific fields, + # then it's a lazy load + if fields: + parser = functools.partial(parse_refresh_from_db, instance, fields) + signals.lazy_load.send( + get_worker(), + args=args, + kwargs=kwargs, + ret=ret, + context={}, + parser=parser, + ) + return ret + setattr(refresh_from_db, NPLUSONE_WRAPPED, True) + instance.refresh_from_db = refresh_from_db diff --git a/tests/testapp/testapp/models.py b/tests/testapp/testapp/models.py index 6ffc15a..088330b 100644 --- a/tests/testapp/testapp/models.py +++ b/tests/testapp/testapp/models.py @@ -26,3 +26,8 @@ class Address(models.Model): class Hobby(models.Model): pass + + +class Medicine(models.Model): + name = models.CharField(max_length=20) + prescription = models.BooleanField(default=False) diff --git a/tests/testapp/testapp/tests.py b/tests/testapp/testapp/tests.py index 3db4d9a..a799c8d 100644 --- a/tests/testapp/testapp/tests.py +++ b/tests/testapp/testapp/tests.py @@ -8,6 +8,7 @@ from django.conf import settings from django.http.request import HttpRequest from django.http.response import HttpResponse +from django.test import override_settings from nplusone.ext.django.patch import setup_state from nplusone.ext.django.middleware import NPlusOneMiddleware @@ -32,6 +33,7 @@ def objects(): address = models.Address.objects.create(user=user) hobby = models.Hobby.objects.create() user.hobbies.add(hobby) + medicine = models.Medicine.objects.create(name="Allergix") return locals() @@ -133,6 +135,23 @@ def test_many_to_many_reverse_prefetch(self, objects, calls): assert len(calls) == 0 +@pytest.mark.django_db +class TestDeferred: + + def test_deferred(self, objects, calls): + medicine = list(models.Medicine.objects.defer('name'))[0] + medicine.name + assert len(calls) == 1 + call = calls[0] + assert call.objects == (models.Medicine, 'Medicine:1', 'name') + assert 'medicine.name' in ''.join(call.frame[4]) + + def test_non_deferred(self, objects, calls): + medicine = list(models.Medicine.objects.all())[0] + medicine.name + assert len(calls) == 0 + + @pytest.fixture def logger(monkeypatch): mock_logger = mock.Mock() @@ -272,16 +291,29 @@ def test_select_nested_unused(self, objects, client, logger): assert any('Pet.user' in call[1] for call in calls) assert any('User.occupation' in call[1] for call in calls) + @override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.User'}]) def test_many_to_many_whitelist(self, objects, client, logger): - settings.NPLUSONE_WHITELIST = [{'model': 'testapp.User'}] client.get('/many_to_many/') assert not logger.log.called + @override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.*'}]) def test_many_to_many_whitelist_wildcard(self, objects, client, logger): - settings.NPLUSONE_WHITELIST = [{'model': 'testapp.*'}] client.get('/many_to_many/') assert not logger.log.called + def test_deferred(self, objects, client, logger): + client.get('/deferred/') + assert len(logger.log.call_args_list) == 1 + args = logger.log.call_args[0] + assert 'Medicine.name' in args[1] + + def test_double_deferred(self, objects, client, logger): + client.get('/double_deferred/') + assert len(logger.log.call_args_list) == 2 + messages = sorted({args[0][1] for args in logger.log.call_args_list}) + assert 'Medicine.name' in messages[0] + assert 'Medicine.prescription' in messages[1] + @pytest.mark.django_db def test_values(objects, lazy_listener): diff --git a/tests/testapp/testapp/urls.py b/tests/testapp/testapp/urls.py index 510183b..223186d 100644 --- a/tests/testapp/testapp/urls.py +++ b/tests/testapp/testapp/urls.py @@ -25,4 +25,6 @@ url(r'^prefetch_nested_unused/$', views.prefetch_nested_unused), url(r'^select_nested/$', views.select_nested), url(r'^select_nested_unused/$', views.select_nested_unused), + url(r'^deferred/$', views.deferred), + url(r'^double_deferred/$', views.double_deferred), ] diff --git a/tests/testapp/testapp/views.py b/tests/testapp/testapp/views.py index 048b6fd..4945227 100644 --- a/tests/testapp/testapp/views.py +++ b/tests/testapp/testapp/views.py @@ -127,3 +127,14 @@ def select_nested(request): def select_nested_unused(request): pets = list(models.Pet.objects.all().select_related('user__occupation')) return HttpResponse(pets[0]) + + +def deferred(request): + meds = list(models.Medicine.objects.defer('name')) + return HttpResponse("; ".join(med.name for med in meds)) + +def double_deferred(request): + meds = list(models.Medicine.objects.only('id')) + return HttpResponse("; ".join( + med.name + (' *' if med.prescription else '') for med in meds + ))