diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a3301328..188c8f3cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,4 +81,4 @@ repos: rev: "v2.2.6" hooks: - id: codespell - args: ["-L", "nin"] + args: ["-L", "nin", "-L", "searchin"] diff --git a/django_mongodb_backend/__init__.py b/django_mongodb_backend/__init__.py index 00700421a..d21566d9c 100644 --- a/django_mongodb_backend/__init__.py +++ b/django_mongodb_backend/__init__.py @@ -8,7 +8,7 @@ from .aggregates import register_aggregates # noqa: E402 from .checks import register_checks # noqa: E402 -from .expressions import register_expressions # noqa: E402 +from .expressions.builtins import register_expressions # noqa: E402 from .fields import register_fields # noqa: E402 from .functions import register_functions # noqa: E402 from .indexes import register_indexes # noqa: E402 diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 1c727039d..3e1ccc7fc 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -17,6 +17,7 @@ from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING +from .expressions.search import SearchExpression, SearchVector from .query import MongoQuery, wrap_database_errors @@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs): # A list of OrderBy objects for this query. self.order_by_objs = None self.subqueries = [] + # Atlas search calls + self.search_pipeline = [] def _get_group_alias_column(self, expr, annotation_group_idx): """Generate a dummy field for use in the ids fields in $group.""" @@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias): column_target.set_attributes_from_name(alias) return Col(self.collection_name, column_target) + def _get_replace_expr(self, sub_expr, group, alias): + column_target = sub_expr.output_field.clone() + column_target.db_column = alias + column_target.set_attributes_from_name(alias) + inner_column = Col(self.collection_name, column_target) + if getattr(sub_expr, "distinct", False): + # If the expression should return distinct values, use + # $addToSet to deduplicate. + rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) + group[alias] = {"$addToSet": rhs} + replacing_expr = sub_expr.copy() + replacing_expr.set_source_expressions([inner_column, None]) + else: + group[alias] = sub_expr.as_mql(self, self.connection) + replacing_expr = inner_column + # Count must return 0 rather than null. + if isinstance(sub_expr, Count): + replacing_expr = Coalesce(replacing_expr, 0) + # Variance = StdDev^2 + if isinstance(sub_expr, Variance): + replacing_expr = Power(replacing_expr, 2) + return replacing_expr + def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx): """ Prepare expressions for the aggregation pipeline. @@ -80,29 +106,33 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group alias = ( f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target ) - column_target = sub_expr.output_field.clone() - column_target.db_column = alias - column_target.set_attributes_from_name(alias) - inner_column = Col(self.collection_name, column_target) - if sub_expr.distinct: - # If the expression should return distinct values, use - # $addToSet to deduplicate. - rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) - group[alias] = {"$addToSet": rhs} - replacing_expr = sub_expr.copy() - replacing_expr.set_source_expressions([inner_column, None]) - else: - group[alias] = sub_expr.as_mql(self, self.connection) - replacing_expr = inner_column - # Count must return 0 rather than null. - if isinstance(sub_expr, Count): - replacing_expr = Coalesce(replacing_expr, 0) - # Variance = StdDev^2 - if isinstance(sub_expr, Variance): - replacing_expr = Power(replacing_expr, 2) - replacements[sub_expr] = replacing_expr + replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias) return replacements, group + def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements): + searches = {} + for sub_expr in self._get_search_expressions(expression): + if sub_expr not in replacements: + alias = f"__search_expr.search{next(search_idx)}" + replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias) + + def _prepare_search_query_for_aggregation_pipeline(self, order_by): + replacements = {} + annotation_group_idx = itertools.count(start=1) + for expr in self.query.annotation_select.values(): + self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements) + + for expr, _ in order_by: + self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements) + + self._prepare_search_expressions_for_pipeline( + self.having, annotation_group_idx, replacements + ) + self._prepare_search_expressions_for_pipeline( + self.get_where(), annotation_group_idx, replacements + ) + return replacements + def _prepare_annotations_for_aggregation_pipeline(self, order_by): """Prepare annotations for the aggregation pipeline.""" replacements = {} @@ -207,9 +237,36 @@ def _build_aggregation_pipeline(self, ids, group): pipeline.append({"$unset": "_id"}) return pipeline + def _compound_searches_queries(self, search_replacements): + if not search_replacements: + return [] + if len(search_replacements) > 1: + raise ValueError("Cannot perform more than one search operation.") + pipeline = [] + for search, result_col in search_replacements.items(): + score_function = ( + "vectorSearchScore" if isinstance(search, SearchVector) else "searchScore" + ) + pipeline.extend( + [ + search.as_mql(self, self.connection), + { + "$addFields": { + result_col.as_mql(self, self.connection).removeprefix("$"): { + "$meta": score_function + } + } + }, + ] + ) + return pipeline + def pre_sql_setup(self, with_col_aliases=False): extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases) - group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) + search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by) + group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) + all_replacements = {**search_replacements, **group_replacements} + self.search_pipeline = self._compound_searches_queries(search_replacements) # query.group_by is either: # - None: no GROUP BY # - True: group by select fields @@ -234,6 +291,9 @@ def pre_sql_setup(self, with_col_aliases=False): for target, expr in self.query.annotation_select.items() } self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by] + if (where := self.get_where()) and search_replacements: + where = where.replace_expressions(search_replacements) + self.set_where(where) return extra_select, order_by, group_by def execute_sql( @@ -557,10 +617,16 @@ def get_lookup_pipeline(self): return result def _get_aggregate_expressions(self, expr): + return self._get_all_expressions_of_type(expr, Aggregate) + + def _get_search_expressions(self, expr): + return self._get_all_expressions_of_type(expr, SearchExpression) + + def _get_all_expressions_of_type(self, expr, target_type): stack = [expr] while stack: expr = stack.pop() - if isinstance(expr, Aggregate): + if isinstance(expr, target_type): yield expr elif hasattr(expr, "get_source_expressions"): stack.extend(expr.get_source_expressions()) @@ -629,6 +695,9 @@ def _get_ordering(self): def get_where(self): return getattr(self, "where", self.query.where) + def set_where(self, value): + self.where = value + def explain_query(self): # Validate format (none supported) and options. options = self.connection.ops.explain_query_prefix( @@ -715,6 +784,9 @@ def check_query(self): def get_where(self): return self.query.where + def set_where(self, value): + self.query.where = value + @cached_property def collection_name(self): return self.query.base_table @@ -786,6 +858,9 @@ def check_query(self): def get_where(self): return self.query.where + def set_where(self, value): + self.query.where = value + @cached_property def collection_name(self): return self.query.base_table diff --git a/django_mongodb_backend/creation.py b/django_mongodb_backend/creation.py index 572e770fa..b787ad06c 100644 --- a/django_mongodb_backend/creation.py +++ b/django_mongodb_backend/creation.py @@ -22,6 +22,10 @@ def _destroy_test_db(self, test_database_name, verbosity): for collection in self.connection.introspection.table_names(): if not collection.startswith("system."): + if self.connection.features.supports_atlas_search: + db_collection = self.connection.database.get_collection(collection) + for search_indexes in db_collection.list_search_indexes(): + db_collection.drop_search_index(search_indexes["name"]) self.connection.database.drop_collection(collection) def create_test_db(self, *args, **kwargs): diff --git a/django_mongodb_backend/expressions/__init__.py b/django_mongodb_backend/expressions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/django_mongodb_backend/expressions.py b/django_mongodb_backend/expressions/builtins.py similarity index 97% rename from django_mongodb_backend/expressions.py rename to django_mongodb_backend/expressions/builtins.py index 46eef56da..da95d5fe2 100644 --- a/django_mongodb_backend/expressions.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -25,7 +25,7 @@ ) from django.db.models.sql import Query -from .query_utils import process_lhs +from ..query_utils import process_lhs def case(self, compiler, connection): @@ -53,7 +53,7 @@ def case(self, compiler, connection): } -def col(self, compiler, connection): # noqa: ARG001 +def col(self, compiler, connection, as_path=False): # noqa: ARG001 # If the column is part of a subquery and belongs to one of the parent # queries, it will be stored for reference using $let in a $lookup stage. # If the query is built with `alias_cols=False`, treat the column as @@ -71,7 +71,7 @@ def col(self, compiler, connection): # noqa: ARG001 # Add the column's collection's alias for columns in joined collections. has_alias = self.alias and self.alias != compiler.collection_name prefix = f"{self.alias}." if has_alias else "" - return f"${prefix}{self.target.column}" + return f"{prefix}{self.target.column}" if as_path else f"${prefix}{self.target.column}" def col_pairs(self, compiler, connection): diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py new file mode 100644 index 000000000..1c7f8cdba --- /dev/null +++ b/django_mongodb_backend/expressions/search.py @@ -0,0 +1,1038 @@ +from django.db import NotSupportedError +from django.db.models import CharField, Expression, FloatField, TextField +from django.db.models.expressions import F, Value +from django.db.models.lookups import Lookup + +from ..query_utils import process_lhs, process_rhs + + +def cast_as_value(value): + if value is None: + return None + return Value(value) if not hasattr(value, "resolve_expression") else value + + +def cast_as_field(path): + return F(path) if isinstance(path, str) else path + + +class Operator: + AND = "AND" + OR = "OR" + NOT = "NOT" + + def __init__(self, operator): + self.operator = operator + + def __eq__(self, other): + if isinstance(other, str): + return self.operator == other + return self.operator == other.operator + + def negate(self): + if self.operator == self.AND: + return Operator(self.OR) + if self.operator == self.OR: + return Operator(self.AND) + return Operator(self.operator) + + def __hash__(self): + return hash(self.operator) + + def __str__(self): + return self.operator + + def __repr__(self): + return self.operator + + +class SearchCombinable: + def _combine(self, other, connector): + if not isinstance(self, CompoundExpression | CombinedSearchExpression): + lhs = CompoundExpression(must=[self]) + else: + lhs = self + if other and not isinstance(other, CompoundExpression | CombinedSearchExpression): + rhs = CompoundExpression(must=[other]) + else: + rhs = other + return CombinedSearchExpression(lhs, connector, rhs) + + def __invert__(self): + return self._combine(None, Operator(Operator.NOT)) + + def __and__(self, other): + return self._combine(other, Operator(Operator.AND)) + + def __rand__(self, other): + return self._combine(other, Operator(Operator.AND)) + + def __or__(self, other): + return self._combine(other, Operator(Operator.OR)) + + def __ror__(self, other): + return self._combine(self, Operator(Operator.OR), other) + + +class SearchExpression(SearchCombinable, Expression): + """Base expression node for MongoDB Atlas **$search** stages. + + This class bridges Django's ``Expression`` API with the MongoDB Atlas + Search engine. Subclasses produce the operator document placed under + **$search** and expose the stage to queryset methods such as + ``annotate()``, ``filter()``, or ``order_by()``. + """ + + output_field = FloatField() + + def __str__(self): + cls = self.identity[0] + kwargs = dict(self.identity[1:]) + arg_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items()) + return f"{cls.__name__}({arg_str})" + + def __repr__(self): + return str(self) + + def as_sql(self, compiler, connection): + return "", [] + + def get_source_expressions(self): + return [] + + def _get_indexed_fields(self, mappings): + for field, definition in mappings.get("fields", {}).items(): + yield field + for path in self._get_indexed_fields(definition): + yield f"{field}.{path}" + + def _get_query_index(self, fields, compiler): + fields = set(fields) + for search_indexes in compiler.collection.list_search_indexes(): + mappings = search_indexes["latestDefinition"]["mappings"] + indexed_fields = set(self._get_indexed_fields(mappings)) + if mappings["dynamic"] or fields.issubset(indexed_fields): + return search_indexes["name"] + return "default" + + def search_operator(self, compiler, connection): + raise NotImplementedError + + def as_mql(self, compiler, connection): + index = self._get_query_index(self.get_search_fields(compiler, connection), compiler) + return {"$search": {**self.search_operator(compiler, connection), "index": index}} + + +class SearchAutocomplete(SearchExpression): + """ + Atlas Search expression that matches input using the **autocomplete** operator. + + This expression enables autocomplete behavior by querying against a field + indexed as `"type": "autocomplete"` in MongoDB Atlas. It can be used in + `filter()`, `annotate()` or any context that accepts a Django expression. + + Example: + SearchAutocomplete("title", "harry", fuzzy={"maxEdits": 1}) + + Args: + path: The document path to search (as string or expression). + query: The input string to autocomplete. + fuzzy: Optional dictionary of fuzzy matching parameters. + token_order: Optional value for `"tokenOrder"`; controls sequential vs. + any-order token matching. + score: Optional expression to adjust score relevance (e.g., `{"boost": {"value": 5}}`). + + Notes: + * Requires an Atlas Search index with `autocomplete` mappings. + * The operator is injected under the `$search` stage in the aggregation pipeline. + """ + + def __init__(self, path, query, fuzzy=None, token_order=None, score=None): + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.fuzzy = cast_as_value(fuzzy) + self.token_order = cast_as_value(token_order) + self.score = score + super().__init__() + + def get_source_expressions(self): + return [self.path, self.query, self.fuzzy, self.token_order] + + def set_source_expressions(self, exprs): + self.path, self.query, self.fuzzy, self.token_order = exprs + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.value, + } + if self.score is not None: + params["score"] = self.score.as_mql(compiler, connection) + if self.fuzzy is not None: + params["fuzzy"] = self.fuzzy.value + if self.token_order is not None: + params["tokenOrder"] = self.token_order.value + return {"autocomplete": params} + + +class SearchEquals(SearchExpression): + """ + Atlas Search expression that matches documents with a field equal to the given value. + + This expression uses the **equals** operator to perform exact matches + on fields indexed in a MongoDB Atlas Search index. + + Example: + SearchEquals("category", "fiction") + + Args: + path: The document path to compare (as string or expression). + value: The exact value to match against. + score: Optional expression to modify the relevance score. + + Notes: + * The field must be indexed with a supported type for `equals`. + * Supports numeric, string, boolean, and date values. + * Score boosting can be applied using the `score` parameter. + """ + + def __init__(self, path, value, score=None): + self.path = cast_as_field(path) + self.value = cast_as_value(value) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.value] + + def set_source_expressions(self, exprs): + self.path, self.value = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + "value": self.value.value, + } + if self.score is not None: + params["score"] = self.score.as_mql(compiler, connection) + return {"equals": params} + + +class SearchExists(SearchExpression): + """ + Atlas Search expression that matches documents where a field exists. + + This expression uses the **exists** operator to check whether a given + path is present in the document. Useful for filtering documents that + include (or exclude) optional fields. + + Example: + SearchExists("metadata__author") + + Args: + path: The document path to check (as string or expression). + score: Optional expression to modify the relevance score. + + Notes: + * The target field must be mapped in the Atlas Search index. + * This does not test for null—only for presence. + """ + + def __init__(self, path, score=None): + self.path = cast_as_field(path) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path] + + def set_source_expressions(self, exprs): + (self.path,) = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + } + if self.score is not None: + params["score"] = self.score.definitions + return {"exists": params} + + +class SearchIn(SearchExpression): + def __init__(self, path, value, score=None): + self.path = cast_as_field(path) + self.value = cast_as_value(value) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.value] + + def set_source_expressions(self, exprs): + self.path, self.value = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + "value": self.value.value, + } + if self.score is not None: + params["score"] = self.score.as_mql(compiler, connection) + return {"in": params} + + +class SearchPhrase(SearchExpression): + """ + Atlas Search expression that matches a phrase in the specified field. + + This expression uses the **phrase** operator to search for exact or near-exact + sequences of terms. It supports optional slop (word distance) and synonym sets. + + Example: + SearchPhrase("description__text", "climate change", slop=2) + + Args: + path: The document path to search (as string or expression). + query: The phrase to match as a single string or list of terms. + slop: Optional maximum word distance allowed between phrase terms. + synonyms: Optional name of a synonym mapping defined in the Atlas index. + score: Optional expression to modify the relevance score. + + Notes: + * The field must be mapped as `"type": "string"` with appropriate analyzers. + * Slop allows flexibility in word positioning, like `"quick brown fox"` + matching `"quick fox"` if `slop=1`. + """ + + def __init__(self, path, query, slop=None, synonyms=None, score=None): + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.slop = cast_as_value(slop) + self.synonyms = cast_as_value(synonyms) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.query, self.slop, self.synonyms] + + def set_source_expressions(self, exprs): + self.path, self.query, self.slop, self.synonyms = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.value, + } + if self.score is not None: + params["score"] = self.score.as_mql(compiler, connection) + if self.slop is not None: + params["slop"] = self.slop.value + if self.synonyms is not None: + params["synonyms"] = self.synonyms.value + return {"phrase": params} + + +class SearchQueryString(SearchExpression): + """ + Atlas Search expression that matches using a Lucene-style query string. + + This expression uses the **queryString** operator to parse and execute + full-text queries written in a simplified Lucene syntax. It supports + advanced constructs like boolean operators, wildcards, and field-specific terms. + + Example: + SearchQueryString("content__text", "django AND (search OR query)") + + Args: + path: The document path to query (as string or expression). + query: The Lucene-style query string. + score: Optional expression to modify the relevance score. + + Notes: + * The query string syntax must conform to Atlas Search rules. + * This operator is powerful but can be harder to validate or sanitize. + """ + + def __init__(self, path, query, score=None): + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.query] + + def set_source_expressions(self, exprs): + self.path, self.query = exprs + + def search_operator(self, compiler, connection): + params = { + "defaultPath": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.value, + } + if self.score is not None: + params["score"] = self.score.as_mql(compiler, connection) + return {"queryString": params} + + +class SearchRange(SearchExpression): + """ + Atlas Search expression that filters documents within a range of values. + + This expression uses the **range** operator to match numeric, date, or + other comparable fields based on upper and/or lower bounds. + + Example: + SearchRange("published__year", gte=2000, lt=2020) + + Args: + path: The document path to filter (as string or expression). + lt: Optional exclusive upper bound (`<`). + lte: Optional inclusive upper bound (`<=`). + gt: Optional exclusive lower bound (`>`). + gte: Optional inclusive lower bound (`>=`). + score: Optional expression to modify the relevance score. + + Notes: + * At least one of `lt`, `lte`, `gt`, or `gte` must be provided. + * The field must be mapped in the Atlas Search index as a comparable type. + """ + + def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): + self.path = cast_as_field(path) + self.lt = cast_as_value(lt) + self.lte = cast_as_value(lte) + self.gt = cast_as_value(gt) + self.gte = cast_as_value(gte) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.lt, self.lte, self.gt, self.gte] + + def set_source_expressions(self, exprs): + self.path, self.lt, self.lte, self.gt, self.gte = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + } + if self.score is not None: + params["score"] = self.score.as_mql(compiler, connection) + if self.lt is not None: + params["lt"] = self.lt.value + if self.lte is not None: + params["lte"] = self.lte.value + if self.gt is not None: + params["gt"] = self.gt.value + if self.gte is not None: + params["gte"] = self.gte.value + return {"range": params} + + +class SearchRegex(SearchExpression): + """ + Atlas Search expression that matches strings using a regular expression. + + This expression uses the **regex** operator to apply a regular expression + against the contents of a specified field. + + Example: + SearchRegex("username", r"^admin_") + + Args: + path: The document path to match (as string or expression). + query: The regular expression pattern to apply. + allow_analyzed_field: Whether to allow matching against analyzed fields (default is False). + score: Optional expression to modify the relevance score. + + Notes: + * Regular expressions must follow JavaScript regex syntax. + * By default, the field must be mapped as `"analyzer": "keyword"` + unless `allow_analyzed_field=True`. + """ + + def __init__(self, path, query, allow_analyzed_field=None, score=None): + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.allow_analyzed_field = cast_as_value(allow_analyzed_field) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.query, self.allow_analyzed_field] + + def set_source_expressions(self, exprs): + self.path, self.query, self.allow_analyzed_field = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.value, + } + if self.score: + params["score"] = self.score.as_mql(compiler, connection) + if self.allow_analyzed_field is not None: + params["allowAnalyzedField"] = self.allow_analyzed_field.value + return {"regex": params} + + +class SearchText(SearchExpression): + """ + Atlas Search expression that performs full-text search using the **text** operator. + + This expression matches terms in a specified field with options for + fuzzy matching, match criteria, and synonyms. + + Example: + SearchText("description__content", "mongodb", fuzzy={"maxEdits": 1}, match_criteria="all") + + Args: + path: The document path to search (as string or expression). + query: The search term or phrase. + fuzzy: Optional dictionary to configure fuzzy matching parameters. + match_criteria: Optional criteria for term matching (e.g., "all" or "any"). + synonyms: Optional name of a synonym mapping defined in the Atlas index. + score: Optional expression to adjust relevance scoring. + + Notes: + * The target field must be indexed for full-text search in Atlas. + * Fuzzy matching helps match terms with minor typos or variations. + """ + + def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.fuzzy = cast_as_value(fuzzy) + self.match_criteria = cast_as_value(match_criteria) + self.synonyms = cast_as_value(synonyms) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.query, self.fuzzy, self.match_criteria, self.synonyms] + + def set_source_expressions(self, exprs): + self.path, self.query, self.fuzzy, self.match_criteria, self.synonyms = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.value, + } + if self.score: + params["score"] = self.score.as_mql(compiler, connection) + if self.fuzzy is not None: + params["fuzzy"] = self.fuzzy.value + if self.match_criteria is not None: + params["matchCriteria"] = self.match_criteria.value + if self.synonyms is not None: + params["synonyms"] = self.synonyms.value + return {"text": params} + + +class SearchWildcard(SearchExpression): + """ + Atlas Search expression that matches strings using wildcard patterns. + + This expression uses the **wildcard** operator to search for terms + matching a pattern with `*` and `?` wildcards. + + Example: + SearchWildcard("filename", "report_202?_final*") + + Args: + path: The document path to search (as string or expression). + query: The wildcard pattern to match. + allow_analyzed_field: Whether to allow matching against analyzed fields (default is False). + score: Optional expression to modify the relevance score. + + Notes: + * Wildcard patterns follow standard syntax, where `*` matches any sequence of characters + and `?` matches a single character. + * By default, the field should be keyword or unanalyzed + unless `allow_analyzed_field=True`. + """ + + def __init__(self, path, query, allow_analyzed_field=None, score=None): + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.allow_analyzed_field = cast_as_value(allow_analyzed_field) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.query, self.allow_analyzed_field] + + def set_source_expressions(self, exprs): + self.path, self.query, self.allow_analyzed_field = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.value, + } + if self.score: + params["score"] = self.score.query.as_mql(compiler, connection) + if self.allow_analyzed_field is not None: + params["allowAnalyzedField"] = self.allow_analyzed_field.value + return {"wildcard": params} + + +class SearchGeoShape(SearchExpression): + """ + Atlas Search expression that filters documents by spatial relationship with a geometry. + + This expression uses the **geoShape** operator to match documents where + a geo field relates to a specified geometry by a spatial relation. + + Example: + SearchGeoShape("location", "within", {"type": "Polygon", "coordinates": [...]}) + + Args: + path: The document path to the geo field (as string or expression). + relation: The spatial relation to test (e.g., "within", "intersects", "disjoint"). + geometry: The GeoJSON geometry to compare against. + score: Optional expression to modify the relevance score. + + Notes: + * The field must be indexed as a geo shape type in Atlas Search. + * Geometry must conform to GeoJSON specification. + """ + + def __init__(self, path, relation, geometry, score=None): + self.path = cast_as_field(path) + self.relation = cast_as_value(relation) + self.geometry = cast_as_value(geometry) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.relation, self.geometry] + + def set_source_expressions(self, exprs): + self.path, self.relation, self.geometry = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + "relation": self.relation.value, + "geometry": self.geometry.value, + } + if self.score: + params["score"] = self.score.as_mql(compiler, connection) + return {"geoShape": params} + + +class SearchGeoWithin(SearchExpression): + """ + Atlas Search expression that filters documents with geo fields + contained within a specified shape. + + This expression uses the **geoWithin** operator to match documents where + the geo field lies entirely within the given geometry. + + Example: + SearchGeoWithin("location", "Polygon", {"type": "Polygon", "coordinates": [...]}) + + Args: + path: The document path to the geo field (as string or expression). + kind: The GeoJSON geometry type (e.g., "Polygon", "MultiPolygon"). + geo_object: The GeoJSON geometry defining the boundary. + score: Optional expression to adjust the relevance score. + + Notes: + * The geo field must be indexed appropriately in the Atlas Search index. + * The geometry must follow GeoJSON format. + """ + + def __init__(self, path, kind, geo_object, score=None): + self.path = cast_as_field(path) + self.kind = cast_as_value(kind) + self.geo_object = cast_as_value(geo_object) + self.score = score + super().__init__() + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.kind, self.geo_object] + + def set_source_expressions(self, exprs): + self.path, self.kind, self.geo_object = exprs + + def search_operator(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection, as_path=True), + self.kind.value: self.geo_object.value, + } + if self.score: + params["score"] = self.score.as_mql(compiler, connection) + return {"geoWithin": params} + + +class SearchMoreLikeThis(SearchExpression): + """ + Atlas Search expression that finds documents similar to given examples. + + This expression uses the **moreLikeThis** operator to search for documents + that resemble the specified sample documents. + + Example: + SearchMoreLikeThis([{"_id": ObjectId("...")}, {"title": "Example"}]) + + Args: + documents: A list of example documents or expressions to find similar documents. + score: Optional expression to modify the relevance scoring. + + Notes: + * The documents should be representative examples to base similarity on. + * Supports various field types depending on the Atlas Search configuration. + """ + + def __init__(self, documents, score=None): + self.documents = cast_as_value(documents) + self.score = score + super().__init__() + + def get_source_expressions(self): + return [self.documents] + + def set_source_expressions(self, exprs): + (self.documents,) = exprs + + def search_operator(self, compiler, connection): + params = { + "like": self.documents.as_mql(compiler, connection), + } + if self.score: + params["score"] = self.score.as_mql(compiler, connection) + return {"moreLikeThis": params} + + def get_search_fields(self, compiler, connection): + needed_fields = set() + for doc in self.documents: + needed_fields.update(set(doc.keys())) + return needed_fields + + +class CompoundExpression(SearchExpression): + """ + Compound expression that combines multiple search clauses using boolean logic. + + This expression corresponds to the **compound** operator in MongoDB Atlas Search, + allowing fine-grained control by combining multiple sub-expressions with + `must`, `must_not`, `should`, and `filter` clauses. + + Example: + CompoundExpression( + must=[expr1, expr2], + must_not=[expr3], + should=[expr4], + minimum_should_match=1 + ) + + Args: + must: List of expressions that **must** match. + must_not: List of expressions that **must not** match. + should: List of expressions that **should** match (optional relevance boost). + filter: List of expressions to filter results without affecting relevance. + score: Optional expression to adjust scoring. + minimum_should_match: Minimum number of `should` clauses that must match. + + Notes: + * This is the most flexible way to build complex Atlas Search queries. + * Supports nesting of expressions to any depth. + """ + + def __init__( + self, + must=None, + must_not=None, + should=None, + filter=None, + score=None, + minimum_should_match=None, + ): + self.must = must or [] + self.must_not = must_not or [] + self.should = should or [] + self.filter = filter or [] + self.score = score + self.minimum_should_match = minimum_should_match + + def get_search_fields(self, compiler, connection): + fields = set() + for clause in self.must + self.should + self.filter + self.must_not: + fields.update(clause.get_search_fields(compiler, connection)) + return fields + + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + c = self.copy() + c.is_summary = summarize + c.must = [ + expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.must + ] + c.must_not = [ + expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.must_not + ] + c.should = [ + expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.should + ] + c.filter = [ + expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.filter + ] + return c + + def search_operator(self, compiler, connection): + params = {} + if self.must: + params["must"] = [clause.search_operator(compiler, connection) for clause in self.must] + if self.must_not: + params["mustNot"] = [ + clause.search_operator(compiler, connection) for clause in self.must_not + ] + if self.should: + params["should"] = [ + clause.search_operator(compiler, connection) for clause in self.should + ] + if self.filter: + params["filter"] = [ + clause.search_operator(compiler, connection) for clause in self.filter + ] + if self.minimum_should_match is not None: + params["minimumShouldMatch"] = self.minimum_should_match + return {"compound": params} + + def negate(self): + return CompoundExpression(must_not=[self]) + + +class CombinedSearchExpression(SearchExpression): + """ + Combines two search expressions with a logical operator. + + This expression allows combining two Atlas Search expressions + (left-hand side and right-hand side) using a boolean operator + such as `and`, `or`, or `not`. + + Example: + CombinedSearchExpression(expr1, "and", expr2) + + Args: + lhs: The left-hand search expression. + operator: The boolean operator as a string (e.g., "and", "or", "not"). + rhs: The right-hand search expression. + + Notes: + * The operator must be supported by MongoDB Atlas Search boolean logic. + * This class enables building complex nested search queries. + """ + + def __init__(self, lhs, operator, rhs): + self.lhs = lhs + self.operator = operator + self.rhs = rhs + + def get_source_expressions(self): + return [self.lhs, self.rhs] + + def set_source_expressions(self, exprs): + self.lhs, self.rhs = exprs + + @staticmethod + def resolve(node, negated=False): + if node is None: + return None + # Leaf, resolve the compoundExpression + if isinstance(node, CompoundExpression): + return node.negate() if negated else node + # Apply De Morgan's Laws. + operator = node.operator.negate() if negated else node.operator + negated = negated != (node.operator == Operator.NOT) + lhs_compound = node.resolve(node.lhs, negated) + rhs_compound = node.resolve(node.rhs, negated) + if operator == Operator.OR: + return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1) + if operator == Operator.AND: + return CompoundExpression(must=[lhs_compound, rhs_compound]) + return lhs_compound + + def as_mql(self, compiler, connection): + expression = self.resolve(self) + return expression.as_mql(compiler, connection) + + +class SearchVector(SearchExpression): + """ + Atlas Search expression that performs vector similarity search on embedded vectors. + + This expression uses the **knnBeta** operator to find documents whose vector + embeddings are most similar to a given query vector. + + Example: + SearchVector("embedding", [0.1, 0.2, 0.3], limit=10, num_candidates=100) + + Args: + path: The document path to the vector field (as string or expression). + query_vector: The query vector to compare against. + limit: Maximum number of matching documents to return. + num_candidates: Optional number of candidates to consider during search. + exact: Optional flag to enforce exact matching. + filter: Optional filter expression to narrow candidate documents. + + Notes: + * The vector field must be indexed as a vector type in Atlas Search. + * Parameters like `num_candidates` and `exact` control search + performance and accuracy trade-offs. + """ + + def __init__( + self, + path, + query_vector, + limit, + num_candidates=None, + exact=None, + filter=None, + ): + self.path = cast_as_field(path) + self.query_vector = cast_as_value(query_vector) + self.limit = cast_as_value(limit) + self.num_candidates = cast_as_value(num_candidates) + self.exact = cast_as_value(exact) + self.filter = cast_as_value(filter) + super().__init__() + + def __invert__(self): + return ValueError("SearchVector cannot be negated") + + def __and__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __rand__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __or__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __ror__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [ + self.path, + self.query_vector, + self.limit, + self.num_candidates, + self.exact, + self.filter, + ] + + def set_source_expressions(self, exprs): + ( + self.path, + self.query_vector, + self.limit, + self.num_candidates, + self.exact, + self.filter, + ) = exprs + + def _get_query_index(self, fields, compiler): + for search_indexes in compiler.collection.list_search_indexes(): + if search_indexes["type"] == "vectorSearch": + index_field = { + field["path"] for field in search_indexes["latestDefinition"]["fields"] + } + if fields.issubset(index_field): + return search_indexes["name"] + return "default" + + def as_mql(self, compiler, connection): + params = { + "index": self._get_query_index(self.get_search_fields(compiler, connection), compiler), + "path": self.path.as_mql(compiler, connection, as_path=True), + "queryVector": self.query_vector.value, + "limit": self.limit.value, + } + if self.num_candidates is not None: + params["numCandidates"] = self.num_candidates.value + if self.exact is not None: + params["exact"] = self.exact.value + if self.filter is not None: + params["filter"] = self.filter.as_mql(compiler, connection) + return {"$vectorSearch": params} + + +class SearchScoreOption(Expression): + """Class to mutate scoring on a search operation""" + + def __init__(self, definitions=None): + self.definitions = definitions + + def as_mql(self, compiler, connection): + return self.definitions + + +class SearchTextLookup(Lookup): + lookup_name = "search" + + def __init__(self, lhs, rhs): + super().__init__(lhs, rhs) + self.lhs = SearchText(self.lhs, self.rhs) + self.rhs = Value(0) + + def __str__(self): + return f"SearchText({self.lhs}, {self.rhs})" + + def __repr__(self): + return f"SearchText({self.lhs}, {self.rhs})" + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + return {"$gte": [lhs_mql, value]} + + +CharField.register_lookup(SearchTextLookup) +TextField.register_lookup(SearchTextLookup) diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 4b49a4710..b7f562841 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -184,12 +184,16 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_path=False): previous = self key_transforms = [] while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs + if as_path: + mql = previous.as_mql(compiler, connection, as_path=True) + mql_path = ".".join(key_transforms) + return f"{mql}.{mql_path}" mql = previous.as_mql(compiler, connection) for key in key_transforms: mql = {"$getField": {"input": mql, "field": key}} diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index d59bc1631..e6290ead4 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -49,6 +49,7 @@ def __init__(self, compiler): self.lookup_pipeline = None self.project_fields = None self.aggregation_pipeline = compiler.aggregation_pipeline + self.search_pipeline = compiler.search_pipeline self.extra_fields = None self.combinator_pipeline = None # $lookup stage that encapsulates the pipeline for performing a nested @@ -81,6 +82,8 @@ def get_cursor(self): def get_pipeline(self): pipeline = [] + if self.search_pipeline: + pipeline.extend(self.search_pipeline) if self.lookup_pipeline: pipeline.extend(self.lookup_pipeline) for query in self.subqueries or (): diff --git a/tests/expressions_/test_combinable_search_expression.py b/tests/expressions_/test_combinable_search_expression.py new file mode 100644 index 000000000..2ff597050 --- /dev/null +++ b/tests/expressions_/test_combinable_search_expression.py @@ -0,0 +1,76 @@ +from django.test import SimpleTestCase + +from django_mongodb_backend.expressions.search import ( + CombinedSearchExpression, + CompoundExpression, + SearchEquals, +) + + +class CombinedSearchExpressionResolutionTest(SimpleTestCase): + def test_combined_expression_and_or_not_resolution(self): + A = SearchEquals(path="headline", value="A") + B = SearchEquals(path="headline", value="B") + C = SearchEquals(path="headline", value="C") + D = SearchEquals(path="headline", value="D") + expr = (~A | B) & (C | D) + solved = CombinedSearchExpression.resolve(expr) + self.assertIsInstance(solved, CompoundExpression) + solved_A = CompoundExpression(must_not=[CompoundExpression(must=[A])]) + solved_B = CompoundExpression(must=[B]) + solved_C = CompoundExpression(must=[C]) + solved_D = CompoundExpression(must=[D]) + self.assertCountEqual(solved.must[0].should, [solved_A, solved_B]) + self.assertEqual(solved.must[0].minimum_should_match, 1) + self.assertEqual(solved.must[1].should, [solved_C, solved_D]) + + def test_combined_expression_de_morgans_resolution(self): + A = SearchEquals(path="headline", value="A") + B = SearchEquals(path="headline", value="B") + C = SearchEquals(path="headline", value="C") + D = SearchEquals(path="headline", value="D") + expr = ~(A | B) & (C | D) + solved_A = CompoundExpression(must_not=[CompoundExpression(must=[A])]) + solved_B = CompoundExpression(must_not=[CompoundExpression(must=[B])]) + solved_C = CompoundExpression(must=[C]) + solved_D = CompoundExpression(must=[D]) + solved = CombinedSearchExpression.resolve(expr) + self.assertIsInstance(solved, CompoundExpression) + self.assertCountEqual(solved.must[0].must, [solved_A, solved_B]) + self.assertEqual(solved.must[0].minimum_should_match, None) + self.assertEqual(solved.must[1].should, [solved_C, solved_D]) + self.assertEqual(solved.minimum_should_match, None) + + def test_combined_expression_doble_negation(self): + A = SearchEquals(path="headline", value="A") + expr = ~~A + solved = CombinedSearchExpression.resolve(expr) + solved_A = CompoundExpression(must=[A]) + self.assertIsInstance(solved, CompoundExpression) + self.assertEqual(solved, solved_A) + + def test_combined_expression_long_right_tree(self): + A = SearchEquals(path="headline", value="A") + B = SearchEquals(path="headline", value="B") + C = SearchEquals(path="headline", value="C") + D = SearchEquals(path="headline", value="D") + solved_A = CompoundExpression(must=[A]) + solved_B = CompoundExpression(must_not=[CompoundExpression(must=[B])]) + solved_C = CompoundExpression(must=[C]) + solved_D = CompoundExpression(must=[D]) + expr = A & ~(B & ~(C & D)) + solved = CombinedSearchExpression.resolve(expr) + self.assertIsInstance(solved, CompoundExpression) + self.assertEqual(len(solved.must), 2) + self.assertEqual(solved.must[0], solved_A) + self.assertEqual(len(solved.must[1].should), 2) + self.assertEqual(solved.must[1].should[0], solved_B) + self.assertCountEqual(solved.must[1].should[1].must, [solved_C, solved_D]) + expr = A | ~(B | ~(C | D)) + solved = CombinedSearchExpression.resolve(expr) + self.assertIsInstance(solved, CompoundExpression) + self.assertEqual(len(solved.should), 2) + self.assertEqual(solved.should[0], solved_A) + self.assertEqual(len(solved.should[1].must), 2) + self.assertEqual(solved.should[1].must[0], solved_B) + self.assertCountEqual(solved.should[1].must[1].should, [solved_C, solved_D]) diff --git a/tests/queries_/models.py b/tests/queries_/models.py index 015102248..21af6fafd 100644 --- a/tests/queries_/models.py +++ b/tests/queries_/models.py @@ -1,6 +1,12 @@ from django.db import models -from django_mongodb_backend.fields import ObjectIdAutoField, ObjectIdField +from django_mongodb_backend.fields import ( + ArrayField, + EmbeddedModelField, + ObjectIdAutoField, + ObjectIdField, +) +from django_mongodb_backend.models import EmbeddedModel class Author(models.Model): @@ -53,3 +59,16 @@ class Meta: def __str__(self): return str(self.pk) + + +class Writer(EmbeddedModel): + name = models.CharField(max_length=10) + + +class Article(models.Model): + headline = models.CharField(max_length=100) + number = models.IntegerField() + body = models.TextField() + location = models.JSONField(null=True) + plot_embedding = ArrayField(models.FloatField(), size=3, null=True) + writer = EmbeddedModelField(Writer, null=True) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py new file mode 100644 index 000000000..f0d113ea2 --- /dev/null +++ b/tests/queries_/test_search.py @@ -0,0 +1,572 @@ +import unittest +from collections.abc import Callable +from time import monotonic, sleep + +from django.db import connection +from django.db.utils import DatabaseError +from django.test import TransactionTestCase, skipUnlessDBFeature +from pymongo.operations import SearchIndexModel + +from django_mongodb_backend.expressions.search import ( + CompoundExpression, + SearchAutocomplete, + SearchEquals, + SearchExists, + SearchGeoShape, + SearchGeoWithin, + SearchIn, + SearchMoreLikeThis, + SearchPhrase, + SearchRange, + SearchRegex, + SearchText, + SearchVector, + SearchWildcard, +) + +from .models import Article, Writer + + +def _wait_for_assertion(timeout: float = 120, interval: float = 0.5) -> None: + """Generic to block until the predicate returns true + + Args: + timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT. + interval (float, optional): Interval to check predicate. Defaults to DELAY. + + Raises: + AssertionError: _description_ + """ + + @staticmethod + def _inner_wait_loop(predicate: Callable): + """ + Waits until the given predicate stops raising AssertionError or DatabaseError. + + Args: + predicate (Callable): A function that raises AssertionError (or DatabaseError) + if a condition is not yet met. It should refresh its query each time + it's called (e.g., by using `qs.all()` to avoid cached results). + + Raises: + AssertionError or DatabaseError: If the predicate keeps failing beyond the timeout. + """ + start = monotonic() + while True: + try: + predicate() + except (AssertionError, DatabaseError): + if monotonic() - start > timeout: + raise + sleep(interval) + else: + break + + return _inner_wait_loop + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchUtilsMixin(TransactionTestCase): + available_apps = [] + + @staticmethod + def _get_collection(model): + return connection.database.get_collection(model._meta.db_table) + + @staticmethod + def create_search_index(model, index_name, definition, type="search"): + collection = SearchUtilsMixin._get_collection(model) + idx = SearchIndexModel(definition=definition, name=index_name, type=type) + collection.create_search_index(idx) + + def _tear_down(self, model): + collection = SearchUtilsMixin._get_collection(model) + for search_indexes in collection.list_search_indexes(): + collection.drop_search_index(search_indexes["name"]) + collection.delete_many({}) + + wait_for_assertion = _wait_for_assertion(timeout=3) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchEqualsTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "equals_headline_index", + {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, + ) + self.article = Article.objects.create(headline="cross", number=1, body="body") + Article.objects.create(headline="other thing", number=2, body="body") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_equals(self): + qs = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchAutocompleteTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "autocomplete_headline_index", + { + "mappings": { + "dynamic": False, + "fields": { + "headline": { + "type": "autocomplete", + "analyzer": "lucene.standard", + "tokenization": "edgeGram", + "minGrams": 3, + "maxGrams": 5, + "foldDiacritics": False, + }, + "writer": { + "type": "document", + "fields": { + "name": { + "type": "autocomplete", + "analyzer": "lucene.standard", + "tokenization": "edgeGram", + "minGrams": 3, + "maxGrams": 5, + "foldDiacritics": False, + } + }, + }, + }, + } + }, + ) + self.article = Article.objects.create( + headline="crossing and something", + number=2, + body="river", + writer=Writer(name="Joselina A. Ramirez"), + ) + Article.objects.create(headline="Some random text", number=3, body="river") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_autocomplete(self): + qs = Article.objects.annotate( + score=SearchAutocomplete( + path="headline", + query="crossing", + token_order="sequential", # noqa: S106 + fuzzy={"maxEdits": 2}, + ) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + def test_search_autocomplete_embedded_model(self): + qs = Article.objects.annotate( + score=SearchAutocomplete(path="writer__name", query="Joselina") + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchExistsTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "exists_body_index", + {"mappings": {"dynamic": False, "fields": {"body": {"type": "token"}}}}, + ) + self.article = Article.objects.create(headline="ignored", number=3, body="something") + + def test_search_exists(self): + qs = Article.objects.annotate(score=SearchExists(path="body")) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchInTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "in_headline_index", + {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, + ) + self.article = Article.objects.create(headline="cross", number=1, body="a") + Article.objects.create(headline="road", number=2, body="b") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_in(self): + qs = Article.objects.annotate(score=SearchIn(path="headline", value=["cross", "river"])) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchPhraseTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "phrase_body_index", + {"mappings": {"dynamic": False, "fields": {"body": {"type": "string"}}}}, + ) + self.article = Article.objects.create( + headline="irrelevant", number=1, body="the quick brown fox" + ) + Article.objects.create(headline="cheetah", number=2, body="fastest animal") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_phrase(self): + qs = Article.objects.annotate(score=SearchPhrase(path="body", query="quick brown")) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchRangeTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "range_number_index", + {"mappings": {"dynamic": False, "fields": {"number": {"type": "number"}}}}, + ) + Article.objects.create(headline="x", number=5, body="z") + self.number20 = Article.objects.create(headline="y", number=20, body="z") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_range(self): + qs = Article.objects.annotate(score=SearchRange(path="number", gte=10, lt=30)) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.number20])) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchRegexTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "regex_headline_index", + { + "mappings": { + "dynamic": False, + "fields": {"headline": {"type": "string", "analyzer": "lucene.keyword"}}, + } + }, + ) + self.article = Article.objects.create(headline="hello world", number=1, body="abc") + Article.objects.create(headline="hola mundo", number=2, body="abc") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_regex(self): + qs = Article.objects.annotate( + score=SearchRegex(path="headline", query="hello.*", allow_analyzed_field=True) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchTextTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "text_body_index", + {"mappings": {"dynamic": False, "fields": {"body": {"type": "string"}}}}, + ) + self.article = Article.objects.create( + headline="ignored", number=1, body="The lazy dog sleeps" + ) + Article.objects.create(headline="ignored", number=2, body="The sleepy bear") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_text(self): + qs = Article.objects.annotate(score=SearchText(path="body", query="lazy")) + self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs.all())) + + def test_search_lookup(self): + qs = Article.objects.filter(body__search="lazy") + self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs.all())) + + def test_search_text_with_fuzzy_and_criteria(self): + qs = Article.objects.annotate( + score=SearchText( + path="body", query="lazzy", fuzzy={"maxEdits": 2}, match_criteria="all" + ) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchWildcardTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "wildcard_headline_index", + { + "mappings": { + "dynamic": False, + "fields": {"headline": {"type": "string", "analyzer": "lucene.keyword"}}, + } + }, + ) + self.article = Article.objects.create(headline="dark-knight", number=1, body="") + Article.objects.create(headline="batman", number=2, body="") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_wildcard(self): + qs = Article.objects.annotate(score=SearchWildcard(path="headline", query="dark-*")) + self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs)) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchGeoShapeTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "geoshape_location_index", + { + "mappings": { + "dynamic": False, + "fields": {"location": {"type": "geo", "indexShapes": True}}, + } + }, + ) + self.article = Article.objects.create( + headline="any", number=1, body="", location={"type": "Point", "coordinates": [40, 5]} + ) + Article.objects.create( + headline="any", number=2, body="", location={"type": "Point", "coordinates": [400, 50]} + ) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_geo_shape(self): + polygon = { + "type": "Polygon", + "coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]], + } + qs = Article.objects.annotate( + score=SearchGeoShape(path="location", relation="within", geometry=polygon) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchGeoWithinTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "geowithin_location_index", + {"mappings": {"dynamic": False, "fields": {"location": {"type": "geo"}}}}, + ) + self.article = Article.objects.create( + headline="geo", number=2, body="", location={"type": "Point", "coordinates": [40, 5]} + ) + Article.objects.create( + headline="geo2", number=3, body="", location={"type": "Point", "coordinates": [-40, -5]} + ) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_geo_within(self): + polygon = { + "type": "Polygon", + "coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]], + } + qs = Article.objects.annotate( + score=SearchGeoWithin( + path="location", + kind="geometry", + geo_object=polygon, + ) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + +@skipUnlessDBFeature("supports_atlas_search") +@unittest.expectedFailure +class SearchMoreLikeThisTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "mlt_index", + { + "mappings": { + "dynamic": False, + "fields": {"body": {"type": "string"}, "headline": {"type": "string"}}, + } + }, + ) + self.article1 = Article.objects.create( + headline="Space exploration", number=1, body="Webb telescope" + ) + self.article2 = Article.objects.create( + headline="The commodities fall", + number=2, + body="Commodities dropped sharply due to inflation concerns", + ) + Article.objects.create( + headline="irrelevant", + number=3, + body="This is a completely unrelated article about cooking", + ) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_search_more_like_this(self): + like_docs = [ + {"headline": self.article1.headline, "body": self.article1.body}, + {"headline": self.article2.headline, "body": self.article2.body}, + ] + like_docs = [{"body": "NASA launches new satellite to explore the galaxy"}] + qs = Article.objects.annotate(score=SearchMoreLikeThis(documents=like_docs)).order_by( + "score" + ) + self.wait_for_assertion( + lambda: self.assertQuerySetEqual( + qs.all(), [self.article1, self.article2], lambda a: a.headline + ) + ) + + +@skipUnlessDBFeature("supports_atlas_search") +class CompoundSearchTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "compound_index", + { + "mappings": { + "dynamic": False, + "fields": { + "headline": {"type": "token"}, + "body": {"type": "string"}, + "number": {"type": "number"}, + }, + } + }, + ) + self.mars_mission = Article.objects.create( + number=1, + headline="space exploration", + body="NASA launches a new mission to Mars, aiming to study surface geology", + ) + + self.exoplanet = Article.objects.create( + number=2, + headline="space exploration", + body="Astronomers discover exoplanets orbiting distant stars using Webb telescope", + ) + + self.icy_moons = Article.objects.create( + number=3, + headline="space exploration", + body="ESA prepares a robotic expedition to explore the icy moons of Jupiter", + ) + + self.comodities_drop = Article.objects.create( + number=4, + headline="astronomy news", + body="Commodities dropped sharply due to inflation concerns", + ) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_compound_expression(self): + must_expr = SearchEquals(path="headline", value="space exploration") + must_not_expr = SearchPhrase(path="body", query="icy moons") + should_expr = SearchPhrase(path="body", query="exoplanets") + + compound = CompoundExpression( + must=[must_expr or should_expr], + must_not=[must_not_expr], + should=[should_expr], + minimum_should_match=1, + ) + + qs = Article.objects.annotate(score=compound).order_by("score") + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.exoplanet])) + + def test_compound_operations(self): + expr = SearchEquals(path="headline", value="space exploration") & ~SearchEquals( + path="number", value=3 + ) + qs = Article.objects.annotate(score=expr) + self.wait_for_assertion( + lambda: self.assertCountEqual(qs.all(), [self.mars_mission, self.exoplanet]) + ) + + +@skipUnlessDBFeature("supports_atlas_search") +class SearchVectorTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( + Article, + "vector_index", + { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 3, + "similarity": "cosine", + "quantization": "scalar", + } + ] + }, + type="vectorSearch", + ) + + self.mars = Article.objects.create( + headline="Mars landing", + number=1, + body="The rover has landed on Mars", + plot_embedding=[0.1, 0.2, 0.3], + ) + self.cooking = Article.objects.create( + headline="Cooking tips", + number=2, + body="This article is about pasta", + plot_embedding=[0.9, 0.8, 0.7], + ) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() + + def test_vector_search(self): + vector_query = [0.1, 0.2, 0.3] + expr = SearchVector( + path="plot_embedding", + query_vector=vector_query, + num_candidates=5, + limit=2, + ) + qs = Article.objects.annotate(score=expr).order_by("-score") + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.mars, self.cooking]))