Skip to content

Commit

Permalink
async aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
Linchin committed Sep 9, 2024
1 parent edd655a commit b56e43c
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 27 deletions.
61 changes: 51 additions & 10 deletions google/cloud/firestore_v1/async_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

from google.api_core import gapic_v1
from google.api_core import retry_async as retries
Expand All @@ -32,9 +32,12 @@
BaseAggregationQuery,
_query_response_to_result,
)
from google.cloud.firestore_v1.query_results import QueryResultsList

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_aggregation import AggregationResult
from google.cloud.firestore_v1.query_profile import ExplainMetrics, ExplainOptions
import google.cloud.firestore_v1.types.query_profile as query_profile_pb


class AsyncAggregationQuery(BaseAggregationQuery):
Expand All @@ -53,7 +56,9 @@ async def get(
retries.AsyncRetry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> List[List[AggregationResult]]:
*,
explain_options: Optional[ExplainOptions] = None,
) -> QueryResultsList[List[AggregationResult]]:
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages.
Expand All @@ -69,23 +74,41 @@ async def get(
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
List[List[AggregationResult]]: The aggregation query results
QueryResultsList[List[AggregationResult]]: The aggregation query results.
"""
explain_metrics: ExplainMetrics | None = None

stream_result = self.stream(
transaction=transaction, retry=retry, timeout=timeout
transaction=transaction,
retry=retry,
timeout=timeout,
explain_options=explain_options,
)
result = [aggregation async for aggregation in stream_result]
return result # type: ignore

if explain_options is None:
explain_metrics = None
else:
explain_metrics = await stream_result.get_explain_metrics()

return QueryResultsList(result, explain_options, explain_metrics)

async def _make_stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> Union[AsyncGenerator[List[AggregationResult], None]]:
explain_options: Optional[ExplainOptions] = None,
) -> AsyncStreamGenerator[
List[AggregationResult] | query_profile_pb.ExplainMetrics
]:
"""Internal method for stream(). Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns a generator which
Expand All @@ -105,15 +128,22 @@ async def _make_stream(
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Yields:
:class:`~google.cloud.firestore_v1.base_aggregation.AggregationResult`:
List[AggregationResult] | query_profile_pb.ExplainMetrics:
The result of aggregations of this query
"""
metrics: query_profile_pb.ExplainMetrics | None = None

request, kwargs = self._prep_stream(
transaction,
retry,
timeout,
explain_options,
)

response_iterator = await self._client._firestore_api.run_aggregation_query(
Expand All @@ -126,12 +156,18 @@ async def _make_stream(
result = _query_response_to_result(response)
yield result

if metrics is None and response.explain_metrics:
metrics = response.explain_metrics
yield metrics

def stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> "AsyncStreamGenerator[DocumentSnapshot]":
*,
explain_options: Optional[ExplainOptions] = None,
) -> AsyncStreamGenerator[List[AggregationResult]]:
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns a generator
Expand All @@ -150,15 +186,20 @@ def stream(
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
`AsyncStreamGenerator[DocumentSnapshot]`:
`AsyncStreamGenerator[List[AggregationResult]]`:
A generator of the query results.
"""

inner_generator = self._make_stream(
transaction=transaction,
retry=retry,
timeout=timeout,
explain_options=explain_options,
)
return AsyncStreamGenerator(inner_generator)
188 changes: 171 additions & 17 deletions tests/unit/v1/test_async_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,7 @@
# limitations under the License.

from datetime import datetime, timedelta, timezone

import pytest

from google.cloud.firestore_v1.base_aggregation import (
AggregationResult,
AvgAggregation,
CountAggregation,
SumAggregation,
)
from tests.unit.v1._test_helpers import (
make_aggregation_query_response,
make_async_aggregation_query,
Expand All @@ -30,6 +22,17 @@
)
from tests.unit.v1.test__helpers import AsyncIter, AsyncMock

from google.cloud.firestore_v1.base_aggregation import (
AggregationResult,
AvgAggregation,
CountAggregation,
SumAggregation,
)
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.query_profile import ExplainMetrics, QueryExplainError
from google.cloud.firestore_v1.query_results import QueryResultsList


_PROJECT = "PROJECT"


Expand Down Expand Up @@ -292,8 +295,36 @@ def test_async_aggregation_query_prep_stream_with_transaction():
assert kwargs == {"retry": None}


def test_async_aggregation_query_prep_stream_with_explain_options():
from google.cloud.firestore_v1 import query_profile

client = make_async_client()
parent = client.collection("dee")
query = make_async_query(parent)
aggregation_query = make_async_aggregation_query(query)

aggregation_query.count(alias="all")
aggregation_query.sum("someref", alias="sumall")
aggregation_query.avg("anotherref", alias="avgall")

explain_options = query_profile.ExplainOptions(analyze=True)
request, kwargs = aggregation_query._prep_stream(explain_options=explain_options)

parent_path, _ = parent._parent_info()
expected_request = {
"parent": parent_path,
"structured_aggregation_query": aggregation_query._to_protobuf(),
"transaction": None,
"explain_options": explain_options._to_dict(),
}
assert request == expected_request
assert kwargs == {"retry": None}


@pytest.mark.asyncio
async def _async_aggregation_query_get_helper(retry=None, timeout=None, read_time=None):
async def _async_aggregation_query_get_helper(
retry=None, timeout=None, read_time=None, explain_options=None
):
from google.cloud._helpers import _datetime_to_pb_timestamp

from google.cloud.firestore_v1 import _helpers
Expand All @@ -312,15 +343,23 @@ async def _async_aggregation_query_get_helper(retry=None, timeout=None, read_tim
aggregation_query.count(alias="all")

aggregation_result = AggregationResult(alias="total", value=5, read_time=read_time)

if explain_options is not None:
explain_metrics = {"execution_stats": {"results_returned": 1}}
else:
explain_metrics = None

response_pb = make_aggregation_query_response(
[aggregation_result], read_time=read_time
[aggregation_result],
read_time=read_time,
explain_metrics=explain_metrics,
)
firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb])
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

# Execute the query and check the response.
returned = await aggregation_query.get(**kwargs)
assert isinstance(returned, list)
returned = await aggregation_query.get(**kwargs, explain_options=explain_options)
assert isinstance(returned, QueryResultsList)
assert len(returned) == 1

for result in returned:
Expand All @@ -331,14 +370,25 @@ async def _async_aggregation_query_get_helper(retry=None, timeout=None, read_tim
result_datetime = _datetime_to_pb_timestamp(r.read_time)
assert result_datetime == read_time

if explain_options is None:
with pytest.raises(QueryExplainError, match="explain_options not set"):
returned.get_explain_metrics()
else:
explain_metrics = returned.get_explain_metrics()
assert isinstance(explain_metrics, ExplainMetrics)
assert explain_metrics.execution_stats.results_returned == 1

# Verify the mock call.
parent_path, _ = parent._parent_info()
expected_request = {
"parent": parent_path,
"structured_aggregation_query": aggregation_query._to_protobuf(),
"transaction": None,
}
if explain_options is not None:
expected_request["explain_options"] = explain_options._to_dict()
firestore_api.run_aggregation_query.assert_called_once_with(
request={
"parent": parent_path,
"structured_aggregation_query": aggregation_query._to_protobuf(),
"transaction": None,
},
request=expected_request,
metadata=client._rpc_metadata,
**kwargs,
)
Expand All @@ -358,6 +408,14 @@ async def test_async_aggregation_query_get_with_readtime():
await _async_aggregation_query_get_helper(read_time=read_time)


@pytest.mark.asyncio
async def test_async_aggregation_query_get_with_explain_options():
from google.cloud.firestore_v1.query_profile import ExplainOptions

explain_options = ExplainOptions(analyze=True)
await _async_aggregation_query_get_helper(explain_options=explain_options)


@pytest.mark.asyncio
async def test_async_aggregation_query_get_retry_timeout():
from google.api_core.retry import Retry
Expand Down Expand Up @@ -481,3 +539,99 @@ async def test_async_aggregation_from_query():
metadata=client._rpc_metadata,
**kwargs,
)


async def _async_aggregation_query_stream_helper(
retry=None,
timeout=None,
read_time=None,
explain_options=None,
):
from google.cloud._helpers import _datetime_to_pb_timestamp

from google.cloud.firestore_v1 import _helpers

# Create a minimal fake GAPIC.
firestore_api = AsyncMock(spec=["run_aggregation_query"])

# Attach the fake GAPIC to a real client.
client = make_async_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dee")
query = make_async_query(parent)
aggregation_query = make_async_aggregation_query(query)
aggregation_query.count(alias="all")

if explain_options is not None and explain_options.analyze is False:
results_list = []
else:
aggregation_result = AggregationResult(
alias="total", value=5, read_time=read_time
)
results_list = [aggregation_result]

if explain_options is not None:
explain_metrics = {"execution_stats": {"results_returned": 1}}
else:
explain_metrics = None
response_pb = make_aggregation_query_response(
results_list,
read_time=read_time,
explain_metrics=explain_metrics,
)
firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb])
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

