Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Sum/Avg aggregation queries #715

Merged
merged 40 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
608464a
Feat: Sum/Avg Feature
Mariatta May 10, 2023
7e4d335
Merge branch 'main' into feat-sum-avg-pr
daniel-sanche Aug 16, 2023
f1322b5
fixed proto sum attribute name
daniel-sanche Aug 18, 2023
ed247b9
added query tests with alias unset
daniel-sanche Aug 18, 2023
b3d1d7a
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Aug 18, 2023
eaac103
added async tests
daniel-sanche Aug 18, 2023
19c2c4b
added missing decorators
daniel-sanche Aug 21, 2023
20ad46e
fixed wrong expected values in tests
daniel-sanche Aug 21, 2023
c2a804d
fixed empty avg aggregations
daniel-sanche Aug 21, 2023
59a02c2
ran blacken
daniel-sanche Aug 21, 2023
91dbc48
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Aug 21, 2023
3901dd7
Merge branch 'feat-sum-avg-pr' of https://github.com/googleapis/pytho…
gcf-owl-bot[bot] Aug 21, 2023
dff8459
aggregation test should cover all aggregations
daniel-sanche Aug 24, 2023
3e66842
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Aug 24, 2023
ed06271
fixed async test
daniel-sanche Aug 24, 2023
47d45da
improved transaction tests
daniel-sanche Aug 30, 2023
90b006c
cleaned up new tests
daniel-sanche Aug 31, 2023
5f00661
removed test logic that belongs in unit tests
daniel-sanche Aug 31, 2023
2340fea
ran blacken
daniel-sanche Aug 31, 2023
63f252a
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Aug 31, 2023
109e8b8
Merge branch 'feat-sum-avg-pr' of https://github.com/googleapis/pytho…
gcf-owl-bot[bot] Aug 31, 2023
85c8dbd
reverted removed line
daniel-sanche Sep 6, 2023
93ff4bb
Merge branch 'main' into feat-sum-avg-pr
kolea2 Oct 3, 2023
e60aa75
fix docstrings
daniel-sanche Oct 3, 2023
6a0950a
accept FieldPath for aggregations
daniel-sanche Oct 3, 2023
e03871e
fixed docstrings
daniel-sanche Oct 3, 2023
bb9e8bb
made test changes to avoid index requirements
daniel-sanche Oct 4, 2023
a488a6d
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Oct 4, 2023
b258fde
fixed lint issues
daniel-sanche Oct 4, 2023
6d3b154
added field path to collections
daniel-sanche Oct 5, 2023
4b6de1a
Merge branch 'main' into feat-sum-avg-pr
daniel-sanche Oct 10, 2023
7b40aa7
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Oct 10, 2023
c06aa19
fixed docs issue
daniel-sanche Oct 10, 2023
176ce2d
added tests with start_at
daniel-sanche Oct 10, 2023
1b16357
add no cover marks to TYPE_CHECKING
daniel-sanche Oct 11, 2023
d42bf07
Merge branch 'main' into feat-sum-avg-pr
daniel-sanche Oct 19, 2023
0e969a6
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Oct 19, 2023
82bde42
skip cursor aggregations
daniel-sanche Oct 19, 2023
9a10dc2
import query type
daniel-sanche Oct 19, 2023
3bb4bf5
fixed no cover comments
daniel-sanche Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@
)

from google.cloud.firestore_v1 import async_document
from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from typing import AsyncGenerator, List, Optional, Type

# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction
from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING

from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery
if TYPE_CHECKING: # pragma: no cover
# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction
from google.cloud.firestore_v1.field_path import FieldPath


