Skip to content

Commit

Permalink
fix: improve AsyncQuery typing (#782)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiterupri authored Oct 17, 2023
1 parent d07eebf commit ae1247b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 40 deletions.
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from google.cloud.firestore_v1.transaction import Transaction


class AsyncCollectionReference(BaseCollectionReference):
class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]):
"""A reference to a collection in a Firestore database.
The collection may already exist or this class can facilitate creation
Expand Down
12 changes: 7 additions & 5 deletions google/cloud/firestore_v1/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,15 @@ def _rpc_metadata(self):

return self._rpc_metadata_internal

def collection(self, *collection_path) -> BaseCollectionReference:
def collection(self, *collection_path) -> BaseCollectionReference[BaseQuery]:
raise NotImplementedError

def collection_group(self, collection_id: str) -> BaseQuery:
raise NotImplementedError

def _get_collection_reference(self, collection_id: str) -> BaseCollectionReference:
def _get_collection_reference(
self, collection_id: str
) -> BaseCollectionReference[BaseQuery]:
"""Checks validity of collection_id and then uses subclasses collection implementation.
Args:
Expand Down Expand Up @@ -325,7 +327,7 @@ def _document_path_helper(self, *document_path) -> List[str]:

def recursive_delete(
self,
reference: Union[BaseCollectionReference, BaseDocumentReference],
reference: Union[BaseCollectionReference[BaseQuery], BaseDocumentReference],
bulk_writer: Optional["BulkWriter"] = None, # type: ignore
) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -459,8 +461,8 @@ def collections(
retry: retries.Retry = None,
timeout: float = None,
) -> Union[
AsyncGenerator[BaseCollectionReference, Any],
Generator[BaseCollectionReference, Any, Any],
AsyncGenerator[BaseCollectionReference[BaseQuery], Any],
Generator[BaseCollectionReference[BaseQuery], Any, Any],
]:
raise NotImplementedError

Expand Down
27 changes: 14 additions & 13 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AsyncGenerator,
Coroutine,
Generator,
Generic,
AsyncIterator,
Iterator,
Iterable,
Expand All @@ -38,13 +39,13 @@

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import BaseQuery
from google.cloud.firestore_v1.base_query import QueryType
from google.cloud.firestore_v1.transaction import Transaction

_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"


class BaseCollectionReference(object):
class BaseCollectionReference(Generic[QueryType]):
"""A reference to a collection in a Firestore database.
The collection may already exist or this class can facilitate creation
Expand Down Expand Up @@ -108,7 +109,7 @@ def parent(self):
parent_path = self._path[:-1]
return self._client.document(*parent_path)

def _query(self) -> BaseQuery:
def _query(self) -> QueryType:
raise NotImplementedError

def _aggregation_query(self) -> BaseAggregationQuery:
Expand Down Expand Up @@ -215,10 +216,10 @@ def list_documents(
]:
raise NotImplementedError

def recursive(self) -> "BaseQuery":
def recursive(self) -> QueryType:
return self._query().recursive()

def select(self, field_paths: Iterable[str]) -> BaseQuery:
def select(self, field_paths: Iterable[str]) -> QueryType:
"""Create a "select" query with this collection as parent.
See
Expand All @@ -244,7 +245,7 @@ def where(
value=None,
*,
filter=None
) -> BaseQuery:
) -> QueryType:
"""Create a "where" query with this collection as parent.
See
Expand Down Expand Up @@ -290,7 +291,7 @@ def where(
else:
return query.where(filter=filter)

def order_by(self, field_path: str, **kwargs) -> BaseQuery:
def order_by(self, field_path: str, **kwargs) -> QueryType:
"""Create an "order by" query with this collection as parent.
See
Expand All @@ -312,7 +313,7 @@ def order_by(self, field_path: str, **kwargs) -> BaseQuery:
query = self._query()
return query.order_by(field_path, **kwargs)

def limit(self, count: int) -> BaseQuery:
def limit(self, count: int) -> QueryType:
"""Create a limited query with this collection as parent.
.. note::
Expand Down Expand Up @@ -355,7 +356,7 @@ def limit_to_last(self, count: int):
query = self._query()
return query.limit_to_last(count)

def offset(self, num_to_skip: int) -> BaseQuery:
def offset(self, num_to_skip: int) -> QueryType:
"""Skip to an offset in a query with this collection as parent.
See
Expand All @@ -375,7 +376,7 @@ def offset(self, num_to_skip: int) -> BaseQuery:

