From f06de2f7d856e4919417ad46c59d25f437608231 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Mon, 25 Mar 2024 13:11:40 +0000 Subject: [PATCH] add select_related and partial prefetch_related support implement support for a single query for select related base fetches across polymorphic models. adds a polymorphic QuerySet Mixin to enable non polymorphic models to fetch related models. fixes: #198 #436 #359 #244 possible fixes: #498: support for prefetch_related cannot fetch attributes not on all child models or via class names related: #531 --- polymorphic/query.py | 531 ++++++++++++++++- polymorphic/tests/migrations/0001_initial.py | 265 ++++++++- polymorphic/tests/models.py | 69 ++- polymorphic/tests/test_orm.py | 570 ++++++++++++++++++- 4 files changed, 1387 insertions(+), 48 deletions(-) diff --git a/polymorphic/query.py b/polymorphic/query.py index 8e93281a..475e6c20 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -3,18 +3,24 @@ """ import copy +import functools +import operator from collections import defaultdict from django.contrib.contenttypes.models import ContentType from django.core.exceptions import FieldDoesNotExist from django.db.models import FilteredRelation +from django.db.models.constants import LOOKUP_SEP from django.db.models.query import ModelIterable, Q, QuerySet - +from django.db.models.query import BaseIterable, RelatedPopulator from .query_translate import ( translate_polymorphic_field_path, translate_polymorphic_filter_definitions_in_args, translate_polymorphic_filter_definitions_in_kwargs, translate_polymorphic_Q_object, + _get_query_related_name, + _get_all_sub_models, + _create_base_path, ) # chunk-size: maximum number of objects requested per db-request @@ -22,52 +28,398 @@ Polymorphic_QuerySet_objects_per_request = 100 +def merge_dicts(primary, secondary): + """Deep merge two dicts + + Items from the primary dict are preserved in preference to those on the + secondary dict""" + + for k, v in secondary.items(): + if k in primary: + primary[k] = merge_dicts(primary[k], v) + else: + primary[k] = copy.deepcopy(v) + return primary + + +def search_object_cache(obj, source_model, target_model): + for search_part in _create_base_path(source_model, target_model).split("__"): + try: + obj = obj._state.fields_cache[search_part] + except KeyError: + return + return obj + + +class VanillaRelatedPopulator(RelatedPopulator): + def __init__(self, klass_info, select, db): + super().__init__(klass_info, select, db) + self.field = klass_info["field"] + self.reverse = klass_info["reverse"] + + def build_related(self, row, from_obj, *_): + self.populate(row, from_obj) + + +class RelatedPolymorphicPopulator: + """ + RelatedPopulator is used for select_related() object instantiation. + The idea is that each select_related() model will be populated by a + different RelatedPopulator instance. The RelatedPopulator instances get + klass_info and select (computed in SQLCompiler) plus the used db as + input for initialization. That data is used to compute which columns + to use, how to instantiate the model, and how to populate the links + between the objects. + The actual creation of the objects is done in populate() method. This + method gets row and from_obj as input and populates the select_related() + model instance. + """ + + def __init__(self, klass_info, select, db): + self.db = db + # Pre-compute needed attributes. The attributes are: + # - model_cls: the possibly deferred model class to instantiate + # - either: + # - cols_start, cols_end: usually the columns in the row are + # in the same order model_cls.__init__ expects them, so we + # can instantiate by model_cls(*row[cols_start:cols_end]) + # - reorder_for_init: When select_related descends to a child + # class, then we want to reuse the already selected parent + # data. However, in this case the parent data isn't necessarily + # in the same order that Model.__init__ expects it to be, so + # we have to reorder the parent data. The reorder_for_init + # attribute contains a function used to reorder the field data + # in the order __init__ expects it. + # - pk_idx: the index of the primary key field in the reordered + # model data. Used to check if a related object exists at all. + # - init_list: the field attnames fetched from the database. For + # deferred models this isn't the same as all attnames of the + # model's fields. + # - related_populators: a list of RelatedPopulator instances if + # select_related() descends to related models from this model. + # - local_setter, remote_setter: Methods to set cached values on + # the object being populated and on the remote object. Usually + # these are Field.set_cached_value() methods. + select_fields = klass_info["select_fields"] + from_parent = klass_info["from_parent"] + if not from_parent: + self.cols_start = select_fields[0] + self.cols_end = select_fields[-1] + 1 + self.init_list = [f[0].target.attname for f in select[self.cols_start : self.cols_end]] + self.reorder_for_init = None + else: + attname_indexes = {select[idx][0].target.attname: idx for idx in select_fields} + model_init_attnames = (f.attname for f in klass_info["model"]._meta.concrete_fields) + self.init_list = [ + attname for attname in model_init_attnames if attname in attname_indexes + ] + self.reorder_for_init = operator.itemgetter( + *[attname_indexes[attname] for attname in self.init_list] + ) + + self.model_cls = klass_info["model"] + self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) + self.related_populators = get_related_populators(klass_info, select, self.db) + self.local_setter = klass_info["local_setter"] + self.remote_setter = klass_info["remote_setter"] + self.field = klass_info["field"] + self.reverse = klass_info["reverse"] + self.content_type_manager = ContentType.objects.db_manager(self.db) + self.model_class_id = self.content_type_manager.get_for_model( + self.model_cls, for_concrete_model=False + ).pk + self.concrete_model_class_id = self.content_type_manager.get_for_model( + self.model_cls, for_concrete_model=True + ).pk + + def build_related(self, row, from_obj, post_actions): + if self.reorder_for_init: + obj_data = self.reorder_for_init(row) + else: + obj_data = row[self.cols_start : self.cols_end] + + if obj_data[self.pk_idx] is None: + obj = None + else: + obj = self.model_cls.from_db(self.db, self.init_list, obj_data) + self.post_build_modify( + obj, + from_obj, + post_actions, + functools.partial(self._populate, row, from_obj, post_actions), + ) + + def _populate(self, row, from_obj, post_actions, obj): + for rel_iter in self.related_populators: + rel_iter.build_related(row, obj, post_actions) + + self.local_setter(from_obj, obj) + if obj is not None: + self.remote_setter(obj, from_obj) + + def post_build_modify(self, base_object, from_obj, post_actions, populate_fn): + if base_object.polymorphic_ctype_id == self.model_class_id: + # Real class is exactly the same as base class, go straight to results + populate_fn(base_object) + else: + real_concrete_class = base_object.get_real_instance_class() + real_concrete_class_id = base_object.get_real_concrete_instance_class_id() + + if real_concrete_class_id is None: + # Dealing with a stale content type + populate_fn(None) + return False + elif real_concrete_class_id == self.concrete_model_class_id: + # Real and base classes share the same concrete ancestor, + # upcast it and put it in the results + populate_fn(transmogrify(real_concrete_class, base_object)) + return False + else: + # This model has a concrete derived class: either track it for bulk + # retrieval or if it is already fetched as part of a select_related + # enable pivoting to that object + real_concrete_class = self.content_type_manager.get_for_id( + real_concrete_class_id + ).model_class() + populate_fn(base_object) + post_actions.append( + ( + functools.partial( + self.pivot_onto_cached_subclass, + from_obj, + base_object, + real_concrete_class, + ), + populate_fn, + ) + ) + + def pivot_onto_cached_subclass(self, from_obj, obj, model_target_cls): + """Pivot to final polymorphic class. + + Pivot the object created from the base query onto the true polymorphic + instance, we need to ensure that this is only done on objects that are + from non parent-child type relationships. + + If we cannot pivot we return info to be used in the PolymorphicModelIterable + to ensure the correct model loaded from the additional bulk queries + """ + + original = obj + parents = model_target_cls()._get_inheritance_relation_fields_and_models() + for cls in reversed(model_target_cls.mro()[: -len(self.model_cls.mro())]): + for rel_iter in self.related_populators: + if not isinstance( + rel_iter, (VanillaRelatedPopulator, RelatedPolymorphicPopulator) + ): + continue + if rel_iter.reverse and rel_iter.model_cls is cls: + if rel_iter.field.name in parents.keys(): + obj = getattr(obj, rel_iter.field.remote_field.name) + + if not isinstance(obj, model_target_cls): + # This allow pivoting of object that are descendants of the original field + if not original._meta.get_path_to_parent(from_obj._meta.model): + obj = search_object_cache(original, original._meta.model, model_target_cls) + + if isinstance(obj, model_target_cls): + # We only want to pivot onto a field from a different object, ie not a parent/child + # relationship as this will break the cache and other object relationships + if not original._meta.get_path_to_parent(from_obj._meta.model): + self.local_setter(from_obj, obj) + if obj is not None: + self.remote_setter(obj, from_obj) + return None, None + + pk_name = self.model_cls.polymorphic_primary_key_name + return model_target_cls, (getattr(original, pk_name), self.field.name) + + +def get_related_populators(klass_info, select, db): + from .models import PolymorphicModel + + iterators = [] + related_klass_infos = klass_info.get("related_klass_infos", []) + for rel_klass_info in related_klass_infos: + model = rel_klass_info["model"] + if issubclass(model, PolymorphicModel): + rel_cls = RelatedPolymorphicPopulator(rel_klass_info, select, db) + else: + rel_cls = VanillaRelatedPopulator(rel_klass_info, select, db) + iterators.append(rel_cls) + return iterators + + class PolymorphicModelIterable(ModelIterable): """ ModelIterable for PolymorphicModel Yields real instances if qs.polymorphic_disabled is False, - otherwise acts like a regular ModelIterable. + otherwise acts like a regular ModelIterable. We inherit from + ModelIterable non base BaseIterable even though we completely + replace it, but this allows Django test in Prefetch to work """ def __iter__(self): - base_iter = super().__iter__() - if self.queryset.polymorphic_disabled: - return base_iter - return self._polymorphic_iterator(base_iter) - - def _polymorphic_iterator(self, base_iter): - """ - Here we do the same as:: - - real_results = queryset._get_real_instances(list(base_iter)) - for o in real_results: yield o - - but it requests the objects in chunks from the database, - with Polymorphic_QuerySet_objects_per_request per chunk - """ + queryset = self.queryset + db = queryset.db + compiler = queryset.query.get_compiler(using=db) + # Execute the query. This will also fill compiler.select, klass_info, + # and annotations. + results = compiler.execute_sql( + chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size + ) + select, klass_info, annotation_col_map = ( + compiler.select, + compiler.klass_info, + compiler.annotation_col_map, + ) + model_cls = klass_info["model"] + select_fields = klass_info["select_fields"] + model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1 + init_list = [f[0].target.attname for f in select[model_fields_start:model_fields_end]] + related_populators = get_related_populators(klass_info, select, db) + known_related_objects = [ + ( + field, + related_objs, + operator.attrgetter( + *[ + field.attname + if from_field == "self" + else queryset.model._meta.get_field(from_field).attname + for from_field in field.from_fields + ] + ), + ) + for field, related_objs in queryset._known_related_objects.items() + ] + base_iter = compiler.results_iter(results) while True: + result_objects = [] base_result_objects = [] reached_end = False # Make sure the base iterator is read in chunks instead of # reading it completely, in case our caller read only a few objects. + post_actions = list() for i in range(Polymorphic_QuerySet_objects_per_request): + # dict contains one entry per unique model type occurring in result, + # in the format idlist_per_model[modelclass]=[list-of-object-ids] try: - o = next(base_iter) - base_result_objects.append(o) + row = next(base_iter) + obj = model_cls.from_db( + db, init_list, row[model_fields_start:model_fields_end] + ) + for rel_populator in related_populators: + rel_populator.build_related(row, obj, post_actions) + base_result_objects.append([row, obj]) except StopIteration: reached_end = True break - real_results = self.queryset._get_real_instances(base_result_objects) + if not self.queryset.polymorphic_disabled: + self.fetch_polymorphic(post_actions, base_result_objects) + + for row, obj in base_result_objects: + if annotation_col_map: + for attr_name, col_pos in annotation_col_map.items(): + setattr(obj, attr_name, row[col_pos]) + + # Add the known related objects to the model. + for field, rel_objs, rel_getter in known_related_objects: + # Avoid overwriting objects loaded by, e.g., select_related(). + if field.is_cached(obj): + continue + rel_obj_id = rel_getter(obj) + try: + rel_obj = rel_objs[rel_obj_id] + except KeyError: + pass # May happen in qs1 | qs2 scenarios. + else: + setattr(obj, field.name, rel_obj) + result_objects.append(obj) + + if not self.queryset.polymorphic_disabled: + result_objects = self.queryset._get_real_instances(result_objects) - for o in real_results: + for o in result_objects: yield o if reached_end: return + def apply_select_related(self, qs, relations): + if self.queryset.query.select_related is True: + return qs.select_related() + + model_name = qs.model.__name__.lower() + if isinstance(self.queryset.query.select_related, dict): + select_related = {} + if isinstance(qs.query.select_related, dict): + select_related = qs.query.select_related + for k, v in self.queryset.query.select_related.items(): + if k in relations: + if not isinstance(select_related, dict): + select_related = {} + if isinstance(v, dict): + if model_name in v: + select_related = merge_dicts( + select_related, v[model_name]) + else: + for field in qs.model._meta.fields: + if field.name in v: + select_related = merge_dicts( + select_related, v[field.name]) + else: + select_related = merge_dicts(select_related, v) + qs.query.select_related = select_related + return qs + + def fetch_polymorphic(self, post_actions, base_result_objects): + update_fn_per_model = defaultdict(list) + idlist_per_model = defaultdict(list) + + for action, populate_fn in post_actions: + target_class, pk_info = action() + if target_class: + pk, name = pk_info + idlist_per_model[target_class].append((pk, name)) + update_fn_per_model[target_class].append((populate_fn, pk)) + + # For each model in "idlist_per_model" request its objects (the real model) + # from the db and store them in results[]. + # Then we copy the annotate fields from the base objects to the real objects. + # Then we copy the extra() select fields from the base objects to the real objects. + # TODO: defer(), only(): support for these would be around here + for real_concrete_class, data in idlist_per_model.items(): + idlist, names = zip(*data) + updates = update_fn_per_model[real_concrete_class] + pk_name = real_concrete_class.polymorphic_primary_key_name + real_objects = real_concrete_class._base_objects.db_manager(self.queryset.db).filter( + **{("%s__in" % pk_name): idlist} + ) + + real_objects = self.apply_select_related(real_objects, set(names)) + real_objects_dict = { + getattr(real_object, pk_name): real_object for real_object in real_objects + } + + for populate_fn, o_pk in updates: + real_object = real_objects_dict.get(o_pk) + if real_object is None: + continue + + # need shallow copy to avoid duplication in caches (see PR #353) + real_object = copy.copy(real_object) + real_class = real_object.get_real_instance_class() + + # If the real class is a proxy, upcast it + if real_class != real_concrete_class: + real_object = transmogrify(real_class, real_object) + + populate_fn(real_object) + def transmogrify(cls, obj): """ @@ -89,7 +441,64 @@ def transmogrify(cls, obj): # PolymorphicQuerySet -class PolymorphicQuerySet(QuerySet): +class PolymorphicQuerySetMixin(QuerySet): + def select_related(self, *fields): + if fields == (None,) or not len(fields): + return super().select_related(*fields) + field_with_poly = list(self.convert_related_fieldnames(fields)) + return super().select_related(*field_with_poly) + + def _convert_field_name_part(self, field_parts, model): + """ + recursively convert a fieldname into (model, filedname) + """ + field = None + part = field_parts[0] + next_parts = field_parts[1:] + field_path = [] + rel_model = None + try: + field = model._meta.get_field(part) + field_path = [part] + yield field_path + + if field.is_relation: + rel_model = field.related_model + if next_parts: + self._convert_field_name_part(next_parts, rel_model) + else: + rel_model = model + + except FieldDoesNotExist: + submodels = _get_all_sub_models(model) + rel_model = submodels.get(part, None) + field_path = list(_create_base_path(model, rel_model).split("__")) + for field_part_idx in range(0, len(field_path)): + yield field_path[0 : 1 + field_part_idx] + + if next_parts: + child_selectors = self._convert_field_name_part(next_parts, rel_model) + for selector in child_selectors: + all_field_path = field_path + selector + for field_part_idx in range(0, len(all_field_path)): + yield all_field_path[0 : 1 + field_part_idx] + + def convert_related_fieldnames(self, fields, opts=None): + """ + convert the field name which may contain polymorphic models names into + raw filed names that can be used with django select_related and + prefetch_related. + """ + if not opts: + opts = self.model + for field_name in fields: + field_parts = field_name.split(LOOKUP_SEP) + selectors = self._convert_field_name_part(field_parts, opts) + for selector in selectors: + yield "__".join(selector) + + +class PolymorphicQuerySet(PolymorphicQuerySetMixin, QuerySet): """ QuerySet for PolymorphicModel @@ -387,22 +796,37 @@ class self.model, but as a class derived from self.model. We want to re-fetch real_concrete_class = content_type_manager.get_for_id( real_concrete_class_id ).model_class() - idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name)) - indexlist_per_model[real_concrete_class].append((i, len(resultlist))) - resultlist.append(None) + + cached_obj = search_object_cache(base_object, self.model, real_concrete_class) + if cached_obj: + resultlist.append(cached_obj) + else: + idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name)) + indexlist_per_model[real_concrete_class].append((i, len(resultlist))) + resultlist.append(None) # For each model in "idlist_per_model" request its objects (the real model) # from the db and store them in results[]. # Then we copy the annotate fields from the base objects to the real objects. # Then we copy the extra() select fields from the base objects to the real objects. # TODO: defer(), only(): support for these would be around here + # Also see PolymorphicModelIterable.fetch_polymorphic + + filter_relations = [ + _get_query_related_name(mdl_cls) + for mdl_cls in _get_all_sub_models(self.model).values() + ] + for real_concrete_class, idlist in idlist_per_model.items(): indices = indexlist_per_model[real_concrete_class] real_objects = real_concrete_class._base_objects.db_manager(self.db).filter( **{(f"{pk_name}__in"): idlist} ) # copy select related configuration to new qs - real_objects.query.select_related = self.query.select_related + current_relation = real_objects.model.__name__.lower() + real_objects = self.apply_select_related( + real_objects, current_relation, filter_relations + ) # Copy deferred fields configuration to the new queryset deferred_loading_fields = [] @@ -484,6 +908,37 @@ class self.model, but as a class derived from self.model. We want to re-fetch return resultlist + def apply_select_related(self, qs, relation, filtered): + if self.query.select_related is True: + return qs.select_related() + + model_name = qs.model.__name__.lower() + if isinstance(self.query.select_related, dict): + select_related = {} + if isinstance(qs.query.select_related, dict): + select_related = qs.query.select_related + for k, v in self.query.select_related.items(): + if k in filtered and k != relation: + continue + else: + if not isinstance(select_related, dict): + select_related = {} + if k == relation: + if isinstance(v, dict): + if model_name in v: + select_related = merge_dicts(select_related, v[model_name]) + else: + for field in qs.model._meta.fields: + if field.name in v: + select_related = merge_dicts(select_related, v[field.name]) + else: + select_related = merge_dicts(select_related, v) + else: + select_related[k] = v + + qs.query.select_related = select_related + return qs + def __repr__(self, *args, **kwargs): if self.model.polymorphic_query_multiline_output: result = ",\n ".join(repr(o) for o in self.all()) @@ -516,3 +971,27 @@ def get_real_instances(self, base_result_objects=None): return olist clist = PolymorphicQuerySet._p_list_class(olist) return clist + + +################################################################################### +# PolymorphicRelatedQuerySet + + +class PolymorphicRelatedQuerySetMixin(PolymorphicQuerySetMixin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._iterable_class = PolymorphicModelIterable + self.polymorphic_disabled = False + + def _clone(self, *args, **kwargs): + # Django's _clone only copies its own variables, so we need to copy ours here + new = super()._clone(*args, **kwargs) + new.polymorphic_disabled = self.polymorphic_disabled + return new + + def _get_real_instances(self, base_result_objects): + return base_result_objects + + +class PolymorphicRelatedQuerySet(PolymorphicRelatedQuerySetMixin, QuerySet): + pass diff --git a/polymorphic/tests/migrations/0001_initial.py b/polymorphic/tests/migrations/0001_initial.py index 9e1dc4fb..bdaa55d0 100644 --- a/polymorphic/tests/migrations/0001_initial.py +++ b/polymorphic/tests/migrations/0001_initial.py @@ -8,7 +8,6 @@ class Migration(migrations.Migration): - initial = True dependencies = [("contenttypes", "0002_remove_content_type_name")] @@ -2064,4 +2063,268 @@ class Migration(migrations.Migration): }, bases=("auth.group", models.Model), ), + migrations.CreateModel( + name="NonSymRelationBase", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("field_base", models.CharField(max_length=10)), + ( + "fk", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="relationbase_set", + to="tests.nonsymrelationbase", + ), + ), + ("m2m", models.ManyToManyField(to="tests.nonsymrelationbase")), + ( + "polymorphic_ctype", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="polymorphic_%(app_label)s.%(class)s_set+", + to="contenttypes.contenttype", + ), + ), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + ), + migrations.CreateModel( + name="ParentModel", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("name", models.CharField(max_length=10)), + ( + "polymorphic_ctype", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="polymorphic_%(app_label)s.%(class)s_set+", + to="contenttypes.contenttype", + ), + ), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + ), + migrations.CreateModel( + name="PlainModel", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ( + "relation", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="tests.parentmodel" + ), + ), + ], + ), + migrations.CreateModel( + name="PlainModelWithM2M", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("field1", models.CharField(max_length=10)), + ("m2m", models.ManyToManyField(to="tests.parentmodel")), + ], + ), + migrations.CreateModel( + name="AltChildModel", + fields=[ + ( + "parentmodel_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="tests.parentmodel", + ), + ), + ("other_name", models.CharField(max_length=10)), + ( + "link_on_altchild", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="tests.plaina", + ), + ), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + bases=("tests.parentmodel",), + ), + migrations.CreateModel( + name="ChildModel", + fields=[ + ( + "parentmodel_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="tests.parentmodel", + ), + ), + ("other_name", models.CharField(max_length=10)), + ( + "link_on_child", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="tests.modelextraexternal", + ), + ), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + bases=("tests.parentmodel",), + ), + migrations.CreateModel( + name="NonSymRelationA", + fields=[ + ( + "nonsymrelationbase_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="tests.nonsymrelationbase", + ), + ), + ("field_a", models.CharField(max_length=10)), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + bases=("tests.nonsymrelationbase",), + ), + migrations.CreateModel( + name="NonSymRelationB", + fields=[ + ( + "nonsymrelationbase_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="tests.nonsymrelationbase", + ), + ), + ("field_b", models.CharField(max_length=10)), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + bases=("tests.nonsymrelationbase",), + ), + migrations.CreateModel( + name="NonSymRelationBC", + fields=[ + ( + "nonsymrelationbase_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="tests.nonsymrelationbase", + ), + ), + ("field_c", models.CharField(max_length=10)), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + bases=("tests.nonsymrelationbase",), + ), + migrations.CreateModel( + name="AltChildAsBaseModel", + fields=[ + ( + "altchildmodel_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="tests.altchildmodel", + ), + ), + ("more_name", models.CharField(max_length=10)), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + bases=("tests.altchildmodel",), + ), + migrations.CreateModel( + name="AltChildWithM2MModel", + fields=[ + ( + "altchildmodel_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="tests.altchildmodel", + ), + ), + ("m2m", models.ManyToManyField(to="tests.plaina")), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + bases=("tests.altchildmodel",), + ), ] diff --git a/polymorphic/tests/models.py b/polymorphic/tests/models.py index 76e1b626..b960ecfd 100644 --- a/polymorphic/tests/models.py +++ b/polymorphic/tests/models.py @@ -8,7 +8,11 @@ from polymorphic.managers import PolymorphicManager from polymorphic.models import PolymorphicModel -from polymorphic.query import PolymorphicQuerySet +from polymorphic.query import ( + PolymorphicQuerySet, + PolymorphicRelatedQuerySetMixin, + PolymorphicRelatedQuerySet, +) from polymorphic.showfields import ShowFieldContent, ShowFieldType, ShowFieldTypeAndContent @@ -348,6 +352,8 @@ class NonProxyChild(ProxyBase): # base -> proxy -> real models + + class ProxiedBase(ShowFieldTypeAndContent, PolymorphicModel): name = models.CharField(max_length=10) @@ -496,3 +502,64 @@ class SubclassSelectorProxyConcreteModel(SubclassSelectorProxyModel): class NonPolymorphicParent(PolymorphicModel, Group): test = models.CharField(max_length=22, default="test_non_polymorphic_parent") + + +class NonSymRelationBase(PolymorphicModel): + field_base = models.CharField(max_length=10) + fk = models.ForeignKey( + "self", on_delete=models.CASCADE, null=True, related_name="relationbase_set" + ) + m2m = models.ManyToManyField("self", symmetrical=False) + + +class NonSymRelationA(NonSymRelationBase): + field_a = models.CharField(max_length=10) + + +class NonSymRelationB(NonSymRelationBase): + field_b = models.CharField(max_length=10) + + +class NonSymRelationBC(NonSymRelationBase): + field_c = models.CharField(max_length=10) + + +class CustomPolySupportingQuerySet(PolymorphicRelatedQuerySetMixin, models.QuerySet): + pass + + +class ParentModel(PolymorphicModel): + name = models.CharField(max_length=10) + + +class ChildModel(ParentModel): + other_name = models.CharField(max_length=10) + link_on_child = models.ForeignKey( + ModelExtraExternal, on_delete=models.CASCADE, null=True, related_name="+" + ) + + +class AltChildModel(ParentModel): + other_name = models.CharField(max_length=10) + link_on_altchild = models.ForeignKey( + PlainA, on_delete=models.CASCADE, null=True, related_name="+" + ) + + +class AltChildAsBaseModel(AltChildModel): + more_name = models.CharField(max_length=10) + + +class PlainModel(models.Model): + relation = models.ForeignKey(ParentModel, on_delete=models.CASCADE) + objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + +class PlainModelWithM2M(models.Model): + field1 = models.CharField(max_length=10) + m2m = models.ManyToManyField(ParentModel) + objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + +class AltChildWithM2MModel(AltChildModel): + m2m = models.ManyToManyField(PlainA) diff --git a/polymorphic/tests/test_orm.py b/polymorphic/tests/test_orm.py index 45f0746b..9ad5a754 100644 --- a/polymorphic/tests/test_orm.py +++ b/polymorphic/tests/test_orm.py @@ -2,16 +2,20 @@ import re import uuid +from unittest import expectedFailure from django.contrib.contenttypes.models import ContentType from django.db import models -from django.db.models import Case, Count, FilteredRelation, Q, When +from django.db.models import Case, Count, FilteredRelation, Q, When, F, Prefetch + from django.db.utils import IntegrityError from django.test import TransactionTestCase - -from polymorphic import query_translate +from polymorphic import compat, query_translate from polymorphic.managers import PolymorphicManager from polymorphic.models import PolymorphicTypeInvalid, PolymorphicTypeUndefined from polymorphic.tests.models import ( + AltChildAsBaseModel, + AltChildModel, + AltChildWithM2MModel, ArtProject, Base, BlogA, @@ -19,6 +23,7 @@ BlogBase, BlogEntry, BlogEntry_limit_choices_to, + ChildModel, ChildModelWithManager, CustomPkBase, CustomPkInherit, @@ -54,13 +59,20 @@ MyManagerQuerySet, NonPolymorphicParent, NonProxyChild, + NonSymRelationA, + NonSymRelationB, + NonSymRelationBase, + NonSymRelationBC, One2OneRelatingModel, One2OneRelatingModelDerived, + ParentModel, ParentModelWithManager, PlainA, PlainB, PlainC, PlainChildModelWithManager, + PlainModel, + PlainModelWithM2M, PlainMyManager, PlainMyManagerQuerySet, PlainParentModelWithManager, @@ -691,54 +703,54 @@ def test_relation_base(self): objects = RelationBase.objects.all() assert ( repr(objects[0]) - == '' + f'', ) assert ( repr(objects[1]) - == '' + f'', ) assert ( repr(objects[2]) - == '' + f'', ) assert ( repr(objects[3]) - == '' + f'', ) assert len(objects) == 4 - oa = RelationBase.objects.get(id=2) + boa = RelationBase.objects.get(id=oa.pk) assert ( - repr(oa.fk) - == '' + repr(boa.fk), + f'', ) - objects = oa.relationbase_set.all() + objects = boa.relationbase_set.all() assert ( repr(objects[0]) - == '' + f'', ) assert ( repr(objects[1]) - == '' + f'', ) assert len(objects) == 2 - ob = RelationBase.objects.get(id=3) + bob = RelationBase.objects.get(id=ob.pk) assert ( - repr(ob.fk) - == '' + repr(bob.fk), + f'', ) - oa = RelationA.objects.get() - objects = oa.m2m.all() + aoa = RelationA.objects.get() + objects = aoa.m2m.all() assert ( repr(objects[0]) - == '' + f'', ) assert ( repr(objects[1]) - == '' + f'', ) assert len(objects) == 2 @@ -1241,3 +1253,521 @@ def test_refresh_from_db_fields(self): def test_non_polymorphic_parent(self): obj = NonPolymorphicParent.objects.create() assert obj.delete() + + def test_normal_django_to_poly_related_give_poly_type(self): + obj1 = ParentModel.objects.create(name="m1") + obj2 = ChildModel.objects.create(name="m2", other_name="m2") + obj3 = ChildModel.objects.create(name="m1") + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + + ct = ContentType.objects.get_for_model(AltChildModel, for_concrete_model=True) + + with self.assertNumQueries(6): + # Queries will be + # * 1 for All PlainModels object (1) + # * 1 for each relations ParentModel (4) + # * 1 for each relations ChilModel is needed (3) + multi_q = [ + # these obj.relation values will have their proper sub type + obj.relation + for obj in PlainModel.objects.all() + ] + multi_q_types = [type(obj) for obj in multi_q] + + with self.assertNumQueries(2): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + obj.relation + for obj in PlainModel.objects.select_related("relation") + ] + grouped_q_types = [type(obj) for obj in grouped_q] + + self.assertListEqual(multi_q_types, grouped_q_types) + self.assertListEqual(grouped_q, [obj1, obj2, obj3]) + + def test_normal_django_to_poly_related_give_poly_type_using_select_related_true(self): + obj1 = ParentModel.objects.create(name="m1") + obj2 = ChildModel.objects.create(name="m2", other_name="m2") + obj3 = ChildModel.objects.create(name="m1") + obj4 = AltChildAsBaseModel.objects.create( + name="ac2", other_name="ac2name", more_name="ac2morename" + ) + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + + with self.assertNumQueries(8): + # Queries will be + # * 1 for All PlainModels object (x1) + # * 1 for each relations ParentModel (x4) + # * 1 for each relations ChildModel is needed (x2) + # * 1 for each relations AltChildAsBaseModel is needed (x1) + multi_q = [ + # these obj.relation values will have their proper sub type + obj.relation + for obj in PlainModel.objects.all() + ] + multi_q_types = [type(obj) for obj in multi_q] + + with self.assertNumQueries(3): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + # ATM: we require 1 query fro each type. Although this can + # be reduced by specifying the relations to the polymorphic + # classes. BUT this has the downside of making the query have + # a large number of joins + obj.relation + for obj in PlainModel.objects.select_related() + ] + grouped_q_types = [type(obj) for obj in grouped_q] + + self.assertListEqual(multi_q_types, grouped_q_types) + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4]) + + def test_prefetch_base_load_359(self): + obj1_1 = ModelShow1_plain.objects.create(field1="1") + obj2_1 = ModelShow2_plain.objects.create(field1="2", field2="1") + obj3_2 = ModelShow2_plain.objects.create(field1="3", field2="2") + + with self.assertNumQueries(1): + obj = ModelShow2_plain.objects.filter(pk=obj2_1.pk)[0] + _ = (obj.field1, obj.field1) + + def test_select_related_on_poly_classes(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__childmodel__link_on_child", + "relation__altchildmodel__link_on_altchild", + ).order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2") + obj_list[1].relation.link_on_child + obj_list[2].relation.link_on_altchild + obj_list[3].relation.link_on_altchild + + def test_select_related_on_poly_classes_supports_multi_level_inheritance(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_acab2 = AltChildAsBaseModel.objects.create( + name="ac2ab", + other_name="acab2name", + more_name="acab2morename", + link_on_altchild=plain_a_obj_2, + ) + + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_acab2) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__childmodel__link_on_child", + "relation__altchildmodel__link_on_altchild", + "relation__altchildmodel__altchildasbasemodel__link_on_altchild", + ).order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2ab") + self.assertEqual(obj_list[3].relation.more_name, "acab2morename") + obj_list[1].relation.link_on_child + obj_list[2].relation.link_on_altchild + obj_list[3].relation.link_on_altchild + + def test_select_related_on_poly_classes_with_modelname(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_acab2 = AltChildAsBaseModel.objects.create( + name="acab2", + other_name="acab2name", + more_name="acab2morename", + link_on_altchild=plain_a_obj_1, + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_acab2) + + ct = ContentType.objects.get_for_model(AltChildModel, for_concrete_model=True) + + with self.assertNumQueries(1): + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__ChildModel__link_on_child", + "relation__AltChildAsBaseModel__link_on_altchild", + ).order_by("pk") + ) + + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "acab2") + obj_list[1].relation.link_on_child + obj_list[2].relation.link_on_altchild + + def test_prefetch_related_from_basepoly(self): + obja1 = NonSymRelationA.objects.create(field_a="fa1", field_base="fa1") + obja2 = NonSymRelationA.objects.create(field_a="fa2", field_base="fa2") + objb1 = NonSymRelationB.objects.create(field_b="fb1", field_base="fb1") + objbc1 = NonSymRelationBC.objects.create(field_c="fbc1", field_base="fbc1") + + obja3 = NonSymRelationA.objects.create(field_a="fa3", field_base="fa3") + # NOTE: these are symmetric links + obja3.m2m.add(obja2) + obja3.m2m.add(objb1) + obja2.m2m.add(objbc1) + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(NonSymRelationBase, for_concrete_model=True) + + with self.assertNumQueries(10): + # query for NonSymRelationBase (base) + # query for NonSymRelationA # level 1 (base) + # query for NonSymRelationB # level 1 (base) + # query for NonSymRelationBC # level 1 (base) + # query for prefetch links (m2m) + # query for NonSymRelationA # level 2 (m2m) + # query for NonSymRelationB # level 2 (m2m) + # query for NonSymRelationBC # level 2 (m2m) + # query for prefetch links (m2m__m2m) + # query for NonSymRelationA # level 3 (m2m__m2m) + # query for NonSymRelationB # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + # query for NonSymRelationC # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + + all_objs = { + obj.pk: obj + for obj in NonSymRelationBase.objects.prefetch_related("m2m", "m2m__m2m") + } + + with self.assertNumQueries(0): + relations = {obj.pk: set(obj.m2m.all()) for obj in all_objs.values()} + + with self.assertNumQueries(0): + sub_relations = {a.pk: set(a.m2m.all()) for a in all_objs.get(obja3.pk).m2m.all()} + + self.assertDictEqual( + { + obja1.pk: set(), + obja2.pk: set([objbc1]), + obja3.pk: set([obja2, objb1]), + objb1.pk: set([]), + objbc1.pk: set([]), + }, + relations, + ) + + self.assertDictEqual( + { + obja2.pk: set([objbc1]), + objb1.pk: set([]), + }, + sub_relations, + ) + + def test_prefetch_related_from_subclass(self): + obja1 = NonSymRelationA.objects.create(field_a="fa1", field_base="fa1") + obja2 = NonSymRelationA.objects.create(field_a="fa2", field_base="fa2") + objb1 = NonSymRelationB.objects.create(field_b="fb1", field_base="fb1") + objbc1 = NonSymRelationBC.objects.create(field_c="fbc1", field_base="fbc1") + + obja3 = NonSymRelationA.objects.create(field_a="fa3", field_base="fa3") + # NOTE: these are symmetric links + obja3.m2m.add(obja2) + obja3.m2m.add(objb1) + obja2.m2m.add(objbc1) + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(NonSymRelationBase, for_concrete_model=True) + + with self.assertNumQueries(7): + # query for NonSymRelationA # level 1 (base) + # query for prefetch links (m2m) + # query for NonSymRelationA # level 2 (m2m) + # query for NonSymRelationB # level 2 (m2m) + # query for NonSymRelationBC # level 2 (m2m) + # query for prefetch links (m2m__m2m) + # query for NonSymRelationA # level 3 (m2m__m2m) + # query for NonSymRelationB # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + # query for NonSymRelationC # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + + all_objs = { + obj.pk: obj for obj in NonSymRelationA.objects.prefetch_related("m2m", "m2m__m2m") + } + + with self.assertNumQueries(0): + relations = {obj.pk: set(obj.m2m.all()) for obj in all_objs.values()} + + with self.assertNumQueries(0): + sub_relations = {a.pk: set(a.m2m.all()) for a in all_objs.get(obja3.pk).m2m.all()} + + self.assertDictEqual( + { + obja1.pk: set(), + obja2.pk: set([objbc1]), + obja3.pk: set([obja2, objb1]), + }, + relations, + ) + + self.assertDictEqual( + { + obja2.pk: set([objbc1]), + objb1.pk: set([]), + }, + sub_relations, + ) + + def test_select_related_field_from_polymorphic_child_class(self): + # 198 + obj_p1 = ParentModel.objects.create(name="p1") + obj_p2 = ParentModel.objects.create(name="p2") + obj_p3 = ParentModel.objects.create(name="p4") + obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name") + obj_c2 = ChildModel.objects.create(name="c2", other_name="c2name") + obj_ac1 = AltChildModel.objects.create(name="ac1", other_name="ac1name") + obj_ac2 = AltChildModel.objects.create(name="ac2", other_name="ac2name") + obj_ac3 = AltChildModel.objects.create(name="ac3", other_name="ac3name") + + with self.assertNumQueries(2): + # Queries will be + # * 1 for All ParentModel object (x4 +bases of all) + # * 1 for ChildModel object (x2) + # * 0 for AltChildModel object as from select_related (x3) + all_objs = [ + obj + for obj in ParentModel.objects.select_related( + "altchildmodel", + ) + ] + + def test_select_related_field_from_polymorphic_child_class_using_modelnames_level1(self): + # 198 + obj_p1 = ParentModel.objects.create(name="p1") + obj_p2 = ParentModel.objects.create(name="p2") + obj_p3 = ParentModel.objects.create(name="p4") + obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name") + obj_c2 = ChildModel.objects.create(name="c2", other_name="c2name") + obj_ac1 = AltChildModel.objects.create(name="ac1", other_name="ac1name") + obj_ac2 = AltChildModel.objects.create(name="ac2", other_name="ac2name") + obj_ac3 = AltChildModel.objects.create(name="ac3", other_name="ac3name") + + with self.assertNumQueries(2): + # Queries will be + # * 1 for All ParentModel object (x4 +bases of all) + # * 1 for ChildModel object (x2) + # * 0 for AltChildModel object as from select_related (x3) + all_objs = [ + obj + for obj in ParentModel.objects.select_related( + "AltChildModel", + ) + ] + + def test_select_related_field_from_polymorphic_child_class_using_modelnames_multi_level(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + + obj_p1 = ParentModel.objects.create(name="p1") + obj_acab2 = AltChildAsBaseModel.objects.create( + name="acab2", + other_name="acab2name", + more_name="acab2morename", + link_on_altchild=plain_a_obj_1, + ) + obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name") + obj_ac3 = ChildModel.objects.create(name="c2", other_name="c3name") + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(AltChildModel, for_concrete_model=True) + + with self.assertNumQueries(2): + # Queries will be + # * 1 for All ParentModel object (x4 +bases of all) + # * 1 for ChildModel object (x1) + # * 0 for AltChildAsBaseModel object as from select_related (x1) + # * 0 for AltChildModel object as part of select_related form + # AltChildAsBaseModel (x1) + all_objs = [obj for obj in ParentModel.objects.select_related("AltChildAsBaseModel")] + + def test_prefetch_object_is_supported(self): + b1 = RelatingModel.objects.create() + b2 = RelatingModel.objects.create() + + rel1 = Model2A.objects.create(field1="A1") + rel2 = Model2B.objects.create(field1="A2", field2="B2") + + b1.many2many.add(rel1) + b2.many2many.add(rel2) + + rel2.delete(keep_parents=True) + + qs = RelatingModel.objects.order_by("pk").prefetch_related( + Prefetch("many2many", queryset=Model2A.objects.all(), to_attr="poly"), + Prefetch("many2many", queryset=Model2A.objects.non_polymorphic(), to_attr="non_poly"), + ) + + objects = list(qs) + self.assertEqual(len(objects[0].poly), 1) + + # derived object was not fetched + self.assertEqual(len(objects[1].poly), 0) + + # base object always found + self.assertEqual(len(objects[0].non_poly), 1) + self.assertEqual(len(objects[1].non_poly), 1) + + def test_select_related_on_poly_classes_preserves_on_relations_annotations(self): + b1 = RelatingModel.objects.create() + b2 = RelatingModel.objects.create() + b3 = RelatingModel.objects.create() + + rel1 = Model2A.objects.create(field1="A1") + rel2 = Model2B.objects.create(field1="A2", field2="B2") + + b1.many2many.add(rel1) + b2.many2many.add(rel2) + b3.many2many.add(rel2) + + qs = RelatingModel.objects.order_by("pk").prefetch_related( + Prefetch( + "many2many", + queryset=Model2A.objects.annotate(Count("relatingmodel")), + to_attr="poly", + ) + ) + + objects = list(qs) + self.assertEqual(objects[0].poly[0].relatingmodel__count, 1) + self.assertEqual(objects[1].poly[0].relatingmodel__count, 2) + self.assertEqual(objects[2].poly[0].relatingmodel__count, 2) + + @expectedFailure + def test_prefetch_loading_relation_only_on_some_poly_model(self): + plain_a_obj_1 = PlainA.objects.create(field1="p1") + plain_a_obj_2 = PlainA.objects.create(field1="p2") + plain_a_obj_3 = PlainA.objects.create(field1="p3") + plain_a_obj_4 = PlainA.objects.create(field1="p4") + plain_a_obj_5 = PlainA.objects.create(field1="p5") + + ac_m2m_obj = AltChildWithM2MModel.objects.create( + other_name="o1", + ) + ac_m2m_obj.m2m.set([plain_a_obj_1, plain_a_obj_2, plain_a_obj_3]) + + cm_1 = ChildModel.objects.create(other_name="c1") + cm_2 = ChildModel.objects.create(other_name="c2") + cm_3 = ChildModel.objects.create(other_name="c3") + + acm_1 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_4) + acm_2 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_5) + + pm_1 = PlainModelWithM2M.objects.create(field1="pm1") + pm_2 = PlainModelWithM2M.objects.create(field1="pm2") + + pm_1.m2m.set([cm_1, cm_2]) + pm_2.m2m.set( + [ + cm_3, + ] + ) + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(ParentModel, for_concrete_model=True) + + pm_2.m2m.set([ac_m2m_obj]) + with self.assertNumQueries(4): + # query for PlainModelWithM2M # level 1 (base) + # query for prefetch links (m2m) + # query for ChildModel # level 2 (m2m) + # query for AltChildWithM2MModel # level 2 (m2m) + qs = PlainModelWithM2M.objects.all() + qs = qs.prefetch_related("m2m__altchildmodel__altchildWithm2mmodel__m2m") + all_objs = list(qs) + + @expectedFailure + def test_prefetch_loading_relation_only_on_some_poly_model_using_modelnames(self): + plain_a_obj_1 = PlainA.objects.create(field1="p1") + plain_a_obj_2 = PlainA.objects.create(field1="p2") + plain_a_obj_3 = PlainA.objects.create(field1="p3") + plain_a_obj_4 = PlainA.objects.create(field1="p4") + plain_a_obj_5 = PlainA.objects.create(field1="p5") + + ac_m2m_obj = AltChildWithM2MModel.objects.create( + other_name="o1", + ) + ac_m2m_obj.m2m.set([plain_a_obj_1, plain_a_obj_2, plain_a_obj_3]) + + cm_1 = ChildModel.objects.create(other_name="c1") + cm_2 = ChildModel.objects.create(other_name="c2") + cm_3 = ChildModel.objects.create(other_name="c3") + + acm_1 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_4) + acm_2 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_5) + + pm_1 = PlainModelWithM2M.objects.create(field1="pm1") + pm_2 = PlainModelWithM2M.objects.create(field1="pm2") + + pm_1.m2m.set([cm_1, cm_2]) + pm_2.m2m.set( + [ + cm_3, + ] + ) + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(ParentModel, for_concrete_model=True) + + pm_2.m2m.set([ac_m2m_obj]) + with self.assertNumQueries(4): + # query for PlainModelWithM2M # level 1 (base) + # query for prefetch links (m2m) + # query for ChildModel # level 2 (m2m) + # query for AltChildWithM2MModel # level 2 (m2m) + qs = PlainModelWithM2M.objects.all() + qs = qs.prefetch_related("m2m__AltChildWithM2MModel__m2m") + all_objs = list(qs)