Skip to content

Commit

Permalink
refactor(proofs): simple update_location() method (#437)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn authored Sep 11, 2024
1 parent c2bfb1d commit d54ded6
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 25 deletions.
40 changes: 20 additions & 20 deletions open_prices/prices/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,31 +351,31 @@ def get_or_create_location(self):

def save(self, *args, **kwargs):
self.full_clean()
if not self.id: # new price
# self.set_proof() # should already exist
self.get_or_create_product()
self.get_or_create_location()
# self.set_proof() # should already exist
self.get_or_create_product()
self.get_or_create_location()
super().save(*args, **kwargs)


@receiver(signals.post_save, sender=Price)
def price_post_create_increment_counts(sender, instance, created, **kwargs):
if instance.owner:
User.objects.filter(user_id=instance.owner).update(
price_count=F("price_count") + 1
)
if instance.proof_id:
Proof.objects.filter(id=instance.proof_id).update(
price_count=F("price_count") + 1
)
if instance.product_id:
Product.objects.filter(id=instance.product_id).update(
price_count=F("price_count") + 1
)
if instance.location_id:
Location.objects.filter(id=instance.location_id).update(
price_count=F("price_count") + 1
)
if created:
if instance.owner:
User.objects.filter(user_id=instance.owner).update(
price_count=F("price_count") + 1
)
if instance.proof_id:
Proof.objects.filter(id=instance.proof_id).update(
price_count=F("price_count") + 1
)
if instance.product_id:
Product.objects.filter(id=instance.product_id).update(
price_count=F("price_count") + 1
)
if instance.location_id:
Location.objects.filter(id=instance.location_id).update(
price_count=F("price_count") + 1
)


@receiver(signals.post_delete, sender=Price)
Expand Down
22 changes: 20 additions & 2 deletions open_prices/proofs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,28 @@ def set_location(self):

def save(self, *args, **kwargs):
self.full_clean()
if not self.id:
self.set_location()
self.set_location()
super().save(*args, **kwargs)

def update_price_count(self):
self.price_count = self.prices.count()
self.save(update_fields=["price_count"])

def update_location(self, location_osm_id, location_osm_type):
old_location = self.location
# update proof location
self.location_osm_id = location_osm_id
self.location_osm_type = location_osm_type
self.save()
# update proof's prices location
for price in self.prices.all():
price.location_osm_id = location_osm_id
price.location_osm_type = location_osm_type
price.save()
# update old & new location price counts
self.refresh_from_db()
new_location = self.location
if old_location:
old_location.update_price_count()
if new_location:
new_location.update_price_count()
61 changes: 58 additions & 3 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,22 @@
from django.test import TestCase

from open_prices.locations import constants as location_constants
from open_prices.locations.factories import LocationFactory
from open_prices.prices.factories import PriceFactory
from open_prices.proofs.factories import ProofFactory
from open_prices.proofs.models import Proof

LOCATION_NODE_652825274 = {
"osm_id": 652825274,
"osm_type": "NODE",
"osm_name": "Monoprix",
}
# LOCATION_NODE_6509705997 = {
# "osm_id": 6509705997,
# "osm_type": "NODE",
# "osm_name": "Carrefour",
# }


class ProofModelSaveTest(TestCase):
@classmethod
Expand Down Expand Up @@ -70,9 +82,22 @@ def test_with_stats(self):
class ProofPropertyTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.proof = ProofFactory()
PriceFactory(proof_id=cls.proof.id, price=1.0)
PriceFactory(proof_id=cls.proof.id, price=2.0)
cls.location = LocationFactory(**LOCATION_NODE_652825274)
cls.proof = ProofFactory(
location_osm_id=cls.location.osm_id, location_osm_type=cls.location.osm_type
)
PriceFactory(
proof_id=cls.proof.id,
location_osm_id=cls.location.osm_id,
location_osm_type=cls.location.osm_type,
price=1.0,
)
PriceFactory(
proof_id=cls.proof.id,
location_osm_id=cls.location.osm_id,
location_osm_type=cls.location.osm_type,
price=2.0,
)

def test_update_price_count(self):
self.proof.refresh_from_db()
Expand All @@ -83,3 +108,33 @@ def test_update_price_count(self):
# update_price_count() should fix price_count
self.proof.update_price_count()
self.assertEqual(self.proof.price_count, 0)

def test_update_location(self):
# existing
self.proof.refresh_from_db()
self.location.refresh_from_db()
self.assertEqual(self.proof.price_count, 2)
self.assertEqual(self.proof.location.id, self.location.id)
self.assertEqual(self.location.price_count, 2)
# update location
self.proof.update_location(
location_osm_id=6509705997,
location_osm_type=location_constants.OSM_TYPE_NODE,
)
# check changes
self.proof.refresh_from_db()
self.location.refresh_from_db()
new_location = self.proof.location
self.assertNotEqual(self.location, new_location)
self.assertEqual(self.proof.price_count, 2)
self.assertEqual(new_location.price_count, 2)
self.assertEqual(self.location.price_count, 0)
# update again, same location
self.proof.update_location(
location_osm_id=6509705997,
location_osm_type=location_constants.OSM_TYPE_NODE,
)
self.proof.refresh_from_db()
self.location.refresh_from_db()
self.assertEqual(self.proof.price_count, 2)
self.assertEqual(self.proof.location.price_count, 2)

0 comments on commit d54ded6

Please sign in to comment.