Skip to content

Commit

Permalink
feat: Sum/Avg aggregation queries (#715)
Browse files Browse the repository at this point in the history
* Feat: Sum/Avg Feature

Adds the ability to perform sum/avg aggregation query through:
- query.sum(),
- query.avg(),
- async_query.sum(),
- async_query.avg()

* fixed proto sum attribute name

* added query tests with alias unset

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* added async tests

* added missing decorators

* fixed wrong expected values in tests

* fixed empty avg aggregations

* ran blacken

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* aggregation test should cover all aggregations

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* fixed async test

* improved transaction tests

* cleaned up new tests

* removed test logic that belongs in unit tests

* ran blacken

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* reverted removed line

* fix docstrings

* accept FieldPath for aggregations

* fixed docstrings

* made test changes to avoid index requirements

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* fixed lint issues

* added field path to collections

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* fixed docs issue

* added tests with start_at

* add no cover marks to TYPE_CHECKING

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* skip cursor aggregations

* import query type

* fixed no cover comments

---------

Co-authored-by: Daniel Sanche <[email protected]>
Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
Co-authored-by: kolea2 <[email protected]>
Co-authored-by: Daniel Sanche <[email protected]>
  • Loading branch information
5 people authored Oct 19, 2023
1 parent ae1247b commit 443475b
Show file tree
Hide file tree
Showing 13 changed files with 1,350 additions and 174 deletions.
47 changes: 40 additions & 7 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@
)

from google.cloud.firestore_v1 import async_document
from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from typing import AsyncGenerator, List, Optional, Type

# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction
from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING

from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery
if TYPE_CHECKING: # pragma: NO COVER
# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction
from google.cloud.firestore_v1.field_path import FieldPath


