Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to Django 5.2 #199

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .evergreen/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ python -m pip install -U pip
pip install -e .

# Install django and test dependencies
git clone --branch mongodb-5.1.x https://github.com/mongodb-forks/django django_repo
git clone --branch mongodb-5.2.x https://github.com/mongodb-forks/django django_repo
pushd django_repo/tests/
pip install -e ..
pip install -r requirements/py3.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
uses: actions/checkout@v4
with:
repository: 'mongodb-forks/django'
ref: 'mongodb-5.1.x'
ref: 'mongodb-5.2.x'
path: 'django_repo'
persist-credentials: false
- name: Install system packages for Django's Python test dependencies
Expand Down
2 changes: 1 addition & 1 deletion django_mongodb_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "5.1.0b1.dev0"
__version__ = "5.2.0a0"

# Check Django compatibility before other imports which may fail if the
# wrong version of Django is installed.
Expand Down
47 changes: 29 additions & 18 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .base import Cursor
from .query import MongoQuery, wrap_database_errors


Expand Down Expand Up @@ -403,12 +402,6 @@ def columns(self):
columns = (
self.get_default_columns(select_mask) if self.query.default_cols else self.query.select
)
# Populate QuerySet.select_related() data.
related_columns = []
if self.query.select_related:
self.get_related_selections(related_columns, select_mask)
if related_columns:
related_columns, _ = zip(*related_columns, strict=True)

annotation_idx = 1

Expand All @@ -427,11 +420,28 @@ def project_field(column):
annotation_idx += 1
return target, column

return (
tuple(map(project_field, columns))
+ tuple(self.annotations.items())
+ tuple(map(project_field, related_columns))
)
selected = []
if self.query.selected is None:
selected = [
*(project_field(col) for col in columns),
*self.annotations.items(),
]
else:
for expression in self.query.selected.values():
# Reference to an annotation.
if isinstance(expression, str):
alias, expression = expression, self.annotations[expression]
# Reference to a column.
elif isinstance(expression, int):
alias, expression = project_field(columns[expression])
selected.append((alias, expression))
# Populate QuerySet.select_related() data.
related_columns = []
if self.query.select_related:
self.get_related_selections(related_columns, select_mask)
if related_columns:
related_columns, _ = zip(*related_columns, strict=True)
return tuple(selected) + tuple(map(project_field, related_columns))

@cached_property
def base_table(self):
Expand Down Expand Up @@ -478,7 +488,11 @@ def get_combinator_queries(self):
# If the columns list is limited, then all combined queries
# must have the same columns list. Set the selects defined on
# the query on all combined queries, if not already set.
if not compiler_.query.values_select and self.query.values_select:
selected = self.query.selected
if selected is not None and compiler_.query.selected is None:
compiler_.query = compiler_.query.clone()
compiler_.query.set_values(selected)
elif not compiler_.query.values_select and self.query.values_select:
compiler_.query = compiler_.query.clone()
compiler_.query.set_values(
(
Expand Down Expand Up @@ -690,15 +704,12 @@ def collection_name(self):

class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
def execute_sql(self, result_type=MULTI):
cursor = Cursor()
try:
query = self.build_query()
except EmptyResultSet:
rowcount = 0
return 0
else:
rowcount = query.delete()
cursor.rowcount = rowcount
return cursor
return query.delete()

def check_query(self):
super().check_query()
Expand Down
20 changes: 19 additions & 1 deletion django_mongodb_backend/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Col,
CombinedExpression,
Exists,
ExpressionList,
ExpressionWrapper,
F,
NegatedExpression,
Expand All @@ -24,6 +25,8 @@
)
from django.db.models.sql import Query

from .query_utils import process_lhs


def case(self, compiler, connection):
case_parts = []
Expand Down Expand Up @@ -83,6 +86,10 @@ def expression_wrapper(self, compiler, connection):
return self.expression.as_mql(compiler, connection)


def expression_list(self, compiler, connection):
return process_lhs(self, compiler, connection)


def f(self, compiler, connection): # noqa: ARG001
return f"${self.name}"

Expand Down Expand Up @@ -150,7 +157,11 @@ def ref(self, compiler, connection): # noqa: ARG001
if isinstance(self.source, Col) and self.source.alias != compiler.collection_name
else ""
)
return f"${prefix}{self.refs}"
if hasattr(self, "ordinal"):
refs, _ = compiler.columns[self.ordinal - 1]
else:
refs = self.refs
return f"${prefix}{refs}"


