Skip to content

[WIP] Atlas search lookups #325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 36 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1838811
Create django_mongodb_backend.expressions package
timgraham Jun 25, 2025
a8b1c03
First approach.
WaVEV Jun 16, 2025
e8dce30
Add SearchExpressions
WaVEV Jun 22, 2025
7c40032
Add test
WaVEV Jun 24, 2025
6be0799
Refactor.
WaVEV Jun 26, 2025
3057c01
Add search index test
WaVEV Jun 29, 2025
2768e24
Add moreLikeThis lookup.
WaVEV Jun 30, 2025
228ea1f
CombinedSearchExpression
WaVEV Jul 1, 2025
c39ab78
Add vector search expression.
WaVEV Jul 2, 2025
dfaefc7
Add combinable operators.
WaVEV Jul 5, 2025
9249e05
Edits
WaVEV Jul 6, 2025
c1e9493
Add __str__ method
WaVEV Jul 6, 2025
a32a8de
Add combined expressions test.
WaVEV Jul 6, 2025
4845480
Add vector search test.
WaVEV Jul 7, 2025
152aa46
Add combinable test
WaVEV Jul 7, 2025
d19aa10
Refactor
WaVEV Jul 7, 2025
995a6b1
Remove unused parameter
WaVEV Jul 7, 2025
9a1543c
Fix unit test
WaVEV Jul 10, 2025
2df3729
Move search expression to expression/search
WaVEV Jul 10, 2025
9efd0eb
Improve unit test.
WaVEV Jul 10, 2025
1131111
Add expression wrap to parameters.
WaVEV Jul 12, 2025
7ce99b5
Adding source and set source.
WaVEV Jul 12, 2025
aaadb9d
Edits.
WaVEV Jul 12, 2025
aaeea0d
Edits
WaVEV Jul 12, 2025
c351f6b
Resolve value as direct value.
WaVEV Jul 12, 2025
e6660f9
Add vibe docstrings to MongoDB Atlas search expressions.
WaVEV Jul 12, 2025
318b010
Fix invalid operation.
WaVEV Jul 12, 2025
62fb3ea
Refactor utils function
WaVEV Jul 12, 2025
fd08772
Add skip flag.
WaVEV Jul 12, 2025
bbfd2b6
Edits
WaVEV Jul 12, 2025
4f99a4d
Add search text lookup.
WaVEV Jul 12, 2025
d4cbe18
Remove unused method.
WaVEV Jul 12, 2025
a467a57
Edits.
WaVEV Jul 12, 2025
a17102c
Fix replacements.
WaVEV Jul 13, 2025
8764f42
Edits.
WaVEV Jul 14, 2025
47bc62b
Edits.
WaVEV Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ repos:
rev: "v2.2.6"
hooks:
- id: codespell
args: ["-L", "nin"]
args: ["-L", "nin", "-L", "searchin"]
2 changes: 1 addition & 1 deletion django_mongodb_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .aggregates import register_aggregates # noqa: E402
from .checks import register_checks # noqa: E402
from .expressions import register_expressions # noqa: E402
from .expressions.builtins import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .indexes import register_indexes # noqa: E402
Expand Down
142 changes: 119 additions & 23 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .expressions.search import SearchExpression, SearchVector
from .query import MongoQuery, wrap_database_errors


Expand All @@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
# A list of OrderBy objects for this query.
self.order_by_objs = None
self.subqueries = []
# Atlas search calls
self.search_pipeline = []

def _get_group_alias_column(self, expr, annotation_group_idx):
"""Generate a dummy field for use in the ids fields in $group."""
Expand All @@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
column_target.set_attributes_from_name(alias)
return Col(self.collection_name, column_target)

def _get_replace_expr(self, sub_expr, group, alias):
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if getattr(sub_expr, "distinct", False):
# If the expression should return distinct values, use
# $addToSet to deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
return replacing_expr