class AsyncQuery(BaseQuery):
Expand Down Expand Up @@ -222,15 +223,47 @@ def count(
"""Adds a count over the nested query.
Args:
alias
(Optional[str]): The alias for the count
alias(Optional[str]): Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
Returns:
:class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`:
An instance of an AsyncAggregationQuery object
"""
return AsyncAggregationQuery(self).count(alias=alias)

def sum(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]:
"""Adds a sum over the nested query.
Args:
field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across.
alias(Optional[str]): Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
Returns:
:class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`:
An instance of an AsyncAggregationQuery object
"""
return AsyncAggregationQuery(self).sum(field_ref, alias=alias)

def avg(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]:
"""Adds an avg over the nested query.
Args:
field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across.
alias(Optional[str]): Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
Returns:
:class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`:
An instance of an AsyncAggregationQuery object
"""
return AsyncAggregationQuery(self).avg(field_ref, alias=alias)

async def stream(
self,
transaction=None,
Expand Down
66 changes: 59 additions & 7 deletions google/cloud/firestore_v1/base_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from google.api_core import retry as retries


from google.cloud.firestore_v1.field_path import FieldPath
from google.cloud.firestore_v1.types import RunAggregationQueryResponse

from google.cloud.firestore_v1.types import StructuredAggregationQuery
from google.cloud.firestore_v1 import _helpers

Expand All @@ -60,14 +60,17 @@ def __repr__(self):


class BaseAggregation(ABC):
def __init__(self, alias: str | None = None):
self.alias = alias

@abc.abstractmethod
def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""


class CountAggregation(BaseAggregation):
def __init__(self, alias: str | None = None):
self.alias = alias
super(CountAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
Expand All @@ -77,13 +80,48 @@ def _to_protobuf(self):
return aggregation_pb


class SumAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
super(SumAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum()
aggregation_pb.sum.field.field_path = self.field_ref
return aggregation_pb


class AvgAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
super(AvgAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg()
aggregation_pb.avg.field.field_path = self.field_ref
return aggregation_pb


def _query_response_to_result(
response_pb: RunAggregationQueryResponse,
) -> List[AggregationResult]:
results = [
AggregationResult(
alias=key,
value=response_pb.result.aggregate_fields[key].integer_value,
value=response_pb.result.aggregate_fields[key].integer_value
or response_pb.result.aggregate_fields[key].double_value,
read_time=response_pb.read_time,
)
for key in response_pb.result.aggregate_fields.pb.keys()
Expand All @@ -95,11 +133,9 @@ def _query_response_to_result(
class BaseAggregationQuery(ABC):
"""Represents an aggregation query to the Firestore API."""

def __init__(
self,
nested_query,
) -> None:
def __init__(self, nested_query, alias: str | None = None) -> None:
self._nested_query = nested_query
self._alias = alias
self._collection_ref = nested_query._parent
self._aggregations: List[BaseAggregation] = []

Expand All @@ -115,6 +151,22 @@ def count(self, alias: str | None = None):
self._aggregations.append(count_aggregation)
return self

def sum(self, field_ref: str | FieldPath, alias: str | None = None):
"""
Adds a sum over the nested query
"""
sum_aggregation = SumAggregation(field_ref, alias=alias)
self._aggregations.append(sum_aggregation)
return self

def avg(self, field_ref: str | FieldPath, alias: str | None = None):
"""
Adds an avg over the nested query
"""
avg_aggregation = AvgAggregation(field_ref, alias=alias)
self._aggregations.append(avg_aggregation)
return self

def add_aggregation(self, aggregation: BaseAggregation) -> None:
"""
Adds an aggregation operation to the nested query
Expand Down
42 changes: 37 additions & 5 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.

"""Classes for representing collections for the Google Cloud Firestore API."""
from __future__ import annotations
import random

from google.api_core import retry as retries

from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery
from google.cloud.firestore_v1.base_query import QueryType


from typing import (
Expand All @@ -35,12 +37,15 @@
NoReturn,
Tuple,
Union,
TYPE_CHECKING,
)

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import QueryType
from google.cloud.firestore_v1.transaction import Transaction

if TYPE_CHECKING: # pragma: NO COVER
# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.transaction import Transaction
from google.cloud.firestore_v1.field_path import FieldPath

_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"

Expand Down Expand Up @@ -244,7 +249,7 @@ def where(
op_string: Optional[str] = None,
value=None,
*,
filter=None
filter=None,
) -> QueryType:
"""Create a "where" query with this collection as parent.
Expand Down Expand Up @@ -507,6 +512,33 @@ def count(self, alias=None):
"""
return self._aggregation_query().count(alias=alias)

def sum(self, field_ref: str | FieldPath, alias=None):
"""
Adds a sum over the nested query.
:type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.
:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return self._aggregation_query().sum(field_ref, alias=alias)

def avg(self, field_ref: str | FieldPath, alias=None):
"""
Adds an avg over the nested query.
:type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.
:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return self._aggregation_query().avg(field_ref, alias=alias)


def _auto_id() -> str:
"""Generate a "random" automatically generated ID.
Expand Down
14 changes: 14 additions & 0 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@
Type,
TypeVar,
Union,
TYPE_CHECKING,
)

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.field_path import FieldPath

_BAD_DIR_STRING: str
_BAD_OP_NAN_NULL: str
_BAD_OP_STRING: str
Expand Down Expand Up @@ -970,6 +974,16 @@ def count(
) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]:
raise NotImplementedError

def sum(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]:
raise NotImplementedError

def avg(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]:
raise NotImplementedError

def get(
self,
transaction=None,
Expand Down
40 changes: 37 additions & 3 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@

from google.cloud.firestore_v1 import document
from google.cloud.firestore_v1.watch import Watch
from typing import Any, Callable, Generator, List, Optional, Type
from typing import Any, Callable, Generator, List, Optional, Type, TYPE_CHECKING

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.field_path import FieldPath


class Query(BaseQuery):
Expand Down Expand Up @@ -242,11 +245,42 @@ def count(
"""
Adds a count over the query.
:type alias: str
:param alias: (Optional) The alias for the count
:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return aggregation.AggregationQuery(self).count(alias=alias)

def sum(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.aggregation.AggregationQuery"]:
"""
Adds a sum over the query.
:type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.
:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return aggregation.AggregationQuery(self).sum(field_ref, alias=alias)

def avg(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.aggregation.AggregationQuery"]:
"""
Adds an avg over the query.
:type field_ref: [Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.
:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return aggregation.AggregationQuery(self).avg(field_ref, alias=alias)

def stream(
self,
transaction=None,
Expand Down
Loading

0 comments on commit 443475b

Please sign in to comment.