diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index c0cd785..8471675 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -11,37 +11,29 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] - postgres-version: ["9.4", "9.5", "9.6", "10", "11", "12", "13"] - django-version: ["1.10", "1.11", "2.0", "2.1", "2.2", "3.0", "3.1", "3.2"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + postgres-version: ["9.6", "11", "12", "13", "14", "15"] + django-version: ["2.2", "3.0", "3.1", "3.2", "4.0", "4.1", "4.2"] exclude: - # Django 3.0+ doesn't support PostgreSQL 9.4 - - django-version: "3.0" - postgres-version: "9.4" - - django-version: "3.1" - postgres-version: "9.4" - - django-version: "3.2" - postgres-version: "9.4" - - # python 3.6+ has deprecated issue with django before 1.11 - # https://stackoverflow.com/questions/41343263/provide-classcell-example-for-python-3-6-metaclass\ - - python-version: "3.7" - django-version: "1.10" - - python-version: "3.8" - django-version: "1.10" - - python-version: "3.9" - django-version: "1.10" - - python-version: "3.10" - django-version: "1.10" - - # Django before 2.1 is not compatible with python 3.10 - # as it uses collections.Iterator - - python-version: "3.10" - django-version: "2.0" - - python-version: "3.10" - django-version: "1.11" - - python-version: "3.10" - django-version: "1.10" + # Django 4.0+ doesn't support Pythhon 3.7 + - django-version: "4.0" + python-version: "3.7" + - django-version: "4.1" + python-version: "3.7" + - django-version: "4.2" + python-version: "3.7" + + # Django 4.0+ doesn't support PostgreSQL 9.6 + - django-version: "4.0" + postgres-version: "9.6" + - django-version: "4.1" + postgres-version: "9.6" + - django-version: "4.2" + postgres-version: "9.6" + + # Django 4.2+ doesn't support PostgreSQL 11 + - django-version: "4.2" + postgres-version: "11" services: postgres: diff --git a/src/django_pg_returning/compatibility.py b/src/django_pg_returning/compatibility.py index c1ffd26..b61cc98 100644 --- a/src/django_pg_returning/compatibility.py +++ b/src/django_pg_returning/compatibility.py @@ -1,5 +1,7 @@ +from collections import defaultdict + import django -from typing import Type, Optional, List +from typing import Type, Optional, List, Dict from django.db.models import Model, QuerySet, Field from django.db.models.sql import Query @@ -36,3 +38,71 @@ def get_model_fields(model, concrete=False): # type: (Type[Model], Optional[boo res = [f for f in res if getattr(f, 'concrete', True) and not getattr(f, 'many_to_many', False)] return res + + +def clear_query_ordering(query): # type: (Query) -> Query + """ + Resets query ordering. Parameters changed in django 4.0 + :param query: Query to change + :return: Resulting query + """ + attr_name = 'force_empty' if django.VERSION < (4,) else 'force' + query.clear_ordering(**{attr_name: True}) + return query + + +def prepare_insert_query_kwargs(kwargs): + """ + Prepares kwargs for InsertQuery method based on kwargs from QuerySet._insert(...) + :param kwargs: Original kwargs from QuerySet._insert(obj, fields, **kwargs) + :return: kwargs ready for InsertQuery(model, **kwargs) + """ + if django.VERSION < (2, 2): + query_kwargs = {} + elif django.VERSION < (4, 1): + query_kwargs = {'ignore_conflicts': kwargs.get('ignore_conflicts')} + else: + query_kwargs = { + 'on_conflict': kwargs.get('on_conflict'), + 'update_fields': kwargs.get('update_fields'), + 'unique_fields': kwargs.get('unique_fields') + } + + return query_kwargs + + +def get_not_deferred_fields(qs): # type: (QuerySet) -> Dict[Type[Model], List[Field]] + """ + Gets model fields for query + :param qs: QuerySet for which we get required fields + :return: A dictionary of lists {Model: [Field, Field, ...]} + """ + fields = {} + + if django.VERSION >= (4, 2): + fields = qs.query.get_select_mask() + result_fields = defaultdict(list) + for field in fields.keys(): + result_fields[field.model].append(field) + fields = result_fields + + elif django.VERSION >= (4, 1): + # Django 4.0 changed fields format + qs.query.deferred_to_data(fields) + fields = { + model: [ + model._meta.get_field(field_name) + for field_name in field_names + ] for model, field_names in fields.items() + } + + elif django.VERSION >= (1, 10): + qs.query.deferred_to_data(fields, qs._get_loaded_field_cb) + + else: + # Before django 1.10 pk fields hasn't been returned from postgres. + # In this case, I can't match bulk_create results and return values by primary key. + # So I select all data from returned results + pass + + return fields diff --git a/src/django_pg_returning/manager.py b/src/django_pg_returning/manager.py index 7c6389c..c856cf0 100644 --- a/src/django_pg_returning/manager.py +++ b/src/django_pg_returning/manager.py @@ -4,7 +4,8 @@ from django.db import transaction, models from django.db.models import sql, Field, QuerySet -from .compatibility import chain_query, get_model_fields +from .compatibility import chain_query, get_model_fields, clear_query_ordering, prepare_insert_query_kwargs, \ + get_not_deferred_fields from .queryset import ReturningQuerySet # DEPRECATED class package changed in django 1.11 @@ -31,55 +32,47 @@ def _insert(self, objs, fields, **kwargs): return QuerySet._insert(self, objs, fields, **kwargs) # Returns attname, not column. - # Before django 1.10 pk fields hasn't been returned from postgres. - # In this case, I can't match bulk_create results and return values by primary key. - # So I select all data from returned results - return_fields = self._get_fields(ignore_deferred=(django.VERSION < (1, 10))) + return_fields = self._get_fields() assert len(return_fields) == 1 and list(return_fields.keys())[0] == self.model, \ "You can't fetch relative model fields with returning operation" self._for_write = True using = kwargs.get('using', None) or self.db - query_kwargs = {} if django.VERSION < (2, 2) else {'ignore_conflicts': kwargs.get('ignore_conflicts')} + query_kwargs = prepare_insert_query_kwargs(kwargs) query = sql.InsertQuery(self.model, **query_kwargs) query.insert_values(fields, objs, raw=kwargs.get('raw')) self.model._insert_returning_cache = self._execute_sql(query, return_fields, using=using) - if django.VERSION < (3,): - if not kwargs.get('return_id', False): - return None + if kwargs.get('return_id', False): + # Django before 3.0 inserted_ids = self.model._insert_returning_cache.values_list(self.model._meta.pk.column, flat=True) if not inserted_ids: return None return list(inserted_ids) if len(inserted_ids) > 1 else inserted_ids[0] - else: - returning_fields = kwargs.get('returning_fields', None) - if returning_fields is None: - return None - columns = [f.column for f in returning_fields] + elif kwargs.get('returning_fields', None): + # Django 3.0+ + columns = [f.column for f in kwargs['returning_fields']] # In django 3.0 single result is returned if single object is returned... flat = django.VERSION < (3, 1) and len(objs) <= 1 return self.model._insert_returning_cache.values_list(*columns, flat=flat) + return None + _insert.alters_data = True _insert.queryset_only = False - def _get_fields(self, ignore_deferred=False): # type: (bool) -> Dict[models.Model: List[models.Field]] + def _get_fields(self): # type: () -> Dict[models.Model: List[models.Field]] """ Gets a dictionary of fields for each model, selected by .only() and .defer() methods - :param ignore_deferred: If set, ignores .only() and .defer() filters :return: A dictionary with model as key, fields list as value """ - fields = {} - - if not ignore_deferred: - self.query.deferred_to_data(fields, self._get_loaded_field_cb) + fields = get_not_deferred_fields(self) # No .only() or .defer() operations if not fields: @@ -140,7 +133,7 @@ def _get_returning_qs(self, query_type, values=None, **updates): query._annotations = None query.select_for_update = False query.select_related = False - query.clear_ordering(force_empty=True) + clear_query_ordering(query) return self._execute_sql(query, fields) diff --git a/tests/settings.py b/tests/settings.py index c5916f0..4c3cb4b 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -28,3 +28,5 @@ "src", "tests" ] + +DEFAULT_AUTO_FIELD = "django.db.models.AutoField"