def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
"""
Prepare expressions for the aggregation pipeline.
Expand All @@ -80,29 +106,33 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
alias = (
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
)
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if sub_expr.distinct:
# If the expression should return distinct values, use
# $addToSet to deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
replacements[sub_expr] = replacing_expr
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
return replacements, group

def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements):
searches = {}
for sub_expr in self._get_search_expressions(expression):
if sub_expr not in replacements:
alias = f"__search_expr.search{next(search_idx)}"
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)

def _prepare_search_query_for_aggregation_pipeline(self, order_by):
replacements = {}
annotation_group_idx = itertools.count(start=1)
for expr in self.query.annotation_select.values():
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)

for expr, _ in order_by:
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)

self._prepare_search_expressions_for_pipeline(
self.having, annotation_group_idx, replacements
)
self._prepare_search_expressions_for_pipeline(
self.get_where(), annotation_group_idx, replacements
)
return replacements

def _prepare_annotations_for_aggregation_pipeline(self, order_by):
"""Prepare annotations for the aggregation pipeline."""
replacements = {}
Expand Down Expand Up @@ -207,9 +237,57 @@ def _build_aggregation_pipeline(self, ids, group):
pipeline.append({"$unset": "_id"})
return pipeline

def _compound_searches_queries(self, search_replacements):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to preserve this function for the future, probably want to make hybrid search and this part of the code could be useful. I know that it is weird, check the replacement len as 1 and then iterate over it. Also the exception could be raised before this point. Let me know if you want me to refactor this code.

if not search_replacements:
return []
if len(search_replacements) > 1:
has_search = any(not isinstance(search, SearchVector) for search in search_replacements)
has_vector_search = any(
isinstance(search, SearchVector) for search in search_replacements
)
if has_search and has_vector_search:
raise ValueError(
"Cannot combine a `$vectorSearch` with a `$search` operator. "
"If you need to combine them, consider restructuring your query logic or "
"running them as separate queries."
)
if not has_search:
raise ValueError(
"Cannot combine two `$vectorSearch` operator. "
"If you need to combine them, consider restructuring your query logic or "
"running them as separate queries."
)
raise ValueError(
"Only one $search operation is allowed per query. "
f"Received {len(search_replacements)} search expressions. "
"To combine multiple search expressions, use either a CompoundExpression for "
"fine-grained control or CombinedSearchExpression for simple logical combinations."
)
pipeline = []
for search, result_col in search_replacements.items():
score_function = (
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore"
)
pipeline.extend(
[
search.as_mql(self, self.connection),
{
"$addFields": {
result_col.as_mql(self, self.connection, as_path=True): {
"$meta": score_function
}
}
},
]
)
return pipeline

def pre_sql_setup(self, with_col_aliases=False):
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by)
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
all_replacements = {**search_replacements, **group_replacements}
self.search_pipeline = self._compound_searches_queries(search_replacements)
# query.group_by is either:
# - None: no GROUP BY
# - True: group by select fields
Expand All @@ -234,6 +312,9 @@ def pre_sql_setup(self, with_col_aliases=False):
for target, expr in self.query.annotation_select.items()
}
self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by]
if (where := self.get_where()) and search_replacements:
where = where.replace_expressions(search_replacements)
self.set_where(where)
return extra_select, order_by, group_by

def execute_sql(
Expand Down Expand Up @@ -557,10 +638,16 @@ def get_lookup_pipeline(self):
return result

def _get_aggregate_expressions(self, expr):
return self._get_all_expressions_of_type(expr, Aggregate)

def _get_search_expressions(self, expr):
return self._get_all_expressions_of_type(expr, SearchExpression)

def _get_all_expressions_of_type(self, expr, target_type):
stack = [expr]
while stack:
expr = stack.pop()
if isinstance(expr, Aggregate):
if isinstance(expr, target_type):
yield expr
elif hasattr(expr, "get_source_expressions"):
stack.extend(expr.get_source_expressions())
Expand Down Expand Up @@ -629,6 +716,9 @@ def _get_ordering(self):
def get_where(self):
return getattr(self, "where", self.query.where)

def set_where(self, value):
self.where = value

def explain_query(self):
# Validate format (none supported) and options.
options = self.connection.ops.explain_query_prefix(
Expand Down Expand Up @@ -715,6 +805,9 @@ def check_query(self):
def get_where(self):
return self.query.where

def set_where(self, value):
self.query.where = value

@cached_property
def collection_name(self):
return self.query.base_table
Expand Down Expand Up @@ -786,6 +879,9 @@ def check_query(self):
def get_where(self):
return self.query.where

def set_where(self, value):
self.query.where = value

@cached_property
def collection_name(self):
return self.query.base_table
Expand Down
4 changes: 4 additions & 0 deletions django_mongodb_backend/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def _destroy_test_db(self, test_database_name, verbosity):

for collection in self.connection.introspection.table_names():
if not collection.startswith("system."):
if self.connection.features.supports_atlas_search:
db_collection = self.connection.database.get_collection(collection)
for search_indexes in db_collection.list_search_indexes():
db_collection.drop_search_index(search_indexes["name"])
self.connection.database.drop_collection(collection)

def create_test_db(self, *args, **kwargs):
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from django.db.models.sql import Query

from .query_utils import process_lhs
from ..query_utils import process_lhs


def case(self, compiler, connection):
Expand Down Expand Up @@ -53,7 +53,7 @@ def case(self, compiler, connection):
}


def col(self, compiler, connection): # noqa: ARG001
def col(self, compiler, connection, as_path=False): # noqa: ARG001
# If the column is part of a subquery and belongs to one of the parent
# queries, it will be stored for reference using $let in a $lookup stage.
# If the query is built with `alias_cols=False`, treat the column as
Expand All @@ -71,7 +71,7 @@ def col(self, compiler, connection): # noqa: ARG001
# Add the column's collection's alias for columns in joined collections.
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
return f"${prefix}{self.target.column}"
return f"{prefix}{self.target.column}" if as_path else f"${prefix}{self.target.column}"


def col_pairs(self, compiler, connection):
Expand Down
Loading