Skip to content

Commit

Permalink
perf(Django): Use select_related to avoid N+1 queries (#403)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn authored Aug 28, 2024
1 parent 0628f62 commit 7302a0e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
11 changes: 6 additions & 5 deletions open_prices/api/prices/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ def setUpTestData(cls):

def test_price_list(self):
# anonymous
response = self.client.get(self.url)
self.assertEqual(response.data["total"], 3)
self.assertEqual(len(response.data["items"]), 3)
self.assertTrue("id" in response.data["items"][0])
self.assertEqual(response.data["items"][0]["price"], 15.00) # default order
with self.assertNumQueries(1 + 1): # thanks to select_related
response = self.client.get(self.url)
self.assertEqual(response.data["total"], 3)
self.assertEqual(len(response.data["items"]), 3)
self.assertTrue("id" in response.data["items"][0])
self.assertEqual(response.data["items"][0]["price"], 15.00) # default order


class PriceListPaginationApiTest(TestCase):
Expand Down
4 changes: 3 additions & 1 deletion open_prices/api/prices/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class PriceViewSet(
ordering = ["created"]

def get_queryset(self):
if self.request.method in ["PATCH", "DELETE"]:
if self.request.method in ["GET"]:
return self.queryset.select_related("product", "location", "proof")
elif self.request.method in ["PATCH", "DELETE"]:
# only return prices owned by the current user
if self.request.user.is_authenticated:
return Price.objects.filter(owner=self.request.user.user_id)
Expand Down
30 changes: 21 additions & 9 deletions open_prices/api/proofs/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from django.db.models import signals
from django.test import TestCase
from django.urls import reverse

from open_prices.locations import constants as location_constants
from open_prices.locations.models import (
Location,
location_post_create_fetch_data_from_openstreetmap,
)
from open_prices.proofs import constants as proof_constants
from open_prices.proofs.factories import ProofFactory
from open_prices.proofs.models import Proof
Expand All @@ -20,6 +26,8 @@ def setUpTestData(cls):
ProofFactory(price_count=0)
ProofFactory(
type=proof_constants.TYPE_PRICE_TAG,
location_osm_id=652825274,
location_osm_type=location_constants.OSM_TYPE_NODE,
price_count=50,
owner=cls.user_session.user.user_id,
)
Expand All @@ -34,15 +42,16 @@ def test_proof_list(self):
)
self.assertEqual(response.status_code, 403)
# authenticated
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session.token}"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["total"], 2) # only user's proofs
self.assertEqual(len(response.data["items"]), 2)
self.assertEqual(
response.data["items"][0]["id"], self.proof.id
) # default order
with self.assertNumQueries(4 + 1): # thanks to select_related
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session.token}"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["total"], 2) # only user's proofs
self.assertEqual(len(response.data["items"]), 2)
self.assertEqual(
response.data["items"][0]["id"], self.proof.id
) # default order


class ProofListOrderApiTest(TestCase):
Expand Down Expand Up @@ -129,6 +138,9 @@ def test_proof_detail(self):
class ProofCreateApiTest(TestCase):
@classmethod
def setUpTestData(cls):
signals.post_save.disconnect(
location_post_create_fetch_data_from_openstreetmap, sender=Location
)
cls.url = reverse("api:proofs-upload")
cls.user_session = SessionFactory()
cls.data = {
Expand Down
5 changes: 4 additions & 1 deletion open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ class ProofViewSet(
def get_queryset(self):
# only return proofs owned by the current user
if self.request.user.is_authenticated:
return Proof.objects.filter(owner=self.request.user.user_id)
queryset = Proof.objects.filter(owner=self.request.user.user_id)
if self.request.method in ["GET"]:
return queryset.select_related("location")
return queryset
return self.queryset

def get_serializer_class(self):
Expand Down

0 comments on commit 7302a0e

Please sign in to comment.