Skip to content

[DRAFT] feat: query to pipeline conversion #1071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: pipeline_queries_3_stable_stages
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion google/cloud/firestore_v1/base_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from __future__ import annotations

import abc
import itertools

from abc import ABC
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union, Iterable

from google.api_core import gapic_v1
from google.api_core import retry as retries
Expand All @@ -33,6 +34,10 @@
from google.cloud.firestore_v1.types import (
StructuredAggregationQuery,
)
from google.cloud.firestore_v1.pipeline_expressions import Accumulator
from google.cloud.firestore_v1.pipeline_expressions import Count
from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias
from google.cloud.firestore_v1.pipeline_expressions import Field

# Types needed only for Type Hints
if TYPE_CHECKING: # pragma: NO COVER
Expand Down Expand Up @@ -66,6 +71,9 @@ def __init__(self, alias: str, value: float, read_time=None):
def __repr__(self):
return f"<Aggregation alias={self.alias}, value={self.value}, readtime={self.read_time}>"

def _to_dict(self):
return {self.alias: self.value}


class BaseAggregation(ABC):
def __init__(self, alias: str | None = None):
Expand All @@ -75,6 +83,27 @@ def __init__(self, alias: str | None = None):
def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""

@abc.abstractmethod
def _to_pipeline_expr(
self, autoindexer: Iterable[int]
) -> ExprWithAlias[Accumulator]:
"""
Convert this instance to a pipeline expression for use with pipeline.aggregate()

Args:
autoindexer: If an alias isn't supplied, an one should be created with the format "field_n"
The autoindexer is an iterable that provides the `n` value to use for each expression
"""

def _pipeline_alias(self, autoindexer):
"""
Helper to build the alias for the pipeline expression
"""
if self.alias is not None:
return self.alias
else:
return f"field_{next(autoindexer)}"


class CountAggregation(BaseAggregation):
def __init__(self, alias: str | None = None):
Expand All @@ -88,6 +117,9 @@ def _to_protobuf(self):
aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count()
return aggregation_pb

def _to_pipeline_expr(self, autoindexer: Iterable[int]):
return Count().as_(self._pipeline_alias(autoindexer))


class SumAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
Expand All @@ -107,6 +139,9 @@ def _to_protobuf(self):
aggregation_pb.sum.field.field_path = self.field_ref
return aggregation_pb

def _to_pipeline_expr(self, autoindexer: Iterable[int]):
return Field.of(self.field_ref).sum().as_(self._pipeline_alias(autoindexer))


class AvgAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
Expand All @@ -126,6 +161,9 @@ def _to_protobuf(self):
aggregation_pb.avg.field.field_path = self.field_ref
return aggregation_pb

def _to_pipeline_expr(self, autoindexer: Iterable[int]):
return Field.of(self.field_ref).avg().as_(self._pipeline_alias(autoindexer))


def _query_response_to_result(
response_pb,
Expand Down Expand Up @@ -317,3 +355,20 @@ def stream(
StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]:
A generator of the query results.
"""

def pipeline(self):
"""
Convert this query into a Pipeline

Queries containing a `cursor` or `limit_to_last` are not currently supported

Raises:
- ValueError: raised if Query wasn't created with an associated client
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
Returns:
a Pipeline representing the query
"""
# use autoindexer to keep track of which field number to use for un-aliased fields
autoindexer = itertools.count(start=1)
exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations]
return self._nested_query.pipeline().aggregate(*exprs)
13 changes: 13 additions & 0 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,19 @@ def find_nearest(
distance_threshold=distance_threshold,
)

def pipeline(self):
"""
Convert this query into a Pipeline

Queries containing a `cursor` or `limit_to_last` are not currently supported

Raises:
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
Returns:
a Pipeline representing the query
"""
return self._query().pipeline()


def _auto_id() -> str:
"""Generate a "random" automatically generated ID.
Expand Down
69 changes: 69 additions & 0 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
query,
)
from google.cloud.firestore_v1.vector import Vector
from google.cloud.firestore_v1 import pipeline_expressions

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
Expand Down Expand Up @@ -1128,6 +1129,74 @@ def recursive(self: QueryType) -> QueryType:

return copied

def pipeline(self):
"""
Convert this query into a Pipeline

Queries containing a `cursor` or `limit_to_last` are not currently supported

Raises:
- ValueError: raised if Query wasn't created with an associated client
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
Returns:
a Pipeline representing the query
"""
if not self._client:
raise ValueError("Query does not have an associated client")
if self._all_descendants:
ppl = self._client.pipeline().collection_group(self._parent.id)
else:
ppl = self._client.pipeline().collection(self._parent._path)

# Filters
for filter_ in self._field_filters:
ppl = ppl.where(
pipeline_expressions.FilterCondition._from_query_filter_pb(
filter_, self._client
)
)

# Projections
if self._projection and self._projection.fields:
ppl = ppl.select(*[field.field_path for field in self._projection.fields])

# Orders
orders = self._normalize_orders()
if orders:
exists = []
orderings = []
for order in orders:
field = pipeline_expressions.Field.of(order.field.field_path)
exists.append(field.exists())
direction = (
"ascending"
if order.direction == StructuredQuery.Direction.ASCENDING
else "descending"
)
orderings.append(pipeline_expressions.Ordering(field, direction))

# Add exists filters to match Query's implicit orderby semantics.
if len(exists) == 1:
ppl = ppl.where(exists[0])
else:
ppl = ppl.where(pipeline_expressions.And(*exists))

# Add sort orderings
ppl = ppl.sort(*orderings)

# Cursors, Limit and Offset
if self._start_at or self._end_at or self._limit_to_last:
raise NotImplementedError(
"Query to Pipeline conversion: cursors and limit_to_last is not supported yet."
)
else: # Limit & Offset without cursors
if self._offset:
ppl = ppl.offset(self._offset)
if self._limit:
ppl = ppl.limit(self._limit)

return ppl

def _comparator(self, doc1, doc2) -> int:
_orders = self._orders

Expand Down
2 changes: 1 addition & 1 deletion tests/system/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@

# run all tests against default database, and a named database
# TODO: add enterprise mode when GA (RunQuery not currently supported)
TEST_DATABASES = [None, FIRESTORE_OTHER_DB]
TEST_DATABASES = [None, FIRESTORE_OTHER_DB, FIRESTORE_ENTERPRISE_DB]
Loading