diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index e997455092..293a1e0f5b 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -32,7 +32,7 @@ from google.cloud.firestore_v1.transaction import Transaction -class AsyncCollectionReference(BaseCollectionReference): +class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]): """A reference to a collection in a Firestore database. The collection may already exist or this class can facilitate creation diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index bed9d4c2a4..345e061428 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -262,13 +262,15 @@ def _rpc_metadata(self): return self._rpc_metadata_internal - def collection(self, *collection_path) -> BaseCollectionReference: + def collection(self, *collection_path) -> BaseCollectionReference[BaseQuery]: raise NotImplementedError def collection_group(self, collection_id: str) -> BaseQuery: raise NotImplementedError - def _get_collection_reference(self, collection_id: str) -> BaseCollectionReference: + def _get_collection_reference( + self, collection_id: str + ) -> BaseCollectionReference[BaseQuery]: """Checks validity of collection_id and then uses subclasses collection implementation. Args: @@ -325,7 +327,7 @@ def _document_path_helper(self, *document_path) -> List[str]: def recursive_delete( self, - reference: Union[BaseCollectionReference, BaseDocumentReference], + reference: Union[BaseCollectionReference[BaseQuery], BaseDocumentReference], bulk_writer: Optional["BulkWriter"] = None, # type: ignore ) -> int: raise NotImplementedError @@ -459,8 +461,8 @@ def collections( retry: retries.Retry = None, timeout: float = None, ) -> Union[ - AsyncGenerator[BaseCollectionReference, Any], - Generator[BaseCollectionReference, Any, Any], + AsyncGenerator[BaseCollectionReference[BaseQuery], Any], + Generator[BaseCollectionReference[BaseQuery], Any, Any], ]: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 3964dfa162..dd74bf1a00 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -28,6 +28,7 @@ AsyncGenerator, Coroutine, Generator, + Generic, AsyncIterator, Iterator, Iterable, @@ -38,13 +39,13 @@ # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot -from google.cloud.firestore_v1.base_query import BaseQuery +from google.cloud.firestore_v1.base_query import QueryType from google.cloud.firestore_v1.transaction import Transaction _AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" -class BaseCollectionReference(object): +class BaseCollectionReference(Generic[QueryType]): """A reference to a collection in a Firestore database. The collection may already exist or this class can facilitate creation @@ -108,7 +109,7 @@ def parent(self): parent_path = self._path[:-1] return self._client.document(*parent_path) - def _query(self) -> BaseQuery: + def _query(self) -> QueryType: raise NotImplementedError def _aggregation_query(self) -> BaseAggregationQuery: @@ -215,10 +216,10 @@ def list_documents( ]: raise NotImplementedError - def recursive(self) -> "BaseQuery": + def recursive(self) -> QueryType: return self._query().recursive() - def select(self, field_paths: Iterable[str]) -> BaseQuery: + def select(self, field_paths: Iterable[str]) -> QueryType: """Create a "select" query with this collection as parent. See @@ -244,7 +245,7 @@ def where( value=None, *, filter=None - ) -> BaseQuery: + ) -> QueryType: """Create a "where" query with this collection as parent. See @@ -290,7 +291,7 @@ def where( else: return query.where(filter=filter) - def order_by(self, field_path: str, **kwargs) -> BaseQuery: + def order_by(self, field_path: str, **kwargs) -> QueryType: """Create an "order by" query with this collection as parent. See @@ -312,7 +313,7 @@ def order_by(self, field_path: str, **kwargs) -> BaseQuery: query = self._query() return query.order_by(field_path, **kwargs) - def limit(self, count: int) -> BaseQuery: + def limit(self, count: int) -> QueryType: """Create a limited query with this collection as parent. .. note:: @@ -355,7 +356,7 @@ def limit_to_last(self, count: int): query = self._query() return query.limit_to_last(count) - def offset(self, num_to_skip: int) -> BaseQuery: + def offset(self, num_to_skip: int) -> QueryType: """Skip to an offset in a query with this collection as parent. See @@ -375,7 +376,7 @@ def offset(self, num_to_skip: int) -> BaseQuery: def start_at( self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> BaseQuery: + ) -> QueryType: """Start query at a cursor with this collection as parent. See @@ -398,7 +399,7 @@ def start_at( def start_after( self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> BaseQuery: + ) -> QueryType: """Start query after a cursor with this collection as parent. See @@ -421,7 +422,7 @@ def start_after( def end_before( self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> BaseQuery: + ) -> QueryType: """End query before a cursor with this collection as parent. See @@ -444,7 +445,7 @@ def end_before( def end_at( self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> BaseQuery: + ) -> QueryType: """End query at a cursor with this collection as parent. See diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 6c04abbcd7..c179109835 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -47,6 +47,7 @@ Optional, Tuple, Type, + TypeVar, Union, ) @@ -102,6 +103,8 @@ _not_passed = object() +QueryType = TypeVar("QueryType", bound="BaseQuery") + class BaseFilter(abc.ABC): """Base class for Filters""" @@ -319,7 +322,7 @@ def _client(self): """ return self._parent._client - def select(self, field_paths: Iterable[str]) -> "BaseQuery": + def select(self: QueryType, field_paths: Iterable[str]) -> QueryType: """Project documents matching query to a limited set of fields. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -354,7 +357,7 @@ def select(self, field_paths: Iterable[str]) -> "BaseQuery": return self._copy(projection=new_projection) def _copy( - self, + self: QueryType, *, projection: Optional[query.StructuredQuery.Projection] = _not_passed, field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed, @@ -366,7 +369,7 @@ def _copy( end_at: Optional[Tuple[dict, bool]] = _not_passed, all_descendants: Optional[bool] = _not_passed, recursive: Optional[bool] = _not_passed, - ) -> "BaseQuery": + ) -> QueryType: return self.__class__( self._parent, projection=self._evaluate_param(projection, self._projection), @@ -389,13 +392,13 @@ def _evaluate_param(self, value, fallback_value): return value if value is not _not_passed else fallback_value def where( - self, + self: QueryType, field_path: Optional[str] = None, op_string: Optional[str] = None, value=None, *, filter=None, - ) -> "BaseQuery": + ) -> QueryType: """Filter the query on a field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -492,7 +495,9 @@ def _make_order(field_path, direction) -> StructuredQuery.Order: direction=_enum_from_direction(direction), ) - def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery": + def order_by( + self: QueryType, field_path: str, direction: str = ASCENDING + ) -> QueryType: """Modify the query to add an order clause on a specific field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -526,7 +531,7 @@ def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery": new_orders = self._orders + (order_pb,) return self._copy(orders=new_orders) - def limit(self, count: int) -> "BaseQuery": + def limit(self: QueryType, count: int) -> QueryType: """Limit a query to return at most `count` matching results. If the current query already has a `limit` set, this will override it. @@ -545,7 +550,7 @@ def limit(self, count: int) -> "BaseQuery": """ return self._copy(limit=count, limit_to_last=False) - def limit_to_last(self, count: int) -> "BaseQuery": + def limit_to_last(self: QueryType, count: int) -> QueryType: """Limit a query to return the last `count` matching results. If the current query already has a `limit_to_last` set, this will override it. @@ -570,7 +575,7 @@ def _resolve_chunk_size(self, num_loaded: int, chunk_size: int) -> int: return max(self._limit - num_loaded, 0) return chunk_size - def offset(self, num_to_skip: int) -> "BaseQuery": + def offset(self: QueryType, num_to_skip: int) -> QueryType: """Skip to an offset in a query. If the current query already has specified an offset, this will @@ -601,11 +606,11 @@ def _check_snapshot(self, document_snapshot) -> None: raise ValueError("Cannot use snapshot from another collection as a cursor.") def _cursor_helper( - self, + self: QueryType, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], before: bool, start: bool, - ) -> "BaseQuery": + ) -> QueryType: """Set values to be used for a ``start_at`` or ``end_at`` cursor. The values will later be used in a query protobuf. @@ -658,8 +663,9 @@ def _cursor_helper( return self._copy(**query_kwargs) def start_at( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + self: QueryType, + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + ) -> QueryType: """Start query results at a particular document value. The result set will **include** the document specified by @@ -690,8 +696,9 @@ def start_at( return self._cursor_helper(document_fields_or_snapshot, before=True, start=True) def start_after( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + self: QueryType, + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + ) -> QueryType: """Start query results after a particular document value. The result set will **exclude** the document specified by @@ -723,8 +730,9 @@ def start_after( ) def end_before( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + self: QueryType, + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + ) -> QueryType: """End query results before a particular document value. The result set will **exclude** the document specified by @@ -756,8 +764,9 @@ def end_before( ) def end_at( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + self: QueryType, + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + ) -> QueryType: """End query results at a particular document value. The result set will **include** the document specified by @@ -1003,7 +1012,7 @@ def stream( def on_snapshot(self, callback) -> NoReturn: raise NotImplementedError - def recursive(self) -> "BaseQuery": + def recursive(self: QueryType) -> QueryType: """Returns a copy of this query whose iterator will yield all matching documents as well as each of their descendent subcollections and documents. diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 12e9ec883d..f6ba1833d6 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -31,7 +31,7 @@ from google.cloud.firestore_v1.transaction import Transaction -class CollectionReference(BaseCollectionReference): +class CollectionReference(BaseCollectionReference[query_mod.Query]): """A reference to a collection in a Firestore database. The collection may already exist or this class can facilitate creation