diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index c5e6a7b7f..89e4edd0e 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -21,9 +21,10 @@ from __future__ import annotations import abc +import itertools from abc import ABC -from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union, Iterable from google.api_core import gapic_v1 from google.api_core import retry as retries @@ -33,6 +34,10 @@ from google.cloud.firestore_v1.types import ( StructuredAggregationQuery, ) +from google.cloud.firestore_v1.pipeline_expressions import Accumulator +from google.cloud.firestore_v1.pipeline_expressions import Count +from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias +from google.cloud.firestore_v1.pipeline_expressions import Field # Types needed only for Type Hints if TYPE_CHECKING: # pragma: NO COVER @@ -66,6 +71,9 @@ def __init__(self, alias: str, value: float, read_time=None): def __repr__(self): return f"" + def _to_dict(self): + return {self.alias: self.value} + class BaseAggregation(ABC): def __init__(self, alias: str | None = None): @@ -75,6 +83,27 @@ def __init__(self, alias: str | None = None): def _to_protobuf(self): """Convert this instance to the protobuf representation""" + @abc.abstractmethod + def _to_pipeline_expr( + self, autoindexer: Iterable[int] + ) -> ExprWithAlias[Accumulator]: + """ + Convert this instance to a pipeline expression for use with pipeline.aggregate() + + Args: + autoindexer: If an alias isn't supplied, an one should be created with the format "field_n" + The autoindexer is an iterable that provides the `n` value to use for each expression + """ + + def _pipeline_alias(self, autoindexer): + """ + Helper to build the alias for the pipeline expression + """ + if self.alias is not None: + return self.alias + else: + return f"field_{next(autoindexer)}" + class CountAggregation(BaseAggregation): def __init__(self, alias: str | None = None): @@ -88,6 +117,9 @@ def _to_protobuf(self): aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Count().as_(self._pipeline_alias(autoindexer)) + class SumAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): @@ -107,6 +139,9 @@ def _to_protobuf(self): aggregation_pb.sum.field.field_path = self.field_ref return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Field.of(self.field_ref).sum().as_(self._pipeline_alias(autoindexer)) + class AvgAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): @@ -126,6 +161,9 @@ def _to_protobuf(self): aggregation_pb.avg.field.field_path = self.field_ref return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Field.of(self.field_ref).avg().as_(self._pipeline_alias(autoindexer)) + def _query_response_to_result( response_pb, @@ -317,3 +355,20 @@ def stream( StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]: A generator of the query results. """ + + def pipeline(self): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Raises: + - ValueError: raised if Query wasn't created with an associated client + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + # use autoindexer to keep track of which field number to use for un-aliased fields + autoindexer = itertools.count(start=1) + exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations] + return self._nested_query.pipeline().aggregate(*exprs) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 1b1ef0411..a4cc2b1b7 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -602,6 +602,19 @@ def find_nearest( distance_threshold=distance_threshold, ) + def pipeline(self): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Raises: + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + return self._query().pipeline() + def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 14df886bc..245605afc 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -59,6 +59,7 @@ query, ) from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1 import pipeline_expressions if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -1128,6 +1129,74 @@ def recursive(self: QueryType) -> QueryType: return copied + def pipeline(self): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Raises: + - ValueError: raised if Query wasn't created with an associated client + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + if not self._client: + raise ValueError("Query does not have an associated client") + if self._all_descendants: + ppl = self._client.pipeline().collection_group(self._parent.id) + else: + ppl = self._client.pipeline().collection(self._parent._path) + + # Filters + for filter_ in self._field_filters: + ppl = ppl.where( + pipeline_expressions.FilterCondition._from_query_filter_pb( + filter_, self._client + ) + ) + + # Projections + if self._projection and self._projection.fields: + ppl = ppl.select(*[field.field_path for field in self._projection.fields]) + + # Orders + orders = self._normalize_orders() + if orders: + exists = [] + orderings = [] + for order in orders: + field = pipeline_expressions.Field.of(order.field.field_path) + exists.append(field.exists()) + direction = ( + "ascending" + if order.direction == StructuredQuery.Direction.ASCENDING + else "descending" + ) + orderings.append(pipeline_expressions.Ordering(field, direction)) + + # Add exists filters to match Query's implicit orderby semantics. + if len(exists) == 1: + ppl = ppl.where(exists[0]) + else: + ppl = ppl.where(pipeline_expressions.And(*exists)) + + # Add sort orderings + ppl = ppl.sort(*orderings) + + # Cursors, Limit and Offset + if self._start_at or self._end_at or self._limit_to_last: + raise NotImplementedError( + "Query to Pipeline conversion: cursors and limit_to_last is not supported yet." + ) + else: # Limit & Offset without cursors + if self._offset: + ppl = ppl.offset(self._offset) + if self._limit: + ppl = ppl.limit(self._limit) + + return ppl + def _comparator(self, doc1, doc2) -> int: _orders = self._orders diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index c146a5763..7840f4144 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -19,4 +19,4 @@ # run all tests against default database, and a named database # TODO: add enterprise mode when GA (RunQuery not currently supported) -TEST_DATABASES = [None, FIRESTORE_OTHER_DB] +TEST_DATABASES = [None, FIRESTORE_OTHER_DB, FIRESTORE_ENTERPRISE_DB] diff --git a/tests/system/test_system.py b/tests/system/test_system.py index bd12815f2..317c1c7e1 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -42,6 +42,7 @@ MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + ENTERPRISE_MODE_ERROR, TEST_DATABASES, ) @@ -80,6 +81,44 @@ def cleanup(): operation() +def verify_pipeline(query): + """ + This function ensures a pipeline produces the same + results as the query it is derived from + + It can be attached to existing query tests to check both + modalities at the same time + """ + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = [[a._to_dict() for a in s] for s in query.get()] + else: + # other qureies return a simple list of results + query_results = [s.to_dict() for s in query.get()] + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + pipeline = query.pipeline() + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + pipeline.execute() + else: + # ensure results match query + pipeline_results = [s.data() for s in pipeline.execute()] + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collections(client, database): collections = list(client.collections()) @@ -1247,6 +1286,7 @@ def test_query_stream_legacy_where(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1258,6 +1298,7 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1269,6 +1310,7 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1281,6 +1323,7 @@ def test_query_stream_w_simple_field_in_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1303,6 +1346,7 @@ def test_query_stream_w_not_eq_op(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1315,6 +1359,7 @@ def test_query_stream_w_simple_not_in_op(query_docs, database): values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} assert len(values) == 22 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1329,6 +1374,7 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database) for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1343,6 +1389,7 @@ def test_query_stream_w_order_by(query_docs, database): b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. assert sorted(b_vals, reverse=True) == b_vals + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1381,6 +1428,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database): for key, value in values: assert stored[key] == value assert value["a"] == num_vals - 2 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1390,6 +1438,7 @@ def test_query_stream_wo_results(query_docs, database): query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) values = list(query.stream()) assert len(values) == 0 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1407,6 +1456,7 @@ def test_query_stream_w_projection(query_docs, database): "stats": {"product": stored[key]["stats"]["product"]}, } assert expected == value + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1427,6 +1477,7 @@ def test_query_stream_w_multiple_filters(query_docs, database): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1443,6 +1494,7 @@ def test_query_stream_w_offset(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["b"] == 2 + verify_pipeline(query) @pytest.mark.skipif( @@ -1607,6 +1659,7 @@ def test_query_stream_w_read_time(query_docs, cleanup, database): assert len(new_values) == num_vals + 1 assert new_ref.id in new_values assert new_values[new_ref.id] == new_data + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1622,9 +1675,11 @@ def test_query_with_order_dot_key(client, cleanup, database): query = collection.order_by("wordcount.page1").limit(3) data = [doc.to_dict()["wordcount"]["page1"] for doc in query.stream()] assert [100, 110, 120] == data + verify_pipeline(query) query2 = collection.order_by("wordcount.page1").limit(3) for snapshot in query2.stream(): last_value = snapshot.get("wordcount.page1") + verify_pipeline(query2) cursor_with_nested_keys = {"wordcount": {"page1": last_value}} query3 = ( collection.order_by("wordcount.page1") @@ -1638,6 +1693,7 @@ def test_query_with_order_dot_key(client, cleanup, database): {"count": 50, "wordcount": {"page1": 150}}, ] assert found_data == [snap.to_dict() for snap in found] + verify_pipeline(query3) cursor_with_dotted_paths = {"wordcount.page1": last_value} query4 = ( collection.order_by("wordcount.page1") @@ -1646,6 +1702,7 @@ def test_query_with_order_dot_key(client, cleanup, database): ) cursor_with_key_data = list(query4.stream()) assert found_data == [snap.to_dict() for snap in cursor_with_key_data] + verify_pipeline(query4) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1674,6 +1731,7 @@ def test_query_unary(client, cleanup, database): snapshot0 = values0[0] assert snapshot0.reference._path == document0._path assert snapshot0.to_dict() == {field_name: None} + verify_pipeline(query0) # 1. Query for a NAN. query1 = collection.where(filter=FieldFilter(field_name, "==", nan_val)) @@ -1684,6 +1742,7 @@ def test_query_unary(client, cleanup, database): data1 = snapshot1.to_dict() assert len(data1) == 1 assert math.isnan(data1[field_name]) + verify_pipeline(query1) # 2. Query for not null query2 = collection.where(filter=FieldFilter(field_name, "!=", None)) @@ -1734,6 +1793,7 @@ def test_collection_group_queries(client, cleanup, database): found = [snapshot.id for snapshot in snapshots] expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] assert found == expected + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1767,6 +1827,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1777,6 +1838,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1822,6 +1884,7 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1843,6 +1906,7 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.skipif( @@ -2152,6 +2216,7 @@ def on_snapshot(docs, changes, read_time): query_ran_query = collection_ref.where(filter=FieldFilter("first", "==", "Ada")) query_ran = query_ran_query.stream() assert len(docs) == len([i for i in query_ran]) + verify_pipeline(query_ran_query) on_snapshot.called_count = 0 @@ -2389,6 +2454,7 @@ def test_recursive_query(client, cleanup, database): f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" ) assert ids[index] == expected_ids[index], error_msg + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -2414,6 +2480,7 @@ def test_nested_recursive_query(client, cleanup, database): f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" ) assert ids[index] == expected_ids[index], error_msg + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -2529,6 +2596,7 @@ def on_snapshot(docs, changes, read_time): ), "expect the sort order to match, born" on_snapshot.called_count += 1 on_snapshot.last_doc_count = len(docs) + verify_pipeline(query_ref) except Exception as e: on_snapshot.failed = e @@ -2594,6 +2662,8 @@ def test_repro_429(client, cleanup, database): for snapshot in query2.stream(): print(f"id: {snapshot.id}") + verify_pipeline(query) + verify_pipeline(query2) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -3175,6 +3245,7 @@ def test_query_with_and_composite_filter(collection, database): for result in query.stream(): assert result.get("stats.product") > 5 assert result.get("stats.product") < 10 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -3198,6 +3269,7 @@ def test_query_with_or_composite_filter(collection, database): assert gt_5 > 0 assert lt_10 > 0 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -3245,6 +3317,7 @@ def test_aggregation_queries_with_read_time( assert len(old_result) == 1 for r in old_result[0]: assert r.value == expected_value + verify_pipeline(aggregation_query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -3268,6 +3341,7 @@ def test_query_with_complex_composite_filter(collection, database): assert sum_0 > 0 assert sum_4 > 0 + verify_pipeline(query) # b == 3 || (stats.sum == 4 && a == 4) comp_filter = Or( @@ -3290,6 +3364,7 @@ def test_query_with_complex_composite_filter(collection, database): assert b_3 is True assert b_not_3 is True + verify_pipeline(query) @pytest.mark.parametrize( @@ -3337,6 +3412,7 @@ def in_transaction(transaction): assert len(result[0]) == 1 assert result[0][0].value == expected inner_fn_ran = True + verify_pipeline(aggregation_query) in_transaction(transaction) # make sure we didn't skip assertions in inner function @@ -3382,6 +3458,7 @@ def in_transaction(transaction): result[0].get("b") == 2 and result[1].get("b") == 1 ) inner_fn_ran = True + verify_pipeline(query) in_transaction(transaction) # make sure we didn't skip assertions in inner function @@ -3459,6 +3536,7 @@ def in_transaction(transaction): assert explain_metrics.execution_stats is not None inner_fn_ran = True + verify_pipeline(query) in_transaction(transaction) # make sure we didn't skip assertions in inner function diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 69ca69ec7..5064e87ae 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -20,6 +20,7 @@ from google.cloud.firestore_v1.base_aggregation import ( AggregationResult, AvgAggregation, + BaseAggregation, CountAggregation, SumAggregation, ) @@ -27,6 +28,7 @@ from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator from google.cloud.firestore_v1.types import RunAggregationQueryResponse +from google.cloud.firestore_v1.field_path import FieldPath from google.protobuf.timestamp_pb2 import Timestamp from tests.unit.v1._test_helpers import ( make_aggregation_query, @@ -121,6 +123,65 @@ def test_avg_aggregation_no_alias_to_pb(): assert got_pb.alias == "" +@pytest.mark.parametrize( + "in_alias,expected_alias", [("total", "total"), (None, "field_1")] +) +def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import Count + + count_aggregation = CountAggregation(alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, ExprWithAlias) + assert got.alias == expected_alias + assert isinstance(got.expr, Count) + assert len(got.expr.params) == 0 + + +@pytest.mark.parametrize( + "in_alias,expected_path,expected_alias", + [("total", "path", "total"), (None, "some_ref", "field_1")], +) +def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import Sum + + count_aggregation = SumAggregation(expected_path, alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, ExprWithAlias) + assert got.alias == expected_alias + assert isinstance(got.expr, Sum) + assert got.expr.params[0].path == expected_path + + +@pytest.mark.parametrize( + "in_alias,expected_path,expected_alias", + [("total", "path", "total"), (None, "some_ref", "field_1")], +) +def test_avg_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import Avg + + count_aggregation = AvgAggregation(expected_path, alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, ExprWithAlias) + assert got.alias == expected_alias + assert isinstance(got.expr, Avg) + assert got.expr.params[0].path == expected_path + + +def test_aggregation__pipeline_alias_increment(): + """ + BaseAggregation.__pipeline_alias should pull from an autoindexer to populate field numbers + """ + autoindex = iter(range(10)) + mock_instance = mock.Mock() + mock_instance.alias = None + for i in range(10): + got_name = BaseAggregation._pipeline_alias(mock_instance, autoindex) + assert got_name == f"field_{i}" + + def test_aggregation_query_constructor(): client = make_client() parent = client.collection("dee") @@ -894,6 +955,16 @@ def test_aggregation_query_stream_w_explain_options_analyze_false(): _aggregation_query_stream_helper(explain_options=ExplainOptions(analyze=False)) +def test_aggretgation__to_dict(): + expected_alias = "alias" + expected_value = "value" + instance = AggregationResult(alias=expected_alias, value=expected_value) + dict_result = instance._to_dict() + assert len(dict_result) == 1 + assert next(iter(dict_result)) == expected_alias + assert dict_result[expected_alias] == expected_value + + def test_aggregation_from_query(): from google.cloud.firestore_v1 import _helpers @@ -952,3 +1023,147 @@ def test_aggregation_from_query(): metadata=client._rpc_metadata, **kwargs, ) + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Sum + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.sum(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Avg + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.avg(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "in_alias,out_alias", + [ + (None, "field_1"), + ("overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_count(in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Count + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.count(alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Count) + assert aggregate_stage.accumulators[0].alias == out_alias + + +def test_aggreation_to_pipeline_count_increment(): + """ + When aliases aren't given, should assign incrementing field_n values + """ + from google.cloud.firestore_v1.pipeline_expressions import Count + + n = 100 + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + for _ in range(n): + aggregation_query.count() + pipeline = aggregation_query.pipeline() + aggregate_stage = pipeline.stages[1] + assert len(aggregate_stage.accumulators) == n + for i in range(n): + assert isinstance(aggregate_stage.accumulators[i].expr, Count) + assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + + +def test_aggreation_to_pipeline_complex(): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select + from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + + client = make_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + aggregation_query = make_aggregation_query(query) + aggregation_query.sum("field", alias="alias") + aggregation_query.count() + aggregation_query.avg("other") + aggregation_query.sum("final") + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 3 + assert isinstance(pipeline.stages[0], Collection) + assert isinstance(pipeline.stages[1], Select) + assert isinstance(pipeline.stages[2], Aggregate) + aggregate_stage = pipeline.stages[2] + assert len(aggregate_stage.accumulators) == 4 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].alias == "alias" + assert isinstance(aggregate_stage.accumulators[1].expr, Count) + assert aggregate_stage.accumulators[1].alias == "field_1" + assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert aggregate_stage.accumulators[2].alias == "field_2" + assert isinstance(aggregate_stage.accumulators[3].expr, Sum) + assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 9140f53e8..fdd4a1450 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -31,6 +31,7 @@ from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.query_profile import ExplainMetrics, QueryExplainError from google.cloud.firestore_v1.query_results import QueryResultsList +from google.cloud.firestore_v1.field_path import FieldPath _PROJECT = "PROJECT" @@ -696,3 +697,147 @@ async def test_aggregation_query_stream_w_explain_options_analyze_false(): explain_options = ExplainOptions(analyze=False) await _async_aggregation_query_stream_helper(explain_options=explain_options) + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Sum + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.sum(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Avg + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.avg(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "in_alias,out_alias", + [ + (None, "field_1"), + ("overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_count(in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Count + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.count(alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Count) + assert aggregate_stage.accumulators[0].alias == out_alias + + +def test_aggreation_to_pipeline_count_increment(): + """ + When aliases aren't given, should assign incrementing field_n values + """ + from google.cloud.firestore_v1.pipeline_expressions import Count + + n = 100 + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + for _ in range(n): + aggregation_query.count() + pipeline = aggregation_query.pipeline() + aggregate_stage = pipeline.stages[1] + assert len(aggregate_stage.accumulators) == n + for i in range(n): + assert isinstance(aggregate_stage.accumulators[i].expr, Count) + assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + + +def test_async_aggreation_to_pipeline_complex(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select + from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + + client = make_async_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.sum("field", alias="alias") + aggregation_query.count() + aggregation_query.avg("other") + aggregation_query.sum("final") + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 3 + assert isinstance(pipeline.stages[0], Collection) + assert isinstance(pipeline.stages[1], Select) + assert isinstance(pipeline.stages[2], Aggregate) + aggregate_stage = pipeline.stages[2] + assert len(aggregate_stage.accumulators) == 4 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].alias == "alias" + assert isinstance(aggregate_stage.accumulators[1].expr, Count) + assert aggregate_stage.accumulators[1].alias == "field_1" + assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert aggregate_stage.accumulators[2].alias == "field_2" + assert isinstance(aggregate_stage.accumulators[3].expr, Sum) + assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index a0194ace5..353997b8e 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -601,3 +601,23 @@ def test_asynccollectionreference_recursive(): col = _make_async_collection_reference("collection") assert isinstance(col.recursive(), AsyncQuery) + + +def test_asynccollectionreference_pipeline(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection + + client = make_async_client() + collection = _make_async_collection_reference("collection", client=client) + pipeline = collection.pipeline() + assert isinstance(pipeline, AsyncPipeline) + # should have single "Collection" stage + assert len(pipeline.stages) == 1 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/collection" + + +def test_asynccollectionreference_pipeline_no_client(): + collection = _make_async_collection_reference("collection") + with pytest.raises(ValueError, match="client"): + collection.pipeline() diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 54c80e5ad..dc5eb9e8a 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -909,3 +909,22 @@ async def test_asynccollectiongroup_get_partitions_w_offset(): query = _make_async_collection_group(parent).offset(10) with pytest.raises(ValueError): [i async for i in query.get_partitions(2)] + + +def test_asyncquery_collection_pipeline_type(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = make_async_client() + parent = client.collection("test") + query = parent._query() + ppl = query.pipeline() + assert isinstance(ppl, AsyncPipeline) + + +def test_asyncquery_collectiongroup_pipeline_type(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = make_async_client() + query = client.collection_group("test") + ppl = query.pipeline() + assert isinstance(ppl, AsyncPipeline) diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index 22baa0c5f..7f7be9c07 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -422,6 +422,20 @@ def test_basecollectionreference_end_at(mock_query): assert query == mock_query.end_at.return_value +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_pipeline(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = _make_base_collection_reference("collection") + pipeline = collection.pipeline() + + mock_query.pipeline.assert_called_once_with() + assert pipeline == mock_query.pipeline.return_value + + @mock.patch("random.choice") def test__auto_id(mock_rand_choice): from google.cloud.firestore_v1.base_collection import _AUTO_ID_CHARS, _auto_id diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 7f6b0e5e2..a0abdccb2 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -18,6 +18,7 @@ import pytest from tests.unit.v1._test_helpers import make_client +from google.cloud.firestore_v1 import _pipeline_stages as stages def _make_base_query(*args, **kwargs): @@ -1980,6 +1981,175 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.update_time == response_pb._pb.document.update_time +def test__query_pipeline_no_client(): + mock_parent = mock.Mock() + mock_parent._client = None + query = _make_base_query(mock_parent) + with pytest.raises(ValueError, match="client"): + query.pipeline() + + +def test__query_pipeline_decendants(): + client = make_client() + query = client.collection_group("my_col") + pipeline = query.pipeline() + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, stages.CollectionGroup) + assert stage.collection_id == "my_col" + + +@pytest.mark.parametrize( + "in_path,out_path", + [ + ("my_col/doc/", "/my_col/doc/"), + ("/my_col/doc", "/my_col/doc"), + ("my_col/doc/sub_col", "/my_col/doc/sub_col"), + ], +) +def test__query_pipeline_no_decendants(in_path, out_path): + client = make_client() + collection = client.collection(in_path) + query = collection._query() + pipeline = query.pipeline() + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, stages.Collection) + assert stage.path == out_path + + +def test__query_pipeline_composite_filter(): + from google.cloud.firestore_v1 import FieldFilter + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + in_filter = FieldFilter("field_a", "==", "value_a") + query = client.collection("my_col").where(filter=in_filter) + with mock.patch.object( + expr.FilterCondition, "_from_query_filter_pb" + ) as convert_mock: + pipeline = query.pipeline() + convert_mock.assert_called_once_with(in_filter._to_pb(), client) + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Where) + assert stage.condition == convert_mock.return_value + + +def test__query_pipeline_projections(): + client = make_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Select) + assert len(stage.projections) == 2 + assert stage.projections[0].path == "field_a" + assert stage.projections[1].path == "field_b.c" + + +def test__query_pipeline_order_exists_multiple(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query = client.collection("my_col").order_by("field_a").order_by("field_b") + pipeline = query.pipeline() + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline.stages) == 3 + where_stage = pipeline.stages[1] + assert isinstance(where_stage, stages.Where) + # should have and with both orderings + assert isinstance(where_stage.condition, expr.And) + assert len(where_stage.condition.params) == 2 + operands = [p for p in where_stage.condition.params] + assert isinstance(operands[0], expr.Exists) + assert operands[0].params[0].path == "field_a" + assert isinstance(operands[1], expr.Exists) + assert operands[1].params[0].path == "field_b" + + +def test__query_pipeline_order_exists_single(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query_single = client.collection("my_col").order_by("field_c") + pipeline_single = query_single.pipeline() + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline_single.stages) == 3 + where_stage_single = pipeline_single.stages[1] + assert isinstance(where_stage_single, stages.Where) + assert isinstance(where_stage_single.condition, expr.Exists) + assert where_stage_single.condition.params[0].path == "field_c" + + +def test__query_pipeline_order_sorts(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + from google.cloud.firestore_v1.base_query import BaseQuery + + client = make_client() + query = ( + client.collection("my_col") + .order_by("field_a", direction=BaseQuery.ASCENDING) + .order_by("field_b", direction=BaseQuery.DESCENDING) + ) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 3 + sort_stage = pipeline.stages[2] + assert isinstance(sort_stage, stages.Sort) + assert len(sort_stage.orders) == 2 + assert isinstance(sort_stage.orders[0], expr.Ordering) + assert sort_stage.orders[0].expr.path == "field_a" + assert sort_stage.orders[0].order_dir == expr.Ordering.Direction.ASCENDING + assert isinstance(sort_stage.orders[1], expr.Ordering) + assert sort_stage.orders[1].expr.path == "field_b" + assert sort_stage.orders[1].order_dir == expr.Ordering.Direction.DESCENDING + + +def test__query_pipeline_unsupported(): + client = make_client() + query_start = client.collection("my_col").start_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_start.pipeline() + + query_end = client.collection("my_col").end_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_end.pipeline() + + query_limit_last = client.collection("my_col").limit_to_last(10) + with pytest.raises(NotImplementedError, match="limit_to_last"): + query_limit_last.pipeline() + + +def test__query_pipeline_limit(): + client = make_client() + query = client.collection("my_col").limit(15) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Limit) + assert stage.limit == 15 + + +def test__query_pipeline_offset(): + client = make_client() + query = client.collection("my_col").offset(5) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Offset) + assert stage.offset == 5 + + def _make_order_pb(field_path, direction): from google.cloud.firestore_v1.types import query diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index da91651b9..9e615541a 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -15,6 +15,7 @@ import types import mock +import pytest from datetime import datetime, timezone from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT @@ -510,6 +511,27 @@ def test_stream_w_read_time(query_class): ) +def test_collectionreference_pipeline(): + from tests.unit.v1 import _test_helpers + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection + + client = _test_helpers.make_client() + collection = _make_collection_reference("collection", client=client) + pipeline = collection.pipeline() + assert isinstance(pipeline, Pipeline) + # should have single "Collection" stage + assert len(pipeline.stages) == 1 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/collection" + + +def test_collectionreference_pipeline_no_client(): + collection = _make_collection_reference("collection") + with pytest.raises(ValueError, match="client"): + collection.pipeline() + + @mock.patch("google.cloud.firestore_v1.collection.Watch", autospec=True) def test_on_snapshot(watch): collection = _make_collection_reference("collection") diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index b8c37cf84..8b1217370 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -1046,3 +1046,22 @@ def test_collection_group_get_partitions_w_offset(database): query = _make_collection_group(parent).offset(10) with pytest.raises(ValueError): list(query.get_partitions(2)) + + +def test_asyncquery_collection_pipeline_type(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = make_client() + parent = client.collection("test") + query = parent._query() + ppl = query.pipeline() + assert isinstance(ppl, Pipeline) + + +def test_asyncquery_collectiongroup_pipeline_type(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = make_client() + query = client.collection_group("test") + ppl = query.pipeline() + assert isinstance(ppl, Pipeline)