diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 24b890ebcd..d03ab72b87 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -34,13 +34,14 @@ ) from google.cloud.firestore_v1 import async_document +from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery from google.cloud.firestore_v1.base_document import DocumentSnapshot -from typing import AsyncGenerator, List, Optional, Type - -# Types needed only for Type Hints -from google.cloud.firestore_v1.transaction import Transaction +from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING -from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery +if TYPE_CHECKING: # pragma: NO COVER + # Types needed only for Type Hints + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.field_path import FieldPath class AsyncQuery(BaseQuery): @@ -222,8 +223,8 @@ def count( """Adds a count over the nested query. Args: - alias - (Optional[str]): The alias for the count + alias(Optional[str]): Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. Returns: :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`: @@ -231,6 +232,38 @@ def count( """ return AsyncAggregationQuery(self).count(alias=alias) + def sum( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: + """Adds a sum over the nested query. + + Args: + field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across. + alias(Optional[str]): Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + + Returns: + :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`: + An instance of an AsyncAggregationQuery object + """ + return AsyncAggregationQuery(self).sum(field_ref, alias=alias) + + def avg( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: + """Adds an avg over the nested query. + + Args: + field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across. + alias(Optional[str]): Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + + Returns: + :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`: + An instance of an AsyncAggregationQuery object + """ + return AsyncAggregationQuery(self).avg(field_ref, alias=alias) + async def stream( self, transaction=None, diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 0eb6750a7d..d6097c136b 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -33,8 +33,8 @@ from google.api_core import retry as retries +from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.types import RunAggregationQueryResponse - from google.cloud.firestore_v1.types import StructuredAggregationQuery from google.cloud.firestore_v1 import _helpers @@ -60,6 +60,9 @@ def __repr__(self): class BaseAggregation(ABC): + def __init__(self, alias: str | None = None): + self.alias = alias + @abc.abstractmethod def _to_protobuf(self): """Convert this instance to the protobuf representation""" @@ -67,7 +70,7 @@ def _to_protobuf(self): class CountAggregation(BaseAggregation): def __init__(self, alias: str | None = None): - self.alias = alias + super(CountAggregation, self).__init__(alias=alias) def _to_protobuf(self): """Convert this instance to the protobuf representation""" @@ -77,13 +80,48 @@ def _to_protobuf(self): return aggregation_pb +class SumAggregation(BaseAggregation): + def __init__(self, field_ref: str | FieldPath, alias: str | None = None): + if isinstance(field_ref, FieldPath): + # convert field path to string + field_ref = field_ref.to_api_repr() + self.field_ref = field_ref + super(SumAggregation, self).__init__(alias=alias) + + def _to_protobuf(self): + """Convert this instance to the protobuf representation""" + aggregation_pb = StructuredAggregationQuery.Aggregation() + aggregation_pb.alias = self.alias + aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum() + aggregation_pb.sum.field.field_path = self.field_ref + return aggregation_pb + + +class AvgAggregation(BaseAggregation): + def __init__(self, field_ref: str | FieldPath, alias: str | None = None): + if isinstance(field_ref, FieldPath): + # convert field path to string + field_ref = field_ref.to_api_repr() + self.field_ref = field_ref + super(AvgAggregation, self).__init__(alias=alias) + + def _to_protobuf(self): + """Convert this instance to the protobuf representation""" + aggregation_pb = StructuredAggregationQuery.Aggregation() + aggregation_pb.alias = self.alias + aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg() + aggregation_pb.avg.field.field_path = self.field_ref + return aggregation_pb + + def _query_response_to_result( response_pb: RunAggregationQueryResponse, ) -> List[AggregationResult]: results = [ AggregationResult( alias=key, - value=response_pb.result.aggregate_fields[key].integer_value, + value=response_pb.result.aggregate_fields[key].integer_value + or response_pb.result.aggregate_fields[key].double_value, read_time=response_pb.read_time, ) for key in response_pb.result.aggregate_fields.pb.keys() @@ -95,11 +133,9 @@ def _query_response_to_result( class BaseAggregationQuery(ABC): """Represents an aggregation query to the Firestore API.""" - def __init__( - self, - nested_query, - ) -> None: + def __init__(self, nested_query, alias: str | None = None) -> None: self._nested_query = nested_query + self._alias = alias self._collection_ref = nested_query._parent self._aggregations: List[BaseAggregation] = [] @@ -115,6 +151,22 @@ def count(self, alias: str | None = None): self._aggregations.append(count_aggregation) return self + def sum(self, field_ref: str | FieldPath, alias: str | None = None): + """ + Adds a sum over the nested query + """ + sum_aggregation = SumAggregation(field_ref, alias=alias) + self._aggregations.append(sum_aggregation) + return self + + def avg(self, field_ref: str | FieldPath, alias: str | None = None): + """ + Adds an avg over the nested query + """ + avg_aggregation = AvgAggregation(field_ref, alias=alias) + self._aggregations.append(avg_aggregation) + return self + def add_aggregation(self, aggregation: BaseAggregation) -> None: """ Adds an aggregation operation to the nested query diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index dd74bf1a00..a9d644c4b4 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -13,6 +13,7 @@ # limitations under the License. """Classes for representing collections for the Google Cloud Firestore API.""" +from __future__ import annotations import random from google.api_core import retry as retries @@ -20,6 +21,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery +from google.cloud.firestore_v1.base_query import QueryType from typing import ( @@ -35,12 +37,15 @@ NoReturn, Tuple, Union, + TYPE_CHECKING, ) -# Types needed only for Type Hints -from google.cloud.firestore_v1.base_document import DocumentSnapshot -from google.cloud.firestore_v1.base_query import QueryType -from google.cloud.firestore_v1.transaction import Transaction + +if TYPE_CHECKING: # pragma: NO COVER + # Types needed only for Type Hints + from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.field_path import FieldPath _AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" @@ -244,7 +249,7 @@ def where( op_string: Optional[str] = None, value=None, *, - filter=None + filter=None, ) -> QueryType: """Create a "where" query with this collection as parent. @@ -507,6 +512,33 @@ def count(self, alias=None): """ return self._aggregation_query().count(alias=alias) + def sum(self, field_ref: str | FieldPath, alias=None): + """ + Adds a sum over the nested query. + + :type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath] + :param field_ref: The field to aggregate across. + + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + + """ + return self._aggregation_query().sum(field_ref, alias=alias) + + def avg(self, field_ref: str | FieldPath, alias=None): + """ + Adds an avg over the nested query. + + :type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath] + :param field_ref: The field to aggregate across. + + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + """ + return self._aggregation_query().avg(field_ref, alias=alias) + def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index c179109835..da1e41232e 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -49,11 +49,15 @@ Type, TypeVar, Union, + TYPE_CHECKING, ) # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.field_path import FieldPath + _BAD_DIR_STRING: str _BAD_OP_NAN_NULL: str _BAD_OP_STRING: str @@ -970,6 +974,16 @@ def count( ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: raise NotImplementedError + def sum( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: + raise NotImplementedError + + def avg( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: + raise NotImplementedError + def get( self, transaction=None, diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 1f3dbbc1e8..d37964dce0 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -38,7 +38,10 @@ from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List, Optional, Type +from typing import Any, Callable, Generator, List, Optional, Type, TYPE_CHECKING + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.field_path import FieldPath class Query(BaseQuery): @@ -242,11 +245,42 @@ def count( """ Adds a count over the query. - :type alias: str - :param alias: (Optional) The alias for the count + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. """ return aggregation.AggregationQuery(self).count(alias=alias) + def sum( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.aggregation.AggregationQuery"]: + """ + Adds a sum over the query. + + :type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath] + :param field_ref: The field to aggregate across. + + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + """ + return aggregation.AggregationQuery(self).sum(field_ref, alias=alias) + + def avg( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.aggregation.AggregationQuery"]: + """ + Adds an avg over the query. + + :type field_ref: [Union[str, google.cloud.firestore_v1.field_path.FieldPath] + :param field_ref: The field to aggregate across. + + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + """ + return aggregation.AggregationQuery(self).avg(field_ref, alias=alias) + def stream( self, transaction=None, diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 4d3bba1dcb..12e3b87b22 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -564,10 +564,14 @@ def query_docs(client, database): @pytest.fixture -def query(query_docs): - collection, stored, allowed_vals = query_docs - query = collection.where(filter=FieldFilter("a", "==", 1)) - return query +def collection(query_docs): + collection, _, _ = query_docs + return collection + + +@pytest.fixture +def query(collection): + return collection.where(filter=FieldFilter("a", "==", 1)) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) @@ -1879,77 +1883,283 @@ def test_count_query_stream_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@firestore.transactional -def create_in_transaction(collection_id, transaction, cleanup): - collection = client.collection(collection_id) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_count_query_with_start_at(query, database): + """ + Ensure that count aggregation queries work when chained with a start_at + + eg `col.where(...).startAt(...).count()` + """ + result = query.get() + start_doc = result[1] + # find count excluding first result + expected_count = len(result) - 1 + # start new query that starts at the second result + count_query = query.start_at(start_doc).count("a") + # ensure that the first doc was skipped in sum aggregation + for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == expected_count - query = collection.where(filter=FieldFilter("a", "==", 1)) - count_query = query.count() - result = count_query.get(transaction=transaction) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_get_default_alias(collection, database): + sum_query = collection.sum("stats.product") + result = sum_query.get() + assert len(result) == 1 for r in result[0]: - assert r.value <= 2 - if r.value < 2: - document_id_3 = "doc3" + UNIQUE_RESOURCE_ID - document_3 = client.document(collection_id, document_id_3) - cleanup(document_3.delete) - document_3.create({"a": 1}) - else: - raise ValueError("Collection can't have more than 2 documents") + assert r.alias == "field_1" + assert r.value == 100 -@firestore.transactional -def create_in_transaction_helper(transaction, client, collection_id, cleanup, database): - collection = client.collection(collection_id) - query = collection.where(filter=FieldFilter("a", "==", 1)) - count_query = query.count() - result = count_query.get(transaction=transaction) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_get_with_alias(collection, database): + sum_query = collection.sum("stats.product", alias="total") + result = sum_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 100 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_get_with_limit(collection, database): + # sum without limit + sum_query = collection.sum("stats.product", alias="total") + result = sum_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 100 + + # sum with limit + # limit query = [0,0,0,0,0,0,0,0,0,1,2,2] + sum_query = collection.limit(12).sum("stats.product", alias="total") + result = sum_query.get() + assert len(result) == 1 for r in result[0]: - if r.value < 2: - document_id_3 = "doc3" + UNIQUE_RESOURCE_ID - document_3 = client.document(collection_id, document_id_3) - cleanup(document_3.delete) - document_3.create({"a": 1}) - else: # transaction is rolled back - raise ValueError("Collection can't have more than 2 docs") + assert r.alias == "total" + assert r.value == 5 @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_count_query_in_transaction(client, cleanup, database): - collection_id = "doc-create" + UNIQUE_RESOURCE_ID - document_id_1 = "doc1" + UNIQUE_RESOURCE_ID - document_id_2 = "doc2" + UNIQUE_RESOURCE_ID +def test_sum_query_get_multiple_aggregations(collection, database): + sum_query = collection.sum("stats.product", alias="total").sum( + "stats.product", alias="all" + ) - document_1 = client.document(collection_id, document_id_1) - document_2 = client.document(collection_id, document_id_2) + result = sum_query.get() + assert len(result[0]) == 2 - cleanup(document_1.delete) - cleanup(document_2.delete) + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) - document_1.create({"a": 1}) - document_2.create({"a": 1}) - transaction = client.transaction() +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_stream_default_alias(collection, database): + sum_query = collection.sum("stats.product") + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + assert aggregation_result.value == 100 - with pytest.raises(ValueError) as exc: - create_in_transaction_helper( - transaction, client, collection_id, cleanup, database - ) - assert str(exc.value) == "Collection can't have more than 2 docs" - collection = client.collection(collection_id) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_stream_with_alias(collection, database): + sum_query = collection.sum("stats.product", alias="total") + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 100 - query = collection.where(filter=FieldFilter("a", "==", 1)) - count_query = query.count() - result = count_query.get() + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_stream_with_limit(collection, database): + # sum without limit + sum_query = collection.sum("stats.product", alias="total") + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 100 + + # sum with limit + sum_query = collection.limit(12).sum("stats.product", alias="total") + + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 5 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_stream_multiple_aggregations(collection, database): + sum_query = collection.sum("stats.product", alias="total").sum( + "stats.product", alias="all" + ) + + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + +# tests for issue reported in b/306241058 +# we will skip test in client for now, until backend fix is implemented +@pytest.mark.skip(reason="backend fix required") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_with_start_at(query, database): + """ + Ensure that sum aggregation queries work when chained with a start_at + + eg `col.where(...).startAt(...).sum()` + """ + result = query.get() + start_doc = result[1] + # find sum excluding first result + expected_sum = sum([doc.get("a") for doc in result[1:]]) + # start new query that starts at the second result + sum_result = query.start_at(start_doc).sum("a").get() + assert len(sum_result) == 1 + # ensure that the first doc was skipped in sum aggregation + assert sum_result[0].value == expected_sum + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_get_default_alias(collection, database): + avg_query = collection.avg("stats.product") + result = avg_query.get() + assert len(result) == 1 for r in result[0]: - assert r.value == 2 # there are still only 2 docs + assert r.alias == "field_1" + assert r.value == 4.0 + assert isinstance(r.value, float) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_and_composite_filter(query_docs, database): - collection, stored, allowed_vals = query_docs +def test_avg_query_get_with_alias(collection, database): + avg_query = collection.avg("stats.product", alias="total") + result = avg_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 4 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_get_with_limit(collection, database): + # avg without limit + avg_query = collection.avg("stats.product", alias="total") + result = avg_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 4.0 + + # avg with limit + # limit result = [0,0,0,0,0,0,0,0,0,1,2,2] + avg_query = collection.limit(12).avg("stats.product", alias="total") + + result = avg_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 5 / 12 + assert isinstance(r.value, float) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_get_multiple_aggregations(collection, database): + avg_query = collection.avg("stats.product", alias="total").avg( + "stats.product", alias="all" + ) + + result = avg_query.get() + assert len(result[0]) == 2 + + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_stream_default_alias(collection, database): + avg_query = collection.avg("stats.product") + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + assert aggregation_result.value == 4 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_stream_with_alias(collection, database): + avg_query = collection.avg("stats.product", alias="total") + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 4 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_stream_with_limit(collection, database): + # avg without limit + avg_query = collection.avg("stats.product", alias="total") + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 4 + + # avg with limit + avg_query = collection.limit(12).avg("stats.product", alias="total") + + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 5 / 12 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_stream_multiple_aggregations(collection, database): + avg_query = collection.avg("stats.product", alias="total").avg( + "stats.product", alias="all" + ) + + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + +# tests for issue reported in b/306241058 +# we will skip test in client for now, until backend fix is implemented +@pytest.mark.skip(reason="backend fix required") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_with_start_at(query, database): + """ + Ensure that avg aggregation queries work when chained with a start_at + + eg `col.where(...).startAt(...).avg()` + """ + from statistics import mean + + result = query.get() + start_doc = result[1] + # find average, excluding first result + expected_avg = mean([doc.get("a") for doc in result[1:]]) + # start new query that starts at the second result + avg_result = query.start_at(start_doc).avg("a").get() + assert len(avg_result) == 1 + # ensure that the first doc was skipped in avg aggregation + assert avg_result[0].value == expected_avg + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_query_with_and_composite_filter(collection, database): and_filter = And( filters=[ FieldFilter("stats.product", ">", 5), @@ -1964,8 +2174,7 @@ def test_query_with_and_composite_filter(query_docs, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_or_composite_filter(query_docs, database): - collection, stored, allowed_vals = query_docs +def test_query_with_or_composite_filter(collection, database): or_filter = Or( filters=[ FieldFilter("stats.product", ">", 5), @@ -1988,8 +2197,7 @@ def test_query_with_or_composite_filter(query_docs, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_complex_composite_filter(query_docs, database): - collection, stored, allowed_vals = query_docs +def test_query_with_complex_composite_filter(collection, database): field_filter = FieldFilter("b", "==", 0) or_filter = Or( filters=[FieldFilter("stats.sum", "==", 0), FieldFilter("stats.sum", "==", 4)] @@ -2033,48 +2241,140 @@ def test_query_with_complex_composite_filter(query_docs, database): assert b_not_3 is True +@pytest.mark.parametrize( + "aggregation_type,aggregation_args,expected", + [("count", (), 3), ("sum", ("b"), 12), ("avg", ("b"), 4)], +) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_or_query_in_transaction(client, cleanup, database): +def test_aggregation_query_in_transaction( + client, cleanup, database, aggregation_type, aggregation_args, expected +): + """ + Test creating an aggregation query inside a transaction + Should send transaction id along with request. Results should be consistent with non-transactional query + """ collection_id = "doc-create" + UNIQUE_RESOURCE_ID - document_id_1 = "doc1" + UNIQUE_RESOURCE_ID - document_id_2 = "doc2" + UNIQUE_RESOURCE_ID + doc_ids = [f"doc{i}" + UNIQUE_RESOURCE_ID for i in range(4)] + doc_refs = [client.document(collection_id, doc_id) for doc_id in doc_ids] + for doc_ref in doc_refs: + cleanup(doc_ref.delete) + doc_refs[0].create({"a": 3, "b": 1}) + doc_refs[1].create({"a": 5, "b": 1}) + doc_refs[2].create({"a": 5, "b": 10}) + doc_refs[3].create({"a": 10, "b": 0}) # should be ignored by query - document_1 = client.document(collection_id, document_id_1) - document_2 = client.document(collection_id, document_id_2) + collection = client.collection(collection_id) + query = collection.where(filter=FieldFilter("b", ">", 0)) + aggregation_query = getattr(query, aggregation_type)(*aggregation_args) - cleanup(document_1.delete) - cleanup(document_2.delete) + with client.transaction() as transaction: + # should fail if transaction has not been initiated + with pytest.raises(ValueError): + aggregation_query.get(transaction=transaction) - document_1.create({"a": 1, "b": 2}) - document_2.create({"a": 1, "b": 1}) + # should work when transaction is initiated through transactional decorator + @firestore.transactional + def in_transaction(transaction): + global inner_fn_ran + result = aggregation_query.get(transaction=transaction) + assert len(result) == 1 + assert len(result[0]) == 1 + assert result[0][0].value == expected + inner_fn_ran = True - transaction = client.transaction() + in_transaction(transaction) + # make sure we didn't skip assertions in inner function + assert inner_fn_ran is True - with pytest.raises(ValueError) as exc: - create_in_transaction_helper( - transaction, client, collection_id, cleanup, database - ) - assert str(exc.value) == "Collection can't have more than 2 docs" - collection = client.collection(collection_id) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_or_query_in_transaction(client, cleanup, database): + """ + Test running or query inside a transaction. Should pass transaction id along with request + """ + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + doc_ids = [f"doc{i}" + UNIQUE_RESOURCE_ID for i in range(5)] + doc_refs = [client.document(collection_id, doc_id) for doc_id in doc_ids] + for doc_ref in doc_refs: + cleanup(doc_ref.delete) + doc_refs[0].create({"a": 1, "b": 2}) + doc_refs[1].create({"a": 1, "b": 1}) + doc_refs[2].create({"a": 2, "b": 1}) # should be ignored by query + doc_refs[3].create({"a": 1, "b": 0}) # should be ignored by query + collection = client.collection(collection_id) query = collection.where(filter=FieldFilter("a", "==", 1)).where( filter=Or([FieldFilter("b", "==", 1), FieldFilter("b", "==", 2)]) ) - b_1 = False - b_2 = False - count = 0 - for result in query.stream(): - assert result.get("a") == 1 # assert a==1 is True in both results - assert result.get("b") == 1 or result.get("b") == 2 - if result.get("b") == 1: - b_1 = True - if result.get("b") == 2: - b_2 = True - count += 1 - - assert b_1 is True # assert one of them is b == 1 - assert b_2 is True # assert one of them is b == 2 - assert ( - count == 2 - ) # assert only 2 results, the third one was rolledback and not created + + with client.transaction() as transaction: + # should fail if transaction has not been initiated + with pytest.raises(ValueError): + query.get(transaction=transaction) + + # should work when transaction is initiated through transactional decorator + @firestore.transactional + def in_transaction(transaction): + global inner_fn_ran + result = query.get(transaction=transaction) + assert len(result) == 2 + # both documents should have a == 1 + assert result[0].get("a") == 1 + assert result[1].get("a") == 1 + # one document should have b == 1 and the other should have b == 2 + assert (result[0].get("b") == 1 and result[1].get("b") == 2) or ( + result[0].get("b") == 2 and result[1].get("b") == 1 + ) + inner_fn_ran = True + + in_transaction(transaction) + # make sure we didn't skip assertions in inner function + assert inner_fn_ran is True + + +@pytest.mark.parametrize("with_rollback,expected", [(True, 2), (False, 3)]) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_transaction_rollback(client, cleanup, database, with_rollback, expected): + """ + Create a document in a transaction that is rolled back + Document should not show up in later queries + """ + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + doc_ids = [f"doc{i}" + UNIQUE_RESOURCE_ID for i in range(3)] + doc_refs = [client.document(collection_id, doc_id) for doc_id in doc_ids] + for doc_ref in doc_refs: + cleanup(doc_ref.delete) + doc_refs[0].create({"a": 1}) + doc_refs[1].create({"a": 1}) + doc_refs[2].create({"a": 2}) # should be ignored by query + + transaction = client.transaction() + + @firestore.transactional + def in_transaction(transaction, rollback): + """ + create a document in a transaction that is rolled back (raises an exception) + """ + new_document_id = "in_transaction_doc" + UNIQUE_RESOURCE_ID + new_document_ref = client.document(collection_id, new_document_id) + cleanup(new_document_ref.delete) + transaction.create(new_document_ref, {"a": 1}) + if rollback: + raise RuntimeError("rollback") + + if with_rollback: + # run transaction in function that results in a rollback + with pytest.raises(RuntimeError) as exc: + in_transaction(transaction, with_rollback) + assert str(exc.value) == "rollback" + else: + # no rollback expected + in_transaction(transaction, with_rollback) + + collection = client.collection(collection_id) + + query = collection.where(filter=FieldFilter("a", "==", 1)).count() + result = query.get() + assert len(result) == 1 + assert len(result[0]) == 1 + assert result[0][0].value == expected diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 3d75f61298..5201149167 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -609,11 +609,14 @@ async def query_docs(client): @pytest_asyncio.fixture -async def async_query(query_docs): - collection, stored, allowed_vals = query_docs - query = collection.where(filter=FieldFilter("a", "==", 1)) +async def collection(query_docs): + collection, _, _ = query_docs + yield collection - return query + +@pytest_asyncio.fixture +async def async_query(collection): + return collection.where(filter=FieldFilter("a", "==", 1)) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) @@ -1575,7 +1578,7 @@ async def test_async_count_query_get_empty_aggregation(async_query, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_count_async_query_stream_default_alias(async_query, database): +async def test_async_count_query_stream_default_alias(async_query, database): count_query = async_query.count() async for result in count_query.stream(): @@ -1642,6 +1645,201 @@ async def test_async_count_query_stream_empty_aggregation(async_query, database) assert "Aggregations can not be empty" in exc_info.value.message +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_get_default_alias(collection, database): + sum_query = collection.sum("stats.product") + result = await sum_query.get() + for r in result[0]: + assert r.alias == "field_1" + assert r.value == 100 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_get_with_alias(collection, database): + sum_query = collection.sum("stats.product", alias="total") + result = await sum_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 100 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_get_with_limit(collection, database): + sum_query = collection.sum("stats.product", alias="total") + result = await sum_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 100 + + # sum with limit + sum_query = collection.limit(12).sum("stats.product", alias="total") + result = await sum_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 5 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_get_multiple_aggregations(collection, database): + sum_query = collection.sum("stats.product", alias="total").sum( + "stats.product", alias="all" + ) + + result = await sum_query.get() + assert len(result[0]) == 2 + + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_stream_default_alias(collection, database): + sum_query = collection.sum("stats.product") + + async for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + assert aggregation_result.value == 100 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_stream_with_alias(collection, database): + sum_query = collection.sum("stats.product", alias="total") + async for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_stream_with_limit(collection, database): + # sum without limit + sum_query = collection.sum("stats.product", alias="total") + async for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 100 + + # sum with limit + sum_query = collection.limit(12).sum("stats.product", alias="total") + async for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 5 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_stream_multiple_aggregations(collection, database): + sum_query = collection.sum("stats.product", alias="total").sum( + "stats.product", alias="all" + ) + + async for result in sum_query.stream(): + assert len(result) == 2 + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_get_default_alias(collection, database): + avg_query = collection.avg("stats.product") + result = await avg_query.get() + for r in result[0]: + assert r.alias == "field_1" + assert r.value == 4 + assert isinstance(r.value, float) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_get_with_alias(collection, database): + avg_query = collection.avg("stats.product", alias="total") + result = await avg_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 4 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_get_with_limit(collection, database): + avg_query = collection.avg("stats.product", alias="total") + result = await avg_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 4 + + # avg with limit + avg_query = collection.limit(12).avg("stats.product", alias="total") + result = await avg_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 5 / 12 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_get_multiple_aggregations(collection, database): + avg_query = collection.avg("stats.product", alias="total").avg( + "stats.product", alias="all" + ) + + result = await avg_query.get() + assert len(result[0]) == 2 + + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_stream_default_alias(collection, database): + avg_query = collection.avg("stats.product") + + async for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + assert aggregation_result.value == 4.0 + assert isinstance(aggregation_result.value, float) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_stream_with_alias(collection, database): + avg_query = collection.avg("stats.product", alias="total") + async for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_stream_with_limit(collection, database): + # avg without limit + avg_query = collection.avg("stats.product", alias="total") + async for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 4.0 + + # avg with limit + avg_query = collection.limit(12).avg("stats.product", alias="total") + async for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 5 / 12 + assert isinstance(aggregation_result.value, float) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_stream_multiple_aggregations(collection, database): + avg_query = collection.avg("stats.product", alias="total").avg( + "stats.product", alias="all" + ) + + async for result in avg_query.stream(): + assert len(result) == 2 + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + @firestore.async_transactional async def create_in_transaction_helper( transaction, client, collection_id, cleanup, database diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 7b07aa9afa..d19cf69e81 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -21,6 +21,8 @@ from google.cloud.firestore_v1.base_aggregation import ( CountAggregation, + SumAggregation, + AvgAggregation, AggregationResult, ) from tests.unit.v1._test_helpers import ( @@ -46,6 +48,58 @@ def test_count_aggregation_to_pb(): assert count_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_sum_aggregation_w_field_path(): + """ + SumAggregation should convert FieldPath inputs into strings + """ + from google.cloud.firestore_v1.field_path import FieldPath + + field_path = FieldPath("foo", "bar") + sum_aggregation = SumAggregation(field_path, alias="total") + assert sum_aggregation.field_ref == "foo.bar" + + +def test_avg_aggregation_w_field_path(): + """ + AvgAggregation should convert FieldPath inputs into strings + """ + from google.cloud.firestore_v1.field_path import FieldPath + + field_path = FieldPath("foo", "bar") + avg_aggregation = AvgAggregation(field_path, alias="total") + assert avg_aggregation.field_ref == "foo.bar" + + +def test_sum_aggregation_to_pb(): + from google.cloud.firestore_v1.types import query as query_pb2 + + sum_aggregation = SumAggregation("someref", alias="total") + + expected_aggregation_query_pb = query_pb2.StructuredAggregationQuery.Aggregation() + expected_aggregation_query_pb.sum = ( + query_pb2.StructuredAggregationQuery.Aggregation.Sum() + ) + expected_aggregation_query_pb.sum.field.field_path = "someref" + + expected_aggregation_query_pb.alias = sum_aggregation.alias + assert sum_aggregation._to_protobuf() == expected_aggregation_query_pb + + +def test_avg_aggregation_to_pb(): + from google.cloud.firestore_v1.types import query as query_pb2 + + avg_aggregation = AvgAggregation("someref", alias="total") + + expected_aggregation_query_pb = query_pb2.StructuredAggregationQuery.Aggregation() + expected_aggregation_query_pb.avg = ( + query_pb2.StructuredAggregationQuery.Aggregation.Avg() + ) + expected_aggregation_query_pb.avg.field.field_path = "someref" + expected_aggregation_query_pb.alias = avg_aggregation.alias + + assert avg_aggregation._to_protobuf() == expected_aggregation_query_pb + + def test_aggregation_query_constructor(): client = make_client() parent = client.collection("dee") @@ -64,11 +118,23 @@ def test_aggregation_query_add_aggregation(): query = make_query(parent) aggregation_query = make_aggregation_query(query) aggregation_query.add_aggregation(CountAggregation(alias="all")) + aggregation_query.add_aggregation(SumAggregation("sumref", alias="sum_all")) + aggregation_query.add_aggregation(AvgAggregation("avgref", alias="avg_all")) - assert len(aggregation_query._aggregations) == 1 + assert len(aggregation_query._aggregations) == 3 assert aggregation_query._aggregations[0].alias == "all" assert isinstance(aggregation_query._aggregations[0], CountAggregation) + assert len(aggregation_query._aggregations) == 3 + assert aggregation_query._aggregations[1].alias == "sum_all" + assert aggregation_query._aggregations[1].field_ref == "sumref" + assert isinstance(aggregation_query._aggregations[1], SumAggregation) + + assert len(aggregation_query._aggregations) == 3 + assert aggregation_query._aggregations[2].alias == "avg_all" + assert aggregation_query._aggregations[2].field_ref == "avgref" + assert isinstance(aggregation_query._aggregations[2], AvgAggregation) + def test_aggregation_query_add_aggregations(): client = make_client() @@ -77,15 +143,26 @@ def test_aggregation_query_add_aggregations(): aggregation_query = make_aggregation_query(query) aggregation_query.add_aggregations( - [CountAggregation(alias="all"), CountAggregation(alias="total")] + [ + CountAggregation(alias="all"), + CountAggregation(alias="total"), + SumAggregation("sumref", alias="sum_all"), + AvgAggregation("avgref", alias="avg_all"), + ] ) - assert len(aggregation_query._aggregations) == 2 + assert len(aggregation_query._aggregations) == 4 assert aggregation_query._aggregations[0].alias == "all" assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[2].alias == "sum_all" + assert aggregation_query._aggregations[2].field_ref == "sumref" + assert aggregation_query._aggregations[3].alias == "avg_all" + assert aggregation_query._aggregations[3].field_ref == "avgref" assert isinstance(aggregation_query._aggregations[0], CountAggregation) assert isinstance(aggregation_query._aggregations[1], CountAggregation) + assert isinstance(aggregation_query._aggregations[2], SumAggregation) + assert isinstance(aggregation_query._aggregations[3], AvgAggregation) def test_aggregation_query_count(): @@ -118,6 +195,102 @@ def test_aggregation_query_count_twice(): assert isinstance(aggregation_query._aggregations[1], CountAggregation) +def test_aggregation_query_sum(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.sum("someref", alias="all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + + +def test_aggregation_query_sum_twice(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.sum("someref", alias="all").sum("another_ref", alias="total") + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[1].field_ref == "another_ref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + assert isinstance(aggregation_query._aggregations[1], SumAggregation) + + +def test_aggregation_query_sum_no_alias(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.sum("someref") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias is None + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + + +def test_aggregation_query_avg(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.avg("someref", alias="all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + + +def test_aggregation_query_avg_twice(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.avg("someref", alias="all").avg("another_ref", alias="total") + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[1].field_ref == "another_ref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + assert isinstance(aggregation_query._aggregations[1], AvgAggregation) + + +def test_aggregation_query_avg_no_alias(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.avg("someref") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias is None + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + + def test_aggregation_query_to_protobuf(): client = make_client() parent = client.collection("dee") @@ -125,11 +298,15 @@ def test_aggregation_query_to_protobuf(): aggregation_query = make_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") pb = aggregation_query._to_protobuf() assert pb.structured_query == parent._query()._to_protobuf() - assert len(pb.aggregations) == 1 + assert len(pb.aggregations) == 3 assert pb.aggregations[0] == aggregation_query._aggregations[0]._to_protobuf() + assert pb.aggregations[1] == aggregation_query._aggregations[1]._to_protobuf() + assert pb.aggregations[2] == aggregation_query._aggregations[2]._to_protobuf() def test_aggregation_query_prep_stream(): @@ -139,6 +316,8 @@ def test_aggregation_query_prep_stream(): aggregation_query = make_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") request, kwargs = aggregation_query._prep_stream() @@ -163,6 +342,8 @@ def test_aggregation_query_prep_stream_with_transaction(): aggregation_query = make_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") request, kwargs = aggregation_query._prep_stream(transaction=transaction) @@ -194,6 +375,7 @@ def _aggregation_query_get_helper(retry=None, timeout=None, read_time=None): aggregation_query.count(alias="all") aggregation_result = AggregationResult(alias="total", value=5, read_time=read_time) + response_pb = make_aggregation_query_response( [aggregation_result], read_time=read_time ) @@ -446,31 +628,38 @@ def test_aggregation_from_query(): response_pb = make_aggregation_query_response( [aggregation_result], transaction=txn_id ) - firestore_api.run_aggregation_query.return_value = iter([response_pb]) retry = None timeout = None kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - aggregation_query = query.count(alias="total") - returned = aggregation_query.get(transaction=transaction, **kwargs) - assert isinstance(returned, list) - assert len(returned) == 1 - - for result in returned: - for r in result: - assert r.alias == aggregation_result.alias - assert r.value == aggregation_result.value - - # Verify the mock call. - parent_path, _ = parent._parent_info() - - firestore_api.run_aggregation_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_aggregation_query": aggregation_query._to_protobuf(), - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + for aggregation_query in [ + query.count(alias="total"), + query.sum("foo", alias="total"), + query.avg("foo", alias="total"), + ]: + # reset api mock + firestore_api.run_aggregation_query.reset_mock() + firestore_api.run_aggregation_query.return_value = iter([response_pb]) + # run query + returned = aggregation_query.get(transaction=transaction, **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + + # Verify the mock call. + parent_path, _ = parent._parent_info() + + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 711975535e..4ed97ddb98 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -19,6 +19,8 @@ from google.cloud.firestore_v1.base_aggregation import ( CountAggregation, + SumAggregation, + AvgAggregation, AggregationResult, ) @@ -54,11 +56,22 @@ def test_async_aggregation_query_add_aggregation(): aggregation_query = make_async_aggregation_query(query) aggregation_query.add_aggregation(CountAggregation(alias="all")) + aggregation_query.add_aggregation(SumAggregation("someref", alias="sum_all")) + aggregation_query.add_aggregation(AvgAggregation("otherref", alias="avg_all")) + + assert len(aggregation_query._aggregations) == 3 - assert len(aggregation_query._aggregations) == 1 assert aggregation_query._aggregations[0].alias == "all" assert isinstance(aggregation_query._aggregations[0], CountAggregation) + assert aggregation_query._aggregations[1].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "sum_all" + assert isinstance(aggregation_query._aggregations[1], SumAggregation) + + assert aggregation_query._aggregations[2].field_ref == "otherref" + assert aggregation_query._aggregations[2].alias == "avg_all" + assert isinstance(aggregation_query._aggregations[2], AvgAggregation) + def test_async_aggregation_query_add_aggregations(): client = make_async_client() @@ -67,15 +80,28 @@ def test_async_aggregation_query_add_aggregations(): aggregation_query = make_async_aggregation_query(query) aggregation_query.add_aggregations( - [CountAggregation(alias="all"), CountAggregation(alias="total")] + [ + CountAggregation(alias="all"), + CountAggregation(alias="total"), + SumAggregation("someref", alias="sum_all"), + AvgAggregation("otherref", alias="avg_all"), + ] ) - assert len(aggregation_query._aggregations) == 2 + assert len(aggregation_query._aggregations) == 4 assert aggregation_query._aggregations[0].alias == "all" assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[2].field_ref == "someref" + assert aggregation_query._aggregations[2].alias == "sum_all" + + assert aggregation_query._aggregations[3].field_ref == "otherref" + assert aggregation_query._aggregations[3].alias == "avg_all" + assert isinstance(aggregation_query._aggregations[0], CountAggregation) assert isinstance(aggregation_query._aggregations[1], CountAggregation) + assert isinstance(aggregation_query._aggregations[2], SumAggregation) + assert isinstance(aggregation_query._aggregations[3], AvgAggregation) def test_async_aggregation_query_count(): @@ -108,6 +134,104 @@ def test_async_aggregation_query_count_twice(): assert isinstance(aggregation_query._aggregations[1], CountAggregation) +def test_async_aggregation_sum(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.sum("someref", alias="sum_all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "sum_all" + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + + +def test_async_aggregation_query_sum_twice(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.sum("someref", alias="sum_all").sum( + "another_ref", alias="sum_total" + ) + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "sum_all" + assert aggregation_query._aggregations[0].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "sum_total" + assert aggregation_query._aggregations[1].field_ref == "another_ref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + assert isinstance(aggregation_query._aggregations[1], SumAggregation) + + +def test_async_aggregation_sum_no_alias(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.sum("someref") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias is None + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + + +def test_aggregation_query_avg(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.avg("someref", alias="all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + + +def test_aggregation_query_avg_twice(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.avg("someref", alias="all").avg("another_ref", alias="total") + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[1].field_ref == "another_ref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + assert isinstance(aggregation_query._aggregations[1], AvgAggregation) + + +def test_aggregation_query_avg_no_alias(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.avg("someref") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias is None + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + + def test_async_aggregation_query_to_protobuf(): client = make_async_client() parent = client.collection("dee") @@ -115,11 +239,15 @@ def test_async_aggregation_query_to_protobuf(): aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sum_all") + aggregation_query.avg("someref", alias="avg_all") pb = aggregation_query._to_protobuf() assert pb.structured_query == parent._query()._to_protobuf() - assert len(pb.aggregations) == 1 + assert len(pb.aggregations) == 3 assert pb.aggregations[0] == aggregation_query._aggregations[0]._to_protobuf() + assert pb.aggregations[1] == aggregation_query._aggregations[1]._to_protobuf() + assert pb.aggregations[2] == aggregation_query._aggregations[2]._to_protobuf() def test_async_aggregation_query_prep_stream(): @@ -129,7 +257,8 @@ def test_async_aggregation_query_prep_stream(): aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias="all") - + aggregation_query.sum("someref", alias="sum_all") + aggregation_query.avg("someref", alias="avg_all") request, kwargs = aggregation_query._prep_stream() parent_path, _ = parent._parent_info() @@ -152,6 +281,8 @@ def test_async_aggregation_query_prep_stream_with_transaction(): query = make_async_query(parent) aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sum_all") + aggregation_query.avg("someref", alias="avg_all") request, kwargs = aggregation_query._prep_stream(transaction=transaction) @@ -318,31 +449,38 @@ async def test_async_aggregation_from_query(): response_pb = make_aggregation_query_response( [aggregation_result], transaction=txn_id ) - firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb]) retry = None timeout = None kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - # Execute the query and check the response. - aggregation_query = query.count(alias="total") - returned = await aggregation_query.get(transaction=transaction, **kwargs) - assert isinstance(returned, list) - assert len(returned) == 1 - - for result in returned: - for r in result: - assert r.alias == aggregation_result.alias - assert r.value == aggregation_result.value - - # Verify the mock call. - parent_path, _ = parent._parent_info() - - firestore_api.run_aggregation_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_aggregation_query": aggregation_query._to_protobuf(), - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + # Execute each aggregation query type and check the response. + for aggregation_query in [ + query.count(alias="total"), + query.sum("foo", alias="total"), + query.avg("foo", alias="total"), + ]: + # reset api mock + firestore_api.run_aggregation_query.reset_mock() + firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb]) + # run query + returned = await aggregation_query.get(transaction=transaction, **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + + # Verify the mock call. + parent_path, _ = parent._parent_info() + + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 0599937cca..c5bce0ae8d 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -97,6 +97,36 @@ def test_async_collection_count(): assert aggregation_query._aggregations[0].alias == alias +def test_async_collection_sum(): + firestore_api = AsyncMock(spec=["create_document", "commit"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + collection = _make_async_collection_reference("grand-parent", client=client) + + alias = "total" + field_ref = "someref" + aggregation_query = collection.sum(field_ref, alias=alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + assert aggregation_query._aggregations[0].field_ref == field_ref + + +def test_async_collection_avg(): + firestore_api = AsyncMock(spec=["create_document", "commit"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + collection = _make_async_collection_reference("grand-parent", client=client) + + alias = "total" + field_ref = "someref" + aggregation_query = collection.avg(field_ref, alias=alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + assert aggregation_query._aggregations[0].field_ref == field_ref + + @pytest.mark.asyncio async def test_asynccollectionreference_add_auto_assigned(): from google.cloud.firestore_v1.types import document diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index b74a215c3f..c0f3d0d9ed 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -160,6 +160,64 @@ async def test_asyncquery_get_limit_to_last(): ) +def test_asyncquery_sum(): + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.base_aggregation import SumAggregation + + client = make_async_client() + parent = client.collection("dee") + field_str = "field_str" + field_path = FieldPath("foo", "bar") + query = make_async_query(parent) + # test with only field populated + sum_query = query.sum(field_str) + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == field_str + assert sum_agg.alias is None + # test with field and alias populated + sum_query = query.sum(field_str, alias="alias") + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == field_str + assert sum_agg.alias == "alias" + # test with field_path + sum_query = query.sum(field_path, alias="alias") + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == "foo.bar" + assert sum_agg.alias == "alias" + + +def test_asyncquery_avg(): + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.base_aggregation import AvgAggregation + + client = make_async_client() + parent = client.collection("dee") + field_str = "field_str" + field_path = FieldPath("foo", "bar") + query = make_async_query(parent) + # test with only field populated + avg_query = query.avg(field_str) + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == field_str + assert avg_agg.alias is None + # test with field and alias populated + avg_query = query.avg(field_str, alias="alias") + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == field_str + assert avg_agg.alias == "alias" + # test with field_path + avg_query = query.avg(field_path, alias="alias") + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == "foo.bar" + assert avg_agg.alias == "alias" + + @pytest.mark.asyncio async def test_asyncquery_chunkify_w_empty(): client = make_async_client() diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 39c0df237d..f3bc099b97 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -81,6 +81,44 @@ def test_collection_count(): assert aggregation_query._aggregations[0].alias == alias +def test_collection_sum(): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = _make_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + + alias = "total" + field_ref = "someref" + aggregation_query = collection.sum(field_ref, alias=alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + assert aggregation_query._aggregations[0].field_ref == field_ref + + +def test_collection_avg(): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = _make_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + + alias = "total" + field_ref = "someref" + aggregation_query = collection.avg(field_ref, alias=alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + assert aggregation_query._aggregations[0].field_ref == field_ref + + def test_constructor(): collection_id1 = "rooms" document_id = "roomA" diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index ad972aa763..a7f2e60162 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -152,6 +152,66 @@ def test_query_get_limit_to_last(database): ) +@pytest.mark.parametrize("database", [None, "somedb"]) +def test_query_sum(database): + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.base_aggregation import SumAggregation + + client = make_client(database=database) + parent = client.collection("dee") + field_str = "field_str" + field_path = FieldPath("foo", "bar") + query = make_query(parent) + # test with only field populated + sum_query = query.sum(field_str) + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == field_str + assert sum_agg.alias is None + # test with field and alias populated + sum_query = query.sum(field_str, alias="alias") + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == field_str + assert sum_agg.alias == "alias" + # test with field_path + sum_query = query.sum(field_path, alias="alias") + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == "foo.bar" + assert sum_agg.alias == "alias" + + +@pytest.mark.parametrize("database", [None, "somedb"]) +def test_query_avg(database): + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.base_aggregation import AvgAggregation + + client = make_client(database=database) + parent = client.collection("dee") + field_str = "field_str" + field_path = FieldPath("foo", "bar") + query = make_query(parent) + # test with only field populated + avg_query = query.avg(field_str) + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == field_str + assert avg_agg.alias is None + # test with field and alias populated + avg_query = query.avg(field_str, alias="alias") + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == field_str + assert avg_agg.alias == "alias" + # test with field_path + avg_query = query.avg(field_path, alias="alias") + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == "foo.bar" + assert avg_agg.alias == "alias" + + @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_chunkify_w_empty(database): client = make_client(database=database)