def start_at(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""Start query at a cursor with this collection as parent.
See
Expand All @@ -398,7 +399,7 @@ def start_at(

def start_after(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""Start query after a cursor with this collection as parent.
See
Expand All @@ -421,7 +422,7 @@ def start_after(

def end_before(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""End query before a cursor with this collection as parent.
See
Expand All @@ -444,7 +445,7 @@ def end_before(

def end_at(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""End query at a cursor with this collection as parent.
See
Expand Down
49 changes: 29 additions & 20 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
Optional,
Tuple,
Type,
TypeVar,
Union,
)

Expand Down Expand Up @@ -102,6 +103,8 @@

_not_passed = object()

QueryType = TypeVar("QueryType", bound="BaseQuery")


class BaseFilter(abc.ABC):
"""Base class for Filters"""
Expand Down Expand Up @@ -319,7 +322,7 @@ def _client(self):
"""
return self._parent._client

def select(self, field_paths: Iterable[str]) -> "BaseQuery":
def select(self: QueryType, field_paths: Iterable[str]) -> QueryType:
"""Project documents matching query to a limited set of fields.
See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
Expand Down Expand Up @@ -354,7 +357,7 @@ def select(self, field_paths: Iterable[str]) -> "BaseQuery":
return self._copy(projection=new_projection)

def _copy(
self,
self: QueryType,
*,
projection: Optional[query.StructuredQuery.Projection] = _not_passed,
field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed,
Expand All @@ -366,7 +369,7 @@ def _copy(
end_at: Optional[Tuple[dict, bool]] = _not_passed,
all_descendants: Optional[bool] = _not_passed,
recursive: Optional[bool] = _not_passed,
) -> "BaseQuery":
) -> QueryType:
return self.__class__(
self._parent,
projection=self._evaluate_param(projection, self._projection),
Expand All @@ -389,13 +392,13 @@ def _evaluate_param(self, value, fallback_value):
return value if value is not _not_passed else fallback_value

def where(
self,
self: QueryType,
field_path: Optional[str] = None,
op_string: Optional[str] = None,
value=None,
*,
filter=None,
) -> "BaseQuery":
) -> QueryType:
"""Filter the query on a field.
See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
Expand Down Expand Up @@ -492,7 +495,9 @@ def _make_order(field_path, direction) -> StructuredQuery.Order:
direction=_enum_from_direction(direction),
)

def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery":
def order_by(
self: QueryType, field_path: str, direction: str = ASCENDING
) -> QueryType:
"""Modify the query to add an order clause on a specific field.
See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
Expand Down Expand Up @@ -526,7 +531,7 @@ def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery":
new_orders = self._orders + (order_pb,)
return self._copy(orders=new_orders)

def limit(self, count: int) -> "BaseQuery":
def limit(self: QueryType, count: int) -> QueryType:
"""Limit a query to return at most `count` matching results.
If the current query already has a `limit` set, this will override it.
Expand All @@ -545,7 +550,7 @@ def limit(self, count: int) -> "BaseQuery":
"""
return self._copy(limit=count, limit_to_last=False)

def limit_to_last(self, count: int) -> "BaseQuery":
def limit_to_last(self: QueryType, count: int) -> QueryType:
"""Limit a query to return the last `count` matching results.
If the current query already has a `limit_to_last`
set, this will override it.
Expand All @@ -570,7 +575,7 @@ def _resolve_chunk_size(self, num_loaded: int, chunk_size: int) -> int:
return max(self._limit - num_loaded, 0)
return chunk_size

def offset(self, num_to_skip: int) -> "BaseQuery":
def offset(self: QueryType, num_to_skip: int) -> QueryType:
"""Skip to an offset in a query.
If the current query already has specified an offset, this will
Expand Down Expand Up @@ -601,11 +606,11 @@ def _check_snapshot(self, document_snapshot) -> None:
raise ValueError("Cannot use snapshot from another collection as a cursor.")

def _cursor_helper(
self,
self: QueryType,
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
before: bool,
start: bool,
) -> "BaseQuery":
) -> QueryType:
"""Set values to be used for a ``start_at`` or ``end_at`` cursor.
The values will later be used in a query protobuf.
Expand Down Expand Up @@ -658,8 +663,9 @@ def _cursor_helper(
return self._copy(**query_kwargs)

def start_at(
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
) -> "BaseQuery":
self: QueryType,
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
) -> QueryType:
"""Start query results at a particular document value.
The result set will **include** the document specified by
Expand Down Expand Up @@ -690,8 +696,9 @@ def start_at(
return self._cursor_helper(document_fields_or_snapshot, before=True, start=True)

def start_after(
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
) -> "BaseQuery":
self: QueryType,
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
) -> QueryType:
"""Start query results after a particular document value.
The result set will **exclude** the document specified by
Expand Down Expand Up @@ -723,8 +730,9 @@ def start_after(
)

def end_before(
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
) -> "BaseQuery":
self: QueryType,
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
) -> QueryType:
"""End query results before a particular document value.
The result set will **exclude** the document specified by
Expand Down Expand Up @@ -756,8 +764,9 @@ def end_before(
)

def end_at(
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
) -> "BaseQuery":
self: QueryType,
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
) -> QueryType:
"""End query results at a particular document value.
The result set will **include** the document specified by
Expand Down Expand Up @@ -1003,7 +1012,7 @@ def stream(
def on_snapshot(self, callback) -> NoReturn:
raise NotImplementedError

def recursive(self) -> "BaseQuery":
def recursive(self: QueryType) -> QueryType:
"""Returns a copy of this query whose iterator will yield all matching
documents as well as each of their descendent subcollections and documents.
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from google.cloud.firestore_v1.transaction import Transaction


class CollectionReference(BaseCollectionReference):
class CollectionReference(BaseCollectionReference[query_mod.Query]):
"""A reference to a collection in a Firestore database.
The collection may already exist or this class can facilitate creation
Expand Down

0 comments on commit ae1247b

Please sign in to comment.