class AsyncQuery(BaseQuery):
Expand Down Expand Up @@ -222,15 +223,47 @@ def count(
"""Adds a count over the nested query.

Args:
alias
(Optional[str]): The alias for the count
alias(Optional[str]): Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.

Returns:
:class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`:
An instance of an AsyncAggregationQuery object
"""
return AsyncAggregationQuery(self).count(alias=alias)

def sum(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Although I used the same constructor with an additional accepted type, which is how Python handles this kind of thing

self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]:
"""Adds a sum over the nested query.

Args:
field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across.
alias(Optional[str]): Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.

Returns:
:class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`:
An instance of an AsyncAggregationQuery object
"""
return AsyncAggregationQuery(self).sum(field_ref, alias=alias)

def avg(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]:
"""Adds an avg over the nested query.

Args:
field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across.
alias(Optional[str]): Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.

Returns:
:class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`:
An instance of an AsyncAggregationQuery object
"""
return AsyncAggregationQuery(self).avg(field_ref, alias=alias)

async def stream(
self,
transaction=None,
Expand Down
66 changes: 59 additions & 7 deletions google/cloud/firestore_v1/base_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from google.api_core import retry as retries


from google.cloud.firestore_v1.field_path import FieldPath
from google.cloud.firestore_v1.types import RunAggregationQueryResponse

from google.cloud.firestore_v1.types import StructuredAggregationQuery
from google.cloud.firestore_v1 import _helpers

Expand All @@ -60,14 +60,17 @@ def __repr__(self):


class BaseAggregation(ABC):
def __init__(self, alias: str | None = None):
self.alias = alias

@abc.abstractmethod
def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""


class CountAggregation(BaseAggregation):
def __init__(self, alias: str | None = None):
self.alias = alias
super(CountAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
Expand All @@ -77,13 +80,48 @@ def _to_protobuf(self):
return aggregation_pb


class SumAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
super(SumAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum()
aggregation_pb.sum.field.field_path = self.field_ref
return aggregation_pb


class AvgAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
super(AvgAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg()
aggregation_pb.avg.field.field_path = self.field_ref
return aggregation_pb


def _query_response_to_result(
response_pb: RunAggregationQueryResponse,
) -> List[AggregationResult]:
results = [
AggregationResult(
alias=key,
value=response_pb.result.aggregate_fields[key].integer_value,
value=response_pb.result.aggregate_fields[key].integer_value
or response_pb.result.aggregate_fields[key].double_value,
read_time=response_pb.read_time,
)
for key in response_pb.result.aggregate_fields.pb.keys()
Expand All @@ -95,11 +133,9 @@ def _query_response_to_result(
class BaseAggregationQuery(ABC):
"""Represents an aggregation query to the Firestore API."""

def __init__(
self,
nested_query,
) -> None:
def __init__(self, nested_query, alias: str | None = None) -> None:
self._nested_query = nested_query
self._alias = alias
self._collection_ref = nested_query._parent
self._aggregations: List[BaseAggregation] = []

Expand All @@ -115,6 +151,22 @@ def count(self, alias: str | None = None):
self._aggregations.append(count_aggregation)
return self

def sum(self, field_ref: str | FieldPath, alias: str | None = None):
"""
Adds a sum over the nested query
"""
sum_aggregation = SumAggregation(field_ref, alias=alias)
self._aggregations.append(sum_aggregation)
return self

def avg(self, field_ref: str | FieldPath, alias: str | None = None):
"""
Adds an avg over the nested query
"""
avg_aggregation = AvgAggregation(field_ref, alias=alias)
self._aggregations.append(avg_aggregation)
return self

def add_aggregation(self, aggregation: BaseAggregation) -> None:
"""
Adds an aggregation operation to the nested query
Expand Down
42 changes: 37 additions & 5 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Classes for representing collections for the Google Cloud Firestore API."""
from __future__ import annotations
import random

from google.api_core import retry as retries
Expand All @@ -34,12 +35,16 @@
NoReturn,
Tuple,
Union,
TYPE_CHECKING,
)

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import BaseQuery
from google.cloud.firestore_v1.transaction import Transaction
if TYPE_CHECKING: # pragma: no cover
# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import BaseQuery
from google.cloud.firestore_v1.transaction import Transaction
from google.cloud.firestore_v1.field_path import FieldPath


_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"

Expand Down Expand Up @@ -243,7 +248,7 @@ def where(
op_string: Optional[str] = None,
value=None,
*,
filter=None
filter=None,
) -> BaseQuery:
"""Create a "where" query with this collection as parent.

Expand Down Expand Up @@ -506,6 +511,33 @@ def count(self, alias=None):
"""
return self._aggregation_query().count(alias=alias)

def sum(self, field_ref: str | FieldPath, alias=None):
"""
Adds a sum over the nested query.

:type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.

:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.

"""
return self._aggregation_query().sum(field_ref, alias=alias)

def avg(self, field_ref: str | FieldPath, alias=None):
"""
Adds an avg over the nested query.

:type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.

:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return self._aggregation_query().avg(field_ref, alias=alias)


def _auto_id() -> str:
"""Generate a "random" automatically generated ID.
Expand Down
14 changes: 14 additions & 0 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@
Tuple,
Type,
Union,
TYPE_CHECKING,
)

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot

if TYPE_CHECKING: # pragma: no cover
from google.cloud.firestore_v1.field_path import FieldPath

_BAD_DIR_STRING: str
_BAD_OP_NAN_NULL: str
_BAD_OP_STRING: str
Expand Down Expand Up @@ -961,6 +965,16 @@ def count(
) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]:
raise NotImplementedError

def sum(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]:
raise NotImplementedError

def avg(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]:
raise NotImplementedError

def get(
self,
transaction=None,
Expand Down
40 changes: 37 additions & 3 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@

from google.cloud.firestore_v1 import document
from google.cloud.firestore_v1.watch import Watch
from typing import Any, Callable, Generator, List, Optional, Type
from typing import Any, Callable, Generator, List, Optional, Type, TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
from google.cloud.firestore_v1.field_path import FieldPath


class Query(BaseQuery):
Expand Down Expand Up @@ -242,11 +245,42 @@ def count(
"""
Adds a count over the query.

:type alias: str
:param alias: (Optional) The alias for the count
:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return aggregation.AggregationQuery(self).count(alias=alias)

def sum(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.aggregation.AggregationQuery"]:
"""
Adds a sum over the query.

:type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.

:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return aggregation.AggregationQuery(self).sum(field_ref, alias=alias)

def avg(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.aggregation.AggregationQuery"]:
"""
Adds an avg over the query.

:type field_ref: [Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.

:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return aggregation.AggregationQuery(self).avg(field_ref, alias=alias)

def stream(
self,
transaction=None,
Expand Down
Loading
Loading