def star(self, compiler, connection): # noqa: ARG001
Expand All @@ -175,6 +186,12 @@ def when(self, compiler, connection):

def value(self, compiler, connection): # noqa: ARG001
value = self.value
output_field = self._output_field_or_none
if output_field is not None:
if self.for_save:
value = output_field.get_db_prep_save(value, connection=connection)
else:
value = output_field.get_db_prep_value(value, connection=connection)
if isinstance(value, int):
# Wrap numbers in $literal to prevent ambiguity when Value appears in
# $project.
Expand Down Expand Up @@ -202,6 +219,7 @@ def register_expressions():
Col.as_mql = col
CombinedExpression.as_mql = combined_expression
Exists.as_mql = exists
ExpressionList.as_mql = expression_list
ExpressionWrapper.as_mql = expression_wrapper
F.as_mql = f
NegatedExpression.as_mql = negated_expression
Expand Down
28 changes: 26 additions & 2 deletions django_mongodb_backend/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
greatest_least_ignores_nulls = True
has_json_object_function = False
has_native_json_field = True
rounds_to_even = True
supports_boolean_expr_in_select_clause = True
supports_collation_on_charfield = False
supports_column_check_constraints = False
Expand Down Expand Up @@ -56,8 +57,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# Pattern lookups that use regexMatch don't work on JSONField:
# Unsupported conversion from array to string in $convert
"model_fields.test_jsonfield.TestQuerying.test_icontains",
# MongoDB gives ROUND(365, -1)=360 instead of 370 like other databases.
"db_functions.math.test_round.RoundTests.test_integer_with_negative_precision",
# Truncating in another timezone doesn't work becauase MongoDB converts
# the result back to UTC.
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_func_with_timezone",
Expand Down Expand Up @@ -88,6 +87,19 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# of $setIsSubset must be arrays. Second argument is of type: null"
# https://jira.mongodb.org/browse/SERVER-99186
"model_fields_.test_arrayfield.QueryingTests.test_contained_by_subquery",
# JSONArray not implemented.
"db_functions.json.test_json_array.JSONArrayTests",
# Some usage of prefetch_related() raises "ColPairs is not supported."
"known_related_objects.tests.ExistingRelatedInstancesTests.test_one_to_one_multi_prefetch_related",
"known_related_objects.tests.ExistingRelatedInstancesTests.test_one_to_one_prefetch_related",
"prefetch_related.tests.DeprecationTests.test_prefetch_one_level_fallback",
"prefetch_related.tests.MultiDbTests.test_using_is_honored_fkey",
"prefetch_related.tests.MultiDbTests.test_using_is_honored_inheritance",
"prefetch_related.tests.NestedPrefetchTests.test_nested_prefetch_is_not_overwritten_by_related_object",
"prefetch_related.tests.NullableTest.test_prefetch_nullable",
"prefetch_related.tests.Ticket19607Tests.test_bug",
# {'$project': {'name': Decimal128('1')} is broken? (gives None)
"expressions.tests.ValueTests.test_output_field_decimalfield",
}
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
_django_test_expected_failures_bitwise = {
Expand All @@ -112,6 +124,7 @@ def django_test_expected_failures(self):
# bson.errors.InvalidDocument: cannot encode object:
# <django.db.models.expressions.DatabaseDefault
"basic.tests.ModelInstanceCreationTests.test_save_primary_with_db_default",
"basic.tests.ModelInstanceCreationTests.test_save_primary_with_falsey_db_default",
"constraints.tests.UniqueConstraintTests.test_database_default",
"field_defaults.tests.DefaultTests",
"migrations.test_operations.OperationTests.test_add_field_both_defaults",
Expand Down Expand Up @@ -194,9 +207,13 @@ def django_test_expected_failures(self):
"prefetch_related.tests.Ticket21410Tests",
"queryset_pickle.tests.PickleabilityTestCase.test_pickle_prefetch_related_with_m2m_and_objects_deletion",
"serializers.test_json.JsonSerializerTestCase.test_serialize_prefetch_related_m2m",
"serializers.test_json.JsonSerializerTestCase.test_serialize_prefetch_related_m2m_with_natural_keys",
"serializers.test_jsonl.JsonlSerializerTestCase.test_serialize_prefetch_related_m2m",
"serializers.test_jsonl.JsonlSerializerTestCase.test_serialize_prefetch_related_m2m_with_natural_keys",
"serializers.test_xml.XmlSerializerTestCase.test_serialize_prefetch_related_m2m",
"serializers.test_xml.XmlSerializerTestCase.test_serialize_prefetch_related_m2m_with_natural_keys",
"serializers.test_yaml.YamlSerializerTestCase.test_serialize_prefetch_related_m2m",
"serializers.test_yaml.YamlSerializerTestCase.test_serialize_prefetch_related_m2m_with_natural_keys",
},
"AutoField not supported.": {
"bulk_create.tests.BulkCreateTests.test_bulk_insert_nullable_fields",
Expand Down Expand Up @@ -599,6 +616,13 @@ def django_test_expected_failures(self):
"foreign_object.tests.MultiColumnFKTests",
"foreign_object.tests.TestExtraJoinFilterQ",
},
"Tuple lookups are not supported.": {
"foreign_object.test_tuple_lookups.TupleLookupsTests",
},
"ColPairs is not supported.": {
# 'ColPairs' object has no attribute 'as_mql'
"auth_tests.test_views.CustomUserCompositePrimaryKeyPasswordResetTest",
},
"Custom lookups are not supported.": {
"custom_lookups.tests.BilateralTransformTests",
"custom_lookups.tests.LookupTests.test_basic_lookup",
Expand Down
7 changes: 4 additions & 3 deletions django_mongodb_backend/lookups.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.db import NotSupportedError
from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn
from django.db.models.expressions import ColPairs
from django.db.models.fields.related_lookups import In, RelatedIn
from django.db.models.lookups import (
BuiltinLookup,
FieldGetDbPrepValueIterableMixin,
Expand Down Expand Up @@ -34,8 +35,8 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param):


def in_(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
raise NotImplementedError("MultiColSource is not supported.")
if isinstance(self.lhs, ColPairs):
raise NotImplementedError("ColPairs is not supported.")
db_rhs = getattr(self.rhs, "_db", None)
if db_rhs is not None and db_rhs != connection.alias:
raise ValueError(
Expand Down
22 changes: 0 additions & 22 deletions django_mongodb_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,28 +195,6 @@ def execute_sql_flush(self, tables):
if not options.get("capped", False):
collection.delete_many({})

def prep_lookup_value(self, value, field, lookup):
"""
Perform type-conversion on `value` before using as a filter parameter.
"""
if getattr(field, "rel", None) is not None:
field = field.rel.get_related_field()
field_kind = field.get_internal_type()

if lookup in ("in", "range"):
return [
self._prep_lookup_value(subvalue, field, field_kind, lookup) for subvalue in value
]
return self._prep_lookup_value(value, field, field_kind, lookup)

def _prep_lookup_value(self, value, field, field_kind, lookup):
if value is None:
return None

if field_kind == "DecimalField":
value = self.adapt_decimalfield_value(value, field.max_digits, field.decimal_places)
return value

def explain_query_prefix(self, format=None, **options):
# Validate options.
validated_options = {}
Expand Down
5 changes: 1 addition & 4 deletions django_mongodb_backend/query_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ def process_rhs(node, compiler, connection):
value = value[0]
if hasattr(node, "prep_lookup_value_mongo"):
value = node.prep_lookup_value_mongo(value)
# No need to prepare expressions like F() objects.
if hasattr(rhs, "resolve_expression"):
return value
return connection.ops.prep_lookup_value(value, node.lhs.output_field, node.lookup_name)
return value


def regex_match(field, regex_vals, insensitive=False):
Expand Down
19 changes: 11 additions & 8 deletions tests/expressions_/test_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,44 @@
from decimal import Decimal

from bson import Decimal128
from django.db import connection
from django.db.models import Value
from django.test import SimpleTestCase


class ValueTests(SimpleTestCase):
def test_date(self):
self.assertEqual(
Value(datetime.date(2025, 1, 1)).as_mql(None, None),
Value(datetime.date(2025, 1, 1)).as_mql(None, connection),
datetime.datetime(2025, 1, 1),
)

def test_datetime(self):
self.assertEqual(
Value(datetime.datetime(2025, 1, 1, 9, 8, 7)).as_mql(None, None),
Value(datetime.datetime(2025, 1, 1, 9, 8, 7)).as_mql(None, connection),
datetime.datetime(2025, 1, 1, 9, 8, 7),
)

def test_decimal(self):
self.assertEqual(Value(Decimal("1.0")).as_mql(None, None), Decimal128("1.0"))
self.assertEqual(Value(Decimal("1.0")).as_mql(None, connection), Decimal128("1.0"))

def test_time(self):
self.assertEqual(
Value(datetime.time(9, 8, 7)).as_mql(None, None),
Value(datetime.time(9, 8, 7)).as_mql(None, connection),
datetime.datetime(1, 1, 1, 9, 8, 7),
)

def test_timedelta(self):
self.assertEqual(Value(datetime.timedelta(3600)).as_mql(None, None), 311040000000.0)
self.assertEqual(
Value(datetime.timedelta(3600)).as_mql(None, connection), {"$literal": 311040000000}
)

def test_int(self):
self.assertEqual(Value(1).as_mql(None, None), {"$literal": 1})
self.assertEqual(Value(1).as_mql(None, connection), {"$literal": 1})

def test_str(self):
self.assertEqual(Value("foo").as_mql(None, None), "foo")
self.assertEqual(Value("foo").as_mql(None, connection), "foo")

def test_uuid(self):
value = uuid.UUID(int=1)
self.assertEqual(Value(value).as_mql(None, None), "00000000000000000000000000000001")
self.assertEqual(Value(value).as_mql(None, connection), "00000000000000000000000000000001")
2 changes: 1 addition & 1 deletion tests/indexes_/test_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_composite_index(self):
{
"$and": [
{"number": {"$gte": 3}},
{"$or": [{"body": {"$gt": "test1"}}, {"body": {"$in": ["A", "B"]}}]},
{"$or": [{"body": {"$gt": "test1"}}, {"body": {"$in": ("A", "B")}}]},
]
},
)
Expand Down
8 changes: 4 additions & 4 deletions tests/model_forms_/test_embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def test_some_missing_data(self):
required id="id_title">
</div>
<div>
<fieldset>
<fieldset aria-describedby="id_publisher_error">
<legend>Publisher:</legend>
<ul class="errorlist">
<ul class="errorlist" id="id_publisher_error">
<li>Enter all required values.</li>
</ul>
<div>
Expand Down Expand Up @@ -252,9 +252,9 @@ def test_invalid_field_data(self):
maxlength="50" required id="id_title">
</div>
<div>
<fieldset>
<fieldset aria-describedby="id_publisher_error">
<legend>Publisher:</legend>
<ul class="errorlist">
<ul class="errorlist" id="id_publisher_error">
<li>Ensure this value has at most 2 characters (it has 8).</li>
</ul>
<div>
Expand Down