Skip to content

Commit

Permalink
Merge remote-tracking branch 'otto001/related_accessor' into select_r…
Browse files Browse the repository at this point in the history
…elated
  • Loading branch information
pgammans committed Jun 27, 2023
2 parents 7ab11a3 + 705acf4 commit 5eb5e3e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
21 changes: 10 additions & 11 deletions polymorphic/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,15 @@ def __init__(self, *args, **kwargs):
return
self.__class__.polymorphic_super_sub_accessors_replaced = True

def create_accessor_function_for_model(model, accessor_name):
NOT_PROVIDED = object()

def create_accessor_function_for_model(model, field):
def accessor_function(self):
attr = NOT_PROVIDED
try:
attr = self._state.fields_cache[accessor_name]
pass
rel_obj = field.get_cached_value(self)
except KeyError:
pass
if attr is NOT_PROVIDED:
objects = getattr(model, "_base_objects", model.objects)
attr = objects.get(pk=self.pk)
return attr
rel_obj = objects.get(pk=self.pk)
field.set_cached_value(self, rel_obj)
return rel_obj

return accessor_function

Expand All @@ -223,10 +218,14 @@ def accessor_function(self):
type(orig_accessor),
(ReverseOneToOneDescriptor, ForwardManyToOneDescriptor),
):

field = orig_accessor.related \
if isinstance(orig_accessor, ReverseOneToOneDescriptor) else orig_accessor.field

setattr(
self.__class__,
name,
property(create_accessor_function_for_model(model, name)),
property(create_accessor_function_for_model(model, field)),
)

def _get_inheritance_relation_fields_and_models(self):
Expand Down
23 changes: 23 additions & 0 deletions polymorphic/tests/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,29 @@ def test_parent_link_and_related_name(self):
# test that we can delete the object
t.delete()

def test_polymorphic__accessor_caching(self):
blog_a = BlogA.objects.create(name="blog")

blog_base = BlogBase.objects.non_polymorphic().get(id=blog_a.id)
blog_a = BlogA.objects.get(id=blog_a.id)

# test reverse accessor & check that we get back cached object on repeated access
self.assertEqual(blog_base.bloga, blog_a)
self.assertIs(blog_base.bloga, blog_base.bloga)
cached_blog_a = blog_base.bloga

# test forward accessor & check that we get back cached object on repeated access
self.assertEqual(blog_a.blogbase_ptr, blog_base)
self.assertIs(blog_a.blogbase_ptr, blog_a.blogbase_ptr)
cached_blog_base = blog_a.blogbase_ptr

# check that refresh_from_db correctly clears cached related objects
blog_base.refresh_from_db()
blog_a.refresh_from_db()

self.assertIsNot(cached_blog_a, blog_base.bloga)
self.assertIsNot(cached_blog_base, blog_a.blogbase_ptr)

def test_polymorphic__aggregate(self):
"""test ModelX___field syntax on aggregate (should work for annotate either)"""

Expand Down

0 comments on commit 5eb5e3e

Please sign in to comment.