# Execute the query and check the response.
returned = aggregation_query.stream(**kwargs, explain_options=explain_options)
assert isinstance(returned, AsyncStreamGenerator)

results = []
async for result in returned:
for r in result:
assert r.alias == aggregation_result.alias
assert r.value == aggregation_result.value
if read_time is not None:
result_datetime = _datetime_to_pb_timestamp(r.read_time)
assert result_datetime == read_time
results.append(result)
assert len(results) == len(results_list)

if explain_options is None:
with pytest.raises(QueryExplainError, match="explain_options not set"):
await returned.get_explain_metrics()
else:
explain_metrics = await returned.get_explain_metrics()
assert isinstance(explain_metrics, ExplainMetrics)
assert explain_metrics.execution_stats.results_returned == 1

parent_path, _ = parent._parent_info()
expected_request = {
"parent": parent_path,
"structured_aggregation_query": aggregation_query._to_protobuf(),
"transaction": None,
}
if explain_options is not None:
expected_request["explain_options"] = explain_options._to_dict()

# Verify the mock call.
firestore_api.run_aggregation_query.assert_called_once_with(
request=expected_request,
metadata=client._rpc_metadata,
**kwargs,
)


@pytest.mark.asyncio
async def test_aggregation_query_stream():
await _async_aggregation_query_stream_helper()


@pytest.mark.asyncio
async def test_aggregation_query_stream_w_explain_options():
from google.cloud.firestore_v1.query_profile import ExplainOptions

explain_options = ExplainOptions(analyze=True)
await _async_aggregation_query_stream_helper(explain_options=explain_options)

0 comments on commit b56e43c

Please sign in to comment.