diff --git a/google/__init__.py b/google/__init__.py index aa5aeae602..175169ed36 100644 --- a/google/__init__.py +++ b/google/__init__.py @@ -19,4 +19,4 @@ except ImportError: import pkgutil - __path__ = pkgutil.extend_path(__path__, __name__) + __path__ = pkgutil.extend_path(__path__, __name__) # type: ignore diff --git a/google/cloud/__init__.py b/google/cloud/__init__.py index aa5aeae602..175169ed36 100644 --- a/google/cloud/__init__.py +++ b/google/cloud/__init__.py @@ -19,4 +19,4 @@ except ImportError: import pkgutil - __path__ = pkgutil.extend_path(__path__, __name__) + __path__ = pkgutil.extend_path(__path__, __name__) # type: ignore diff --git a/google/cloud/firestore_bundle/bundle.py b/google/cloud/firestore_bundle/bundle.py index 73a53aadb5..1bf72552ca 100644 --- a/google/cloud/firestore_bundle/bundle.py +++ b/google/cloud/firestore_bundle/bundle.py @@ -16,30 +16,25 @@ import datetime import json +from typing import Dict, List, Optional, Union + +from google.cloud._helpers import _datetime_to_pb_timestamp +from google.cloud._helpers import UTC +from google.protobuf.timestamp_pb2 import Timestamp +from google.protobuf import json_format -from google.cloud.firestore_bundle.types.bundle import ( - BundledDocumentMetadata, - BundledQuery, - BundleElement, - BundleMetadata, - NamedQuery, -) -from google.cloud._helpers import _datetime_to_pb_timestamp, UTC # type: ignore from google.cloud.firestore_bundle._helpers import limit_type_of_query +from google.cloud.firestore_bundle.types.bundle import BundleElement +from google.cloud.firestore_bundle.types.bundle import BundleMetadata +from google.cloud.firestore_bundle.types.bundle import BundledDocumentMetadata +from google.cloud.firestore_bundle.types.bundle import BundledQuery +from google.cloud.firestore_bundle.types.bundle import NamedQuery +from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.base_client import BaseClient from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.base_query import BaseQuery from google.cloud.firestore_v1.document import DocumentReference -from google.cloud.firestore_v1 import _helpers -from google.protobuf.timestamp_pb2 import Timestamp # type: ignore -from google.protobuf import json_format # type: ignore -from typing import ( - Dict, - List, - Optional, - Union, -) class FirestoreBundle: @@ -333,8 +328,10 @@ def build(self) -> str: BundleElement(document_metadata=bundled_document.metadata) ) document_count += 1 + document_msg = bundled_document.snapshot._to_protobuf() + assert document_msg is not None buffer += self._compile_bundle_element( - BundleElement(document=bundled_document.snapshot._to_protobuf()._pb,) + BundleElement(document=document_msg._pb,) ) metadata: BundleElement = BundleElement( diff --git a/google/cloud/firestore_bundle/py.typed b/google/cloud/firestore_bundle/py.typed index e2987f2963..bfcd092b29 100644 --- a/google/cloud/firestore_bundle/py.typed +++ b/google/cloud/firestore_bundle/py.typed @@ -1,2 +1,2 @@ # Marker file for PEP 561. -# The google-cloud-bundle package uses inline types. +# The google-cloud-firestore_bundle package uses inline types. diff --git a/google/cloud/firestore_v1/__init__.py b/google/cloud/firestore_v1/__init__.py index e6100331a4..0fade3f08f 100644 --- a/google/cloud/firestore_v1/__init__.py +++ b/google/cloud/firestore_v1/__init__.py @@ -18,13 +18,12 @@ """Python idiomatic client for Google Cloud Firestore.""" - import pkg_resources try: __version__ = pkg_resources.get_distribution("google-cloud-firestore").version except pkg_resources.DistributionNotFound: - __version__ = None + __version__ = None # type: ignore from google.cloud.firestore_v1 import types from google.cloud.firestore_v1._helpers import GeoPoint diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 05e8c26790..34c57002a5 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -41,7 +41,6 @@ Generator, Iterator, List, - NoReturn, Optional, Tuple, Union, @@ -496,7 +495,7 @@ def __init__(self, document_data) -> None: self.increments = {} self.minimums = {} self.maximums = {} - self.set_fields = {} + self.set_fields: Dict[str, Any] = {} self.empty_document = False prefix_path = FieldPath() @@ -559,13 +558,16 @@ def transform_paths(self): + list(self.minimums) ) - def _get_update_mask(self, allow_empty_mask=False) -> None: + def _get_update_mask( + self, allow_empty_mask=False + ) -> Union[types.common.DocumentMask, None]: return None def get_update_pb( self, document_path, exists=None, allow_empty_mask=False ) -> types.write.Write: + current_document: Union[common.Precondition, None] if exists is not None: current_document = common.Precondition(exists=exists) else: @@ -725,9 +727,9 @@ class DocumentExtractorForMerge(DocumentExtractor): def __init__(self, document_data) -> None: super(DocumentExtractorForMerge, self).__init__(document_data) - self.data_merge = [] - self.transform_merge = [] - self.merge = [] + self.data_merge: List[str] = [] + self.transform_merge: List[FieldPath] = [] + self.merge: List[FieldPath] = [] def _apply_merge_all(self) -> None: self.data_merge = sorted(self.field_paths + self.deleted_fields) @@ -783,7 +785,7 @@ def _apply_merge_paths(self, merge) -> None: self.data_merge.append(field_path) # Clear out data for fields not merged. - merged_set_fields = {} + merged_set_fields: Dict[str, Any] = {} for field_path in self.data_merge: value = get_field_value(self.document_data, field_path) set_field_value(merged_set_fields, field_path, value) @@ -834,7 +836,7 @@ def apply_merge(self, merge) -> None: def _get_update_mask( self, allow_empty_mask=False - ) -> Optional[types.common.DocumentMask]: + ) -> Union[types.common.DocumentMask, None]: # Mask uses dotted / quoted paths. mask_paths = [ field_path.to_api_repr() @@ -903,7 +905,9 @@ def _get_document_iterator( ) -> Generator[Tuple[Any, Any], Any, None]: return extract_fields(self.document_data, prefix_path, expand_dots=True) - def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask: + def _get_update_mask( + self, allow_empty_mask=False + ) -> Union[types.common.DocumentMask, None]: mask_paths = [] for field_path in self.top_level_paths: if field_path not in self.transform_paths: @@ -1017,7 +1021,7 @@ def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]: class WriteOption(object): """Option used to assert a condition on a write operation.""" - def modify_write(self, write, no_create_msg=None) -> NoReturn: + def modify_write(self, write) -> None: """Modify a ``Write`` protobuf based on the state of this write option. This is a virtual method intended to be implemented by subclasses. @@ -1026,8 +1030,6 @@ def modify_write(self, write, no_create_msg=None) -> NoReturn: write (google.cloud.firestore_v1.types.Write): A ``Write`` protobuf instance to be modified with a precondition determined by the state of this option. - no_create_msg (Optional[str]): A message to use to indicate that - a create operation is not allowed. Raises: NotImplementedError: Always, this method is virtual. @@ -1233,6 +1235,8 @@ def deserialize_bundle( raise ValueError("Unexpected end to serialized FirestoreBundle") # Now, finally add the metadata element + assert bundle is not None + bundle._add_bundle_element( metadata_bundle_element, client=client, type="metadata", # type: ignore ) @@ -1296,7 +1300,8 @@ def _get_documents_from_bundle( def _get_document_from_bundle( bundle, *, document_id: str, -) -> Optional["google.cloud.firestore.DocumentSnapshot"]: # type: ignore +) -> Union["google.cloud.firestore.DocumentSnapshot", None]: # type: ignore bundled_doc = bundle.documents.get(document_id) if bundled_doc: return bundled_doc.snapshot + return None diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index 87033d73ba..d70ad10921 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -16,9 +16,9 @@ from google.api_core import gapic_v1 -from google.api_core import retry as retries from google.cloud.firestore_v1.base_batch import BaseWriteBatch +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry class AsyncWriteBatch(BaseWriteBatch): @@ -37,7 +37,7 @@ def __init__(self, client) -> None: super(AsyncWriteBatch, self).__init__(client=client) async def commit( - self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, + self, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> list: """Commit the changes accumulated in this batch. diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 275f904fb9..631c907e16 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -24,36 +24,39 @@ :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` """ -from google.api_core import gapic_v1 -from google.api_core import retry as retries - -from google.cloud.firestore_v1.base_client import ( - BaseClient, - DEFAULT_DATABASE, - _CLIENT_INFO, - _parse_batch_get, # type: ignore - _path_helper, +from typing import ( + Any, + cast, + AsyncGenerator, + Iterable, + List, + Optional, + Union, ) -from google.cloud.firestore_v1.async_query import AsyncCollectionGroup +from google.api_core import gapic_v1 + from google.cloud.firestore_v1.async_batch import AsyncWriteBatch from google.cloud.firestore_v1.async_collection import AsyncCollectionReference -from google.cloud.firestore_v1.async_document import ( - AsyncDocumentReference, - DocumentSnapshot, -) +from google.cloud.firestore_v1.async_document import AsyncDocumentReference +from google.cloud.firestore_v1.async_document import DocumentSnapshot +from google.cloud.firestore_v1.async_query import AsyncCollectionGroup +from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.async_transaction import AsyncTransaction +from google.cloud.firestore_v1.base_client import BaseClient +from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE +from google.cloud.firestore_v1.base_client import _CLIENT_INFO +from google.cloud.firestore_v1.base_client import _parse_batch_get # type: ignore +from google.cloud.firestore_v1.base_client import _path_helper +from google.cloud.firestore_v1.bulk_writer import BulkWriter from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.services.firestore import ( async_client as firestore_client, ) +from google.cloud.firestore_v1.services.firestore.async_client import OptionalRetry from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) -from typing import Any, AsyncGenerator, Iterable, List, Optional, Union, TYPE_CHECKING - -if TYPE_CHECKING: - from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER class AsyncClient(BaseClient): @@ -84,6 +87,8 @@ class AsyncClient(BaseClient): should be set through client_options. """ + _API_CLIENT_CLASS = firestore_client.FirestoreAsyncClient + def __init__( self, project=None, @@ -126,16 +131,6 @@ def _firestore_api(self): firestore_client, ) - @property - def _target(self): - """Return the target (where the API is). - Eg. "firestore.googleapis.com" - - Returns: - str: The location of the API. - """ - return self._target_helper(firestore_client.FirestoreAsyncClient) - def collection(self, *collection_path: str) -> AsyncCollectionReference: """Get a reference to a collection. @@ -189,6 +184,27 @@ def collection_group(self, collection_id: str) -> AsyncCollectionGroup: """ return AsyncCollectionGroup(self._get_collection_reference(collection_id)) + def _get_collection_reference(self, collection_id: str) -> AsyncCollectionReference: + """Checks validity of collection_id and then uses subclasses collection implementation. + + Args: + collection_id (str) Identifies the collections to query over. + + Every collection or subcollection with this ID as the last segment of its + path will be included. Cannot contain a slash. + + Returns: + The created collection. + """ + if "/" in collection_id: + raise ValueError( + "Invalid collection_id " + + collection_id + + ". Collection IDs must not contain '/'." + ) + + return self.collection(collection_id) + def document(self, *document_path: str) -> AsyncDocumentReference: """Get a reference to a document in a collection. @@ -229,7 +245,7 @@ async def get_all( references: List[AsyncDocumentReference], field_paths: Iterable[str] = None, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieve a batch of documents. @@ -282,7 +298,7 @@ async def get_all( yield _parse_batch_get(get_doc_response, reference_map, self) async def collections( - self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, + self, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> AsyncGenerator[AsyncCollectionReference, Any]: """List top-level collections of the client's database. @@ -309,7 +325,7 @@ async def recursive_delete( reference: Union[AsyncCollectionReference, AsyncDocumentReference], *, bulk_writer: Optional["BulkWriter"] = None, - chunk_size: Optional[int] = 5000, + chunk_size: int = 5000, ): """Deletes documents and their subcollections, regardless of collection name. @@ -343,8 +359,8 @@ async def _recursive_delete( reference: Union[AsyncCollectionReference, AsyncDocumentReference], bulk_writer: "BulkWriter", *, - chunk_size: Optional[int] = 5000, - depth: Optional[int] = 0, + chunk_size: int = 5000, + depth: int = 0, ) -> int: """Recursion helper for `recursive_delete.""" @@ -352,9 +368,8 @@ async def _recursive_delete( if isinstance(reference, AsyncCollectionReference): chunk: List[DocumentSnapshot] - async for chunk in reference.recursive().select( - [FieldPath.document_id()] - )._chunkify(chunk_size): + query = reference.recursive().select([FieldPath.document_id()]) + async for chunk in cast(AsyncQuery, query)._chunkify(chunk_size): doc_snap: DocumentSnapshot for doc_snap in chunk: num_deleted += 1 diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index f16992e887..2f4edbc358 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -14,25 +14,18 @@ """Classes for representing collections for the Google Cloud Firestore API.""" -from google.api_core import gapic_v1 -from google.api_core import retry as retries +from typing import Any, AsyncGenerator, AsyncIterator, cast, Tuple -from google.cloud.firestore_v1.base_collection import ( - BaseCollectionReference, - _item_to_document_ref, -) -from google.cloud.firestore_v1 import ( - async_query, - async_document, -) +from google.api_core import gapic_v1 +from google.cloud.firestore_v1 import async_document +from google.cloud.firestore_v1 import async_query +from google.cloud.firestore_v1.base_collection import BaseCollectionReference +from google.cloud.firestore_v1.base_collection import _item_to_document_ref from google.cloud.firestore_v1.document import DocumentReference - -from typing import AsyncIterator -from typing import Any, AsyncGenerator, Tuple - -# Types needed only for Type Hints +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.types import write class AsyncCollectionReference(BaseCollectionReference): @@ -80,7 +73,7 @@ async def add( self, document_data: dict, document_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. @@ -113,7 +106,8 @@ async def add( document_ref, kwargs = self._prep_add( document_data, document_id, retry, timeout, ) - write_result = await document_ref.create(document_data, **kwargs) + document_ref = cast(async_document.AsyncDocumentReference, document_ref) + write_result: write.Write = await document_ref.create(document_data, **kwargs) return write_result.update_time, document_ref def document( @@ -131,12 +125,13 @@ def document( :class:`~google.cloud.firestore_v1.document.async_document.AsyncDocumentReference`: The child document. """ - return super(AsyncCollectionReference, self).document(document_id) + document = super().document(document_id) + return cast(async_document.AsyncDocumentReference, document) async def list_documents( self, page_size: int = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> AsyncGenerator[DocumentReference, None]: """List all subdocuments of the current collection. @@ -167,7 +162,7 @@ async def list_documents( async def get( self, transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> list: """Read the documents in this collection. @@ -198,7 +193,7 @@ async def get( async def stream( self, transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> AsyncIterator[async_document.DocumentSnapshot]: """Read the documents in this collection. diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index c11e6db2d4..5b11a51cd7 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -15,20 +15,18 @@ """Classes for representing documents for the Google Cloud Firestore API.""" import datetime import logging +from typing import Any, AsyncGenerator, Coroutine, Iterable, Union from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore - -from google.cloud.firestore_v1.base_document import ( - BaseDocumentReference, - DocumentSnapshot, - _first_write_result, -) +from google.cloud._helpers import _datetime_to_pb_timestamp +from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_document import _first_write_result +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry from google.cloud.firestore_v1.types import write -from google.protobuf.timestamp_pb2 import Timestamp -from typing import Any, AsyncGenerator, Coroutine, Iterable, Union logger = logging.getLogger(__name__) @@ -65,7 +63,7 @@ def __init__(self, *path, **kwargs) -> None: async def create( self, document_data: dict, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> write.WriteResult: """Create the current document in the Firestore database. @@ -95,7 +93,7 @@ async def set( self, document_data: dict, merge: bool = False, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> write.WriteResult: """Replace the current document in the Firestore database. @@ -135,7 +133,7 @@ async def update( self, field_updates: dict, option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> write.WriteResult: """Update an existing document in the Firestore database. @@ -292,7 +290,7 @@ async def update( async def delete( self, option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Timestamp: """Delete the current document in the Firestore database. @@ -325,7 +323,7 @@ async def get( self, field_paths: Iterable[str] = None, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Union[DocumentSnapshot, Coroutine[Any, Any, DocumentSnapshot]]: """Retrieve a snapshot of the current document. @@ -391,7 +389,7 @@ async def get( async def collections( self, page_size: int = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> AsyncGenerator: """List subcollections of the current document. diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 418f4f157c..3a64ed81fd 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -19,24 +19,22 @@ a more common way to create a query than direct usage of the constructor. """ +from typing import AsyncGenerator, cast, List, Optional, Type + from google.api_core import gapic_v1 -from google.api_core import retry as retries from google.cloud import firestore_v1 +from google.cloud.firestore_v1 import async_document +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_query import BaseCollectionGroup +from google.cloud.firestore_v1.base_query import BaseQuery +from google.cloud.firestore_v1.base_query import QueryPartition from google.cloud.firestore_v1.base_query import ( - BaseCollectionGroup, - BaseQuery, - QueryPartition, - _query_response_to_snapshot, _collection_group_query_response_to_snapshot, - _enum_from_direction, ) - -from google.cloud.firestore_v1 import async_document -from google.cloud.firestore_v1.base_document import DocumentSnapshot -from typing import AsyncGenerator, List, Optional, Type - -# Types needed only for Type Hints +from google.cloud.firestore_v1.base_query import _enum_from_direction +from google.cloud.firestore_v1.base_query import _query_response_to_snapshot +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry from google.cloud.firestore_v1.transaction import Transaction @@ -132,7 +130,7 @@ async def _chunkify( ) -> AsyncGenerator[List[DocumentSnapshot], None]: max_to_return: Optional[int] = self._limit num_returned: int = 0 - original: AsyncQuery = self._copy() + original: AsyncQuery = cast(AsyncQuery, self._copy()) last_document: Optional[DocumentSnapshot] = None while True: @@ -147,7 +145,7 @@ async def _chunkify( if last_document: _q = _q.start_after(last_document) - snapshots = await _q.get() + snapshots = await cast(AsyncQuery, _q).get() if snapshots: last_document = snapshots[-1] @@ -168,7 +166,7 @@ async def _chunkify( async def get( self, transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> list: """Read the documents in the collection that match this query. @@ -206,17 +204,21 @@ async def get( ) self._limit_to_last = False - result = self.stream(transaction=transaction, retry=retry, timeout=timeout) - result = [d async for d in result] + result = [] + async for doc in self.stream( + transaction=transaction, retry=retry, timeout=timeout + ): + result.append(doc) + if is_limited_to_last: - result = list(reversed(result)) + result.reverse() return result async def stream( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: """Read the documents in the collection that match this query. @@ -325,7 +327,7 @@ def _get_query_class(): async def get_partitions( self, partition_count, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> AsyncGenerator[QueryPartition, None]: """Partition a query for parallelization. diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index f4ecf32d34..8cc5f0b4c8 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -14,39 +14,33 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" - import asyncio import random +from typing import Any, AsyncGenerator, Callable, Coroutine, Sequence +from google.api_core import exceptions from google.api_core import gapic_v1 -from google.api_core import retry as retries - -from google.cloud.firestore_v1.base_transaction import ( - _BaseTransactional, - BaseTransaction, - MAX_ATTEMPTS, - _CANT_BEGIN, - _CANT_ROLLBACK, - _CANT_COMMIT, - _WRITE_READ_ONLY, - _INITIAL_SLEEP, - _MAX_SLEEP, - _MULTIPLIER, - _EXCEED_ATTEMPTS_TEMPLATE, -) -from google.api_core import exceptions -from google.cloud.firestore_v1 import async_batch from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import async_batch from google.cloud.firestore_v1 import types - from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_document import DocumentSnapshot from google.cloud.firestore_v1.async_query import AsyncQuery -from typing import Any, AsyncGenerator, Callable, Coroutine - -# Types needed only for Type Hints +from google.cloud.firestore_v1.base_transaction import BaseTransaction +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS +from google.cloud.firestore_v1.base_transaction import _BaseTransactional +from google.cloud.firestore_v1.base_transaction import _CANT_BEGIN +from google.cloud.firestore_v1.base_transaction import _CANT_COMMIT +from google.cloud.firestore_v1.base_transaction import _CANT_ROLLBACK +from google.cloud.firestore_v1.base_transaction import _EXCEED_ATTEMPTS_TEMPLATE +from google.cloud.firestore_v1.base_transaction import _INITIAL_SLEEP +from google.cloud.firestore_v1.base_transaction import _MAX_SLEEP +from google.cloud.firestore_v1.base_transaction import _MULTIPLIER +from google.cloud.firestore_v1.base_transaction import _WRITE_READ_ONLY from google.cloud.firestore_v1.client import Client +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry +from google.cloud.firestore_v1.types import write class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): @@ -67,7 +61,7 @@ def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: super(AsyncTransaction, self).__init__(client) BaseTransaction.__init__(self, max_attempts, read_only) - def _add_write_pbs(self, write_pbs: list) -> None: + def _add_write_pbs(self, write_pbs: Sequence[write.Write]) -> None: """Add `Write`` protobufs to this transaction. Args: @@ -141,6 +135,7 @@ async def _commit(self) -> list: if not self.in_progress: raise ValueError(_CANT_COMMIT) + assert self._id is not None commit_response = await _commit_with_retry( self._client, self._write_pbs, self._id ) @@ -151,7 +146,7 @@ async def _commit(self) -> list: async def get_all( self, references: list, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieves multiple documents from Firestore. @@ -174,7 +169,7 @@ async def get_all( async def get( self, ref_or_query, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """ @@ -197,7 +192,7 @@ async def get( [ref_or_query], transaction=self, **kwargs ) elif isinstance(ref_or_query, AsyncQuery): - return await ref_or_query.stream(transaction=self, **kwargs) + return await ref_or_query.stream(transaction=self, **kwargs) # type: ignore else: raise ValueError( 'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.' diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index ca3a66c897..52de1c9076 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -15,12 +15,12 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" import abc -from typing import Dict, Union +from typing import Dict, List, Optional, Sequence, Union -# Types needed only for Type Hints -from google.api_core import retry as retries from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry +from google.cloud.firestore_v1.types import write class BaseBatch(metaclass=abc.ABCMeta): @@ -37,9 +37,9 @@ class BaseBatch(metaclass=abc.ABCMeta): def __init__(self, client) -> None: self._client = client - self._write_pbs = [] + self._write_pbs: List[write.Write] = [] self._document_references: Dict[str, BaseDocumentReference] = {} - self.write_results = None + self.write_results: Union[list, None] = None self.commit_time = None def __len__(self): @@ -48,7 +48,7 @@ def __len__(self): def __contains__(self, reference: BaseDocumentReference): return reference._document_path in self._document_references - def _add_write_pbs(self, write_pbs: list) -> None: + def _add_write_pbs(self, write_pbs: Sequence[write.Write]) -> None: """Add `Write`` protobufs to this transaction. This method intended to be over-ridden by subclasses. @@ -59,12 +59,6 @@ def _add_write_pbs(self, write_pbs: list) -> None: """ self._write_pbs.extend(write_pbs) - @abc.abstractmethod - def commit(self): - """Sends all accumulated write operations to the server. The details of this - write depend on the implementing class.""" - raise NotImplementedError() - def create(self, reference: BaseDocumentReference, document_data: dict) -> None: """Add a "change" to this batch to create a document. @@ -170,7 +164,7 @@ class BaseWriteBatch(BaseBatch): """Base class for a/sync implementations of the `commit` RPC. `commit` is useful for lower volumes or when the order of write operations is important.""" - def _prep_commit(self, retry: retries.Retry, timeout: float): + def _prep_commit(self, retry: OptionalRetry, timeout: Optional[float]): """Shared setup for async/sync :meth:`commit`.""" request = { "database": self._client._database_string, diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 87c01deef5..7895dad54f 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -25,39 +25,32 @@ """ import os -import grpc # type: ignore - -from google.auth.credentials import AnonymousCredentials -import google.api_core.client_options -import google.api_core.path_template -from google.api_core import retry as retries -from google.api_core.gapic_v1 import client_info -from google.cloud.client import ClientWithProject # type: ignore - -from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import __version__ -from google.cloud.firestore_v1 import types -from google.cloud.firestore_v1.base_document import DocumentSnapshot - -from google.cloud.firestore_v1.field_path import render_field_path -from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from typing import ( Any, - AsyncGenerator, - Generator, Iterable, List, Optional, Tuple, + Type, Union, ) -# Types needed only for Type Hints -from google.cloud.firestore_v1.base_collection import BaseCollectionReference +import google.api_core.client_options +import google.api_core.path_template +from google.api_core.gapic_v1 import client_info +from google.auth.credentials import AnonymousCredentials # type: ignore +from google.cloud.client import ClientWithProject # type: ignore +import grpc # type: ignore + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import __version__ +from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.base_transaction import BaseTransaction -from google.cloud.firestore_v1.base_batch import BaseWriteBatch -from google.cloud.firestore_v1.base_query import BaseQuery +from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions +from google.cloud.firestore_v1.field_path import render_field_path +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry DEFAULT_DATABASE = "(default)" @@ -71,7 +64,9 @@ ) _ACTIVE_TXN: str = "There is already an active transaction." _INACTIVE_TXN: str = "There is no active transaction." -_CLIENT_INFO: Any = client_info.ClientInfo(client_library_version=__version__) +_CLIENT_INFO: Any = client_info.ClientInfo( + client_library_version=__version__ # type: ignore +) _FIRESTORE_EMULATOR_HOST: str = "FIRESTORE_EMULATOR_HOST" @@ -109,6 +104,7 @@ class BaseClient(ClientWithProject): ) """The scopes required for authenticating with the Firestore service.""" + _API_CLIENT_CLASS: Optional[Type] = None _firestore_api_internal = None _database_string_internal = None _rpc_metadata_internal = None @@ -215,6 +211,16 @@ def _target_helper(self, client_class) -> str: else: return client_class.DEFAULT_ENDPOINT + @property + def _target(self): + """Return the target (where the API is). + Eg. "firestore.googleapis.com" + + Returns: + str: The location of the API. + """ + return self._target_helper(self._API_CLIENT_CLASS) + @property def _database_string(self): """The database string corresponding to this client's project. @@ -261,36 +267,6 @@ def _rpc_metadata(self): return self._rpc_metadata_internal - def collection(self, *collection_path) -> BaseCollectionReference: - raise NotImplementedError - - def collection_group(self, collection_id: str) -> BaseQuery: - raise NotImplementedError - - def _get_collection_reference(self, collection_id: str) -> BaseCollectionReference: - """Checks validity of collection_id and then uses subclasses collection implementation. - - Args: - collection_id (str) Identifies the collections to query over. - - Every collection or subcollection with this ID as the last segment of its - path will be included. Cannot contain a slash. - - Returns: - The created collection. - """ - if "/" in collection_id: - raise ValueError( - "Invalid collection_id " - + collection_id - + ". Collection IDs must not contain '/'." - ) - - return self.collection(collection_id) - - def document(self, *document_path) -> BaseDocumentReference: - raise NotImplementedError - def bulk_writer(self, options: Optional[BulkWriterOptions] = None) -> BulkWriter: """Get a BulkWriter instance from this client. @@ -306,7 +282,7 @@ def bulk_writer(self, options: Optional[BulkWriterOptions] = None) -> BulkWriter """ return BulkWriter(client=self, options=options) - def _document_path_helper(self, *document_path) -> List[str]: + def _document_path_helper(self, *document_path: str) -> List[str]: """Standardize the format of path to tuple of path segments and strip the database string from path if present. Args: @@ -322,13 +298,6 @@ def _document_path_helper(self, *document_path) -> List[str]: joined_path = joined_path[len(base_path) :] return joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER) - def recursive_delete( - self, - reference: Union[BaseCollectionReference, BaseDocumentReference], - bulk_writer: Optional["BulkWriter"] = None, # type: ignore - ) -> int: - raise NotImplementedError - @staticmethod def field_path(*field_names: str) -> str: """Create a **field path** from a list of nested field names. @@ -416,7 +385,7 @@ def _prep_get_all( references: list, field_paths: Iterable[str] = None, transaction: BaseTransaction = None, - retry: retries.Retry = None, + retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[dict, dict, dict]: """Shared setup for async/sync :meth:`get_all`.""" @@ -432,20 +401,8 @@ def _prep_get_all( return request, reference_map, kwargs - def get_all( - self, - references: list, - field_paths: Iterable[str] = None, - transaction: BaseTransaction = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Union[ - AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any] - ]: - raise NotImplementedError - def _prep_collections( - self, retry: retries.Retry = None, timeout: float = None, + self, retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`collections`.""" request = {"parent": "{}/documents".format(self._database_string)} @@ -453,20 +410,6 @@ def _prep_collections( return request, kwargs - def collections( - self, retry: retries.Retry = None, timeout: float = None, - ) -> Union[ - AsyncGenerator[BaseCollectionReference, Any], - Generator[BaseCollectionReference, Any, Any], - ]: - raise NotImplementedError - - def batch(self) -> BaseWriteBatch: - raise NotImplementedError - - def transaction(self, **kwargs) -> BaseTransaction: - raise NotImplementedError - def _reference_info(references: list) -> Tuple[list, dict]: """Get information about document references. @@ -575,7 +518,9 @@ def _parse_batch_get( return snapshot -def _get_doc_mask(field_paths: Iterable[str]) -> Optional[types.common.DocumentMask]: +def _get_doc_mask( + field_paths: Optional[Iterable[str]], +) -> Optional[types.common.DocumentMask]: """Get a document mask if field paths are provided. Args: @@ -593,7 +538,7 @@ def _get_doc_mask(field_paths: Iterable[str]) -> Optional[types.common.DocumentM return types.DocumentMask(field_paths=field_paths) -def _path_helper(path: tuple) -> Tuple[str]: +def _path_helper(path: Tuple[str, ...]) -> Tuple[str, ...]: """Standardize path into a tuple of path segments. Args: @@ -603,6 +548,6 @@ def _path_helper(path: tuple) -> Tuple[str]: * A tuple of path segments """ if len(path) == 1: - return path[0].split(_helpers.DOCUMENT_PATH_DELIMITER) + return tuple(path[0].split(_helpers.DOCUMENT_PATH_DELIMITER)) else: return path diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index c3091e75aa..d4fd310893 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -13,30 +13,17 @@ # limitations under the License. """Classes for representing collections for the Google Cloud Firestore API.""" + import random import sys - -from google.api_core import retry as retries +from typing import Any, Iterable, Tuple, Union from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1.document import DocumentReference -from typing import ( - Any, - AsyncGenerator, - Coroutine, - Generator, - AsyncIterator, - Iterator, - Iterable, - NoReturn, - Tuple, - Union, -) - -# Types needed only for Type Hints +from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.base_query import BaseQuery -from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry _AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" @@ -108,7 +95,7 @@ def parent(self): def _query(self) -> BaseQuery: raise NotImplementedError - def document(self, document_id: str = None) -> DocumentReference: + def document(self, document_id: str = None) -> BaseDocumentReference: """Create a sub-document underneath the current collection. Args: @@ -118,7 +105,7 @@ def document(self, document_id: str = None) -> DocumentReference: uppercase and lowercase and letters. Returns: - :class:`~google.cloud.firestore_v1.document.DocumentReference`: + :class:`~google.cloud.firestore_v1.document.BaseDocumentReference`: The child document. """ if document_id is None: @@ -156,9 +143,9 @@ def _prep_add( self, document_data: dict, document_id: str = None, - retry: retries.Retry = None, + retry: OptionalRetry = None, timeout: float = None, - ) -> Tuple[DocumentReference, dict]: + ) -> Tuple[BaseDocumentReference, dict]: """Shared setup for async / sync :method:`add`""" if document_id is None: document_id = _auto_id() @@ -168,17 +155,8 @@ def _prep_add( return document_ref, kwargs - def add( - self, - document_data: dict, - document_id: str = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Union[Tuple[Any, Any], Coroutine[Any, Any, Tuple[Any, Any]]]: - raise NotImplementedError - def _prep_list_documents( - self, page_size: int = None, retry: retries.Retry = None, timeout: float = None, + self, page_size: int = None, retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[dict, dict]: """Shared setup for async / sync :method:`list_documents`""" parent, _ = self._parent_info() @@ -196,14 +174,7 @@ def _prep_list_documents( return request, kwargs - def list_documents( - self, page_size: int = None, retry: retries.Retry = None, timeout: float = None, - ) -> Union[ - Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any] - ]: - raise NotImplementedError - - def recursive(self) -> "BaseQuery": + def recursive(self) -> BaseQuery: return self._query().recursive() def select(self, field_paths: Iterable[str]) -> BaseQuery: @@ -438,7 +409,7 @@ def end_at( return query.end_at(document_fields) def _prep_get_or_stream( - self, retry: retries.Retry = None, timeout: float = None, + self, retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[Any, dict]: """Shared setup for async / sync :meth:`get` / :meth:`stream`""" query = self._query() @@ -446,27 +417,6 @@ def _prep_get_or_stream( return query, kwargs - def get( - self, - transaction: Transaction = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Union[ - Generator[DocumentSnapshot, Any, Any], AsyncGenerator[DocumentSnapshot, Any] - ]: - raise NotImplementedError - - def stream( - self, - transaction: Transaction = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Union[Iterator[DocumentSnapshot], AsyncIterator[DocumentSnapshot]]: - raise NotImplementedError - - def on_snapshot(self, callback) -> NoReturn: - raise NotImplementedError - def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index a4ab469df6..d0d4b3b312 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -15,18 +15,15 @@ """Classes for representing documents for the Google Cloud Firestore API.""" import copy +from typing import Any, Dict, Iterable, Optional, Union, Tuple -from google.api_core import retry as retries - -from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import field_path as field_path_module +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry +from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import common - -# Types needed only for Type Hints from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import write -from typing import Any, Dict, Iterable, NoReturn, Optional, Union, Tuple class BaseDocumentReference(object): @@ -187,7 +184,7 @@ def collection(self, collection_id: str) -> Any: return self._client.collection(*child_path) def _prep_create( - self, document_data: dict, retry: retries.Retry = None, timeout: float = None, + self, document_data: dict, retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[Any, dict]: batch = self._client.batch() batch.create(self, document_data) @@ -195,16 +192,11 @@ def _prep_create( return batch, kwargs - def create( - self, document_data: dict, retry: retries.Retry = None, timeout: float = None, - ) -> NoReturn: - raise NotImplementedError - def _prep_set( self, document_data: dict, merge: bool = False, - retry: retries.Retry = None, + retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[Any, dict]: batch = self._client.batch() @@ -213,20 +205,11 @@ def _prep_set( return batch, kwargs - def set( - self, - document_data: dict, - merge: bool = False, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: - raise NotImplementedError - def _prep_update( self, field_updates: dict, option: _helpers.WriteOption = None, - retry: retries.Retry = None, + retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[Any, dict]: batch = self._client.batch() @@ -235,19 +218,10 @@ def _prep_update( return batch, kwargs - def update( - self, - field_updates: dict, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: - raise NotImplementedError - def _prep_delete( self, option: _helpers.WriteOption = None, - retry: retries.Retry = None, + retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`delete`.""" @@ -261,25 +235,18 @@ def _prep_delete( return request, kwargs - def delete( - self, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: - raise NotImplementedError - def _prep_batch_get( self, field_paths: Iterable[str] = None, transaction=None, - retry: retries.Retry = None, + retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`get`.""" if isinstance(field_paths, str): raise ValueError("'field_paths' must be a sequence of paths, not a string.") + mask: Union[common.DocumentMask, None] if field_paths is not None: mask = common.DocumentMask(field_paths=sorted(field_paths)) else: @@ -295,17 +262,8 @@ def _prep_batch_get( return request, kwargs - def get( - self, - field_paths: Iterable[str] = None, - transaction=None, - retry: retries.Retry = None, - timeout: float = None, - ) -> "DocumentSnapshot": - raise NotImplementedError - def _prep_collections( - self, page_size: int = None, retry: retries.Retry = None, timeout: float = None, + self, page_size: int = None, retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`collections`.""" request = {"parent": self._document_path, "page_size": page_size} @@ -313,14 +271,6 @@ def _prep_collections( return request, kwargs - def collections( - self, page_size: int = None, retry: retries.Retry = None, timeout: float = None, - ) -> None: - raise NotImplementedError - - def on_snapshot(self, callback) -> None: - raise NotImplementedError - class DocumentSnapshot(object): """A snapshot of document data in a Firestore database. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 537288d160..9e1dfb250d 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -20,33 +20,31 @@ """ import copy import math +from typing import ( + Any, + Dict, + Iterable, + Optional, + Tuple, + Type, + Union, +) -from google.api_core import retry as retries +from google.api_core import gapic_v1 from google.protobuf import wrappers_pb2 from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import document from google.cloud.firestore_v1 import field_path as field_path_module from google.cloud.firestore_v1 import transforms +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry from google.cloud.firestore_v1.types import StructuredQuery from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import Cursor from google.cloud.firestore_v1.types import RunQueryResponse from google.cloud.firestore_v1.order import Order -from typing import ( - Any, - Dict, - Generator, - Iterable, - NoReturn, - Optional, - Tuple, - Type, - Union, -) -# Types needed only for Type Hints -from google.cloud.firestore_v1.base_document import DocumentSnapshot _BAD_DIR_STRING: str _BAD_OP_NAN_NULL: str @@ -95,7 +93,24 @@ ) _MISMATCH_CURSOR_W_ORDER_BY = "The cursor {!r} does not match the order fields {!r}." -_not_passed = object() + +class _NotPassed: + """Marker for optoinal paramerters + + Used where ``None`` is a possible explicit value. + """ + + +_not_passed = _NotPassed() +OptionalProjection = Union[query.StructuredQuery.Projection, _NotPassed] +OptionalFieldFilters = Union[Tuple[query.StructuredQuery.FieldFilter], _NotPassed] +OptionalOrders = Union[Tuple[query.StructuredQuery.Order], _NotPassed] +OptionalInt = Union[int, _NotPassed] +OptionalBool = Union[bool, _NotPassed] +CursorParamStripped = Tuple[Union[tuple, dict, list], bool] +CursorArg = Union[DocumentSnapshot, dict, list, tuple, None] +CursorParam = Tuple[CursorArg, bool] +OptionalCursorParam = Union[CursorParam, _NotPassed] class BaseQuery(object): @@ -253,16 +268,16 @@ def select(self, field_paths: Iterable[str]) -> "BaseQuery": def _copy( self, *, - projection: Optional[query.StructuredQuery.Projection] = _not_passed, - field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed, - orders: Optional[Tuple[query.StructuredQuery.Order]] = _not_passed, - limit: Optional[int] = _not_passed, - limit_to_last: Optional[bool] = _not_passed, - offset: Optional[int] = _not_passed, - start_at: Optional[Tuple[dict, bool]] = _not_passed, - end_at: Optional[Tuple[dict, bool]] = _not_passed, - all_descendants: Optional[bool] = _not_passed, - recursive: Optional[bool] = _not_passed, + projection: OptionalProjection = _not_passed, + field_filters: OptionalFieldFilters = _not_passed, + orders: OptionalOrders = _not_passed, + limit: OptionalInt = _not_passed, + limit_to_last: OptionalBool = _not_passed, + offset: OptionalInt = _not_passed, + start_at: OptionalCursorParam = _not_passed, + end_at: OptionalCursorParam = _not_passed, + all_descendants: OptionalBool = _not_passed, + recursive: OptionalBool = _not_passed, ) -> "BaseQuery": return self.__class__( self._parent, @@ -461,10 +476,7 @@ def _check_snapshot(self, document_snapshot) -> None: raise ValueError("Cannot use snapshot from another collection as a cursor.") def _cursor_helper( - self, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], - before: bool, - start: bool, + self, document_fields_or_snapshot: CursorArg, before: bool, start: bool, ) -> "BaseQuery": """Set values to be used for a ``start_at`` or ``end_at`` cursor. @@ -517,9 +529,7 @@ def _cursor_helper( return self._copy(**query_kwargs) - def start_at( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + def start_at(self, document_fields_or_snapshot: CursorArg) -> "BaseQuery": """Start query results at a particular document value. The result set will **include** the document specified by @@ -549,9 +559,7 @@ def start_at( """ return self._cursor_helper(document_fields_or_snapshot, before=True, start=True) - def start_after( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + def start_after(self, document_fields_or_snapshot: CursorArg) -> "BaseQuery": """Start query results after a particular document value. The result set will **exclude** the document specified by @@ -582,9 +590,7 @@ def start_after( document_fields_or_snapshot, before=False, start=True ) - def end_before( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + def end_before(self, document_fields_or_snapshot: CursorArg) -> "BaseQuery": """End query results before a particular document value. The result set will **exclude** the document specified by @@ -615,9 +621,7 @@ def end_before( document_fields_or_snapshot, before=True, start=False ) - def end_at( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + def end_at(self, document_fields_or_snapshot: CursorArg) -> "BaseQuery": """End query results at a particular document value. The result set will **include** the document specified by @@ -648,7 +652,7 @@ def end_at( document_fields_or_snapshot, before=False, start=False ) - def _filters_pb(self) -> StructuredQuery.Filter: + def _filters_pb(self) -> Union[StructuredQuery.Filter, None]: """Convert all the filters into a single generic Filter protobuf. This may be a lone field filter or unary filter, may be a composite @@ -720,10 +724,10 @@ def _normalize_orders(self) -> list: return orders - def _normalize_cursor(self, cursor, orders) -> Optional[Tuple[Any, Any]]: + def _normalize_cursor(self, cursor, orders) -> Union[CursorParamStripped, None]: """Helper: convert cursor to a list of values based on orders.""" if cursor is None: - return + return None if not orders: raise ValueError(_NO_ORDERS_FOR_CURSOR) @@ -806,13 +810,8 @@ def _to_protobuf(self) -> StructuredQuery: query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit) return query.StructuredQuery(**query_kwargs) - def get( - self, transaction=None, retry: retries.Retry = None, timeout: float = None, - ) -> Iterable[DocumentSnapshot]: - raise NotImplementedError - def _prep_stream( - self, transaction=None, retry: retries.Retry = None, timeout: float = None, + self, transaction=None, retry: OptionalRetry = None, timeout: float = None, ) -> Tuple[dict, str, dict]: """Shared setup for async / sync :meth:`stream`""" if self._limit_to_last: @@ -831,14 +830,6 @@ def _prep_stream( return request, expected_prefix, kwargs - def stream( - self, transaction=None, retry: retries.Retry = None, timeout: float = None, - ) -> Generator[document.DocumentSnapshot, Any, None]: - raise NotImplementedError - - def on_snapshot(self, callback) -> NoReturn: - raise NotImplementedError - def recursive(self) -> "BaseQuery": """Returns a copy of this query whose iterator will yield all matching documents as well as each of their descendent subcollections and documents. @@ -925,6 +916,10 @@ def _comparator(self, doc1, doc2) -> int: return 0 + @staticmethod + def _get_collection_reference_class() -> Type: + raise NotImplementedError + def _enum_from_op_string(op_string: str) -> int: """Convert a string representation of a binary operator to an enum. @@ -1021,7 +1016,7 @@ def _filter_pb(field_or_unary) -> StructuredQuery.Filter: raise ValueError("Unexpected filter type", type(field_or_unary), field_or_unary) -def _cursor_pb(cursor_pair: Tuple[list, bool]) -> Optional[Cursor]: +def _cursor_pb(cursor_pair: Optional[CursorParamStripped]) -> Union[Cursor, None]: """Convert a cursor pair to a protobuf. If ``cursor_pair`` is :data:`None`, just returns :data:`None`. @@ -1041,6 +1036,8 @@ def _cursor_pb(cursor_pair: Tuple[list, bool]) -> Optional[Cursor]: value_pbs = [_helpers.encode_value(value) for value in data] return query.Cursor(values=value_pbs, before=before) + return None + def _query_response_to_snapshot( response_pb: RunQueryResponse, collection, expected_prefix: str @@ -1175,7 +1172,10 @@ def _get_query_class(self): raise NotImplementedError def _prep_get_partitions( - self, partition_count, retry: retries.Retry = None, timeout: float = None, + self, + partition_count, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> Tuple[dict, dict]: self._validate_partition_query() parent_path, expected_prefix = self._parent._parent_info() @@ -1196,13 +1196,8 @@ def _prep_get_partitions( return request, kwargs - def get_partitions( - self, partition_count, retry: retries.Retry = None, timeout: float = None, - ) -> NoReturn: - raise NotImplementedError - @staticmethod - def _get_collection_reference_class() -> Type["BaseCollectionGroup"]: + def _get_collection_reference_class() -> Type: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index 7774a3f03d..83b2c1433f 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -14,21 +14,10 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" -from google.api_core import retry as retries +from typing import List, Union from google.cloud.firestore_v1 import types -from typing import Any, Coroutine, NoReturn, Optional, Union - -_CANT_BEGIN: str -_CANT_COMMIT: str -_CANT_RETRY_READ_ONLY: str -_CANT_ROLLBACK: str -_EXCEED_ATTEMPTS_TEMPLATE: str -_INITIAL_SLEEP: float -_MAX_SLEEP: float -_MISSING_ID_TEMPLATE: str -_MULTIPLIER: float -_WRITE_READ_ONLY: str +from google.cloud.firestore_v1.types import write MAX_ATTEMPTS = 5 @@ -52,10 +41,10 @@ class BaseTransaction(object): """Accumulate read-and-write operations to be sent in a transaction. Args: - max_attempts (Optional[int]): The maximum number of attempts for + max_attempts (int): The maximum number of attempts for the transaction (i.e. allowing retries). Defaults to :attr:`~google.cloud.firestore_v1.transaction.MAX_ATTEMPTS`. - read_only (Optional[bool]): Flag indicating if the transaction + read_only (bool): Flag indicating if the transaction should be read-only or should allow writes. Defaults to :data:`False`. """ @@ -64,13 +53,11 @@ def __init__(self, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: self._max_attempts = max_attempts self._read_only = read_only self._id = None - - def _add_write_pbs(self, write_pbs) -> NoReturn: - raise NotImplementedError + self._write_pbs: List[write.Write] = [] def _options_protobuf( self, retry_id: Union[bytes, None] - ) -> Optional[types.common.TransactionOptions]: + ) -> Union[types.common.TransactionOptions, None]: """Convert the current object to protobuf. The ``retry_id`` value is used when retrying a transaction that @@ -82,7 +69,6 @@ def _options_protobuf( to be retried. Returns: - Optional[google.cloud.firestore_v1.types.TransactionOptions]: The protobuf ``TransactionOptions`` if ``read_only==True`` or if there is a transaction ID to be retried, else :data:`None`. @@ -120,7 +106,7 @@ def id(self): """Get the current transaction ID. Returns: - Optional[bytes]: The transaction ID (or :data:`None` if the + Union[bytes, None]: The transaction ID (or :data:`None` if the current transaction is not in progress). """ return self._id @@ -133,25 +119,6 @@ def _clean_up(self) -> None: self._write_pbs = [] self._id = None - def _begin(self, retry_id=None) -> NoReturn: - raise NotImplementedError - - def _rollback(self) -> NoReturn: - raise NotImplementedError - - def _commit(self) -> Union[list, Coroutine[Any, Any, list]]: - raise NotImplementedError - - def get_all( - self, references: list, retry: retries.Retry = None, timeout: float = None, - ) -> NoReturn: - raise NotImplementedError - - def get( - self, ref_or_query, retry: retries.Retry = None, timeout: float = None, - ) -> NoReturn: - raise NotImplementedError - class _BaseTransactional(object): """Provide a callable object to use as a transactional decorater. @@ -167,20 +134,11 @@ class _BaseTransactional(object): def __init__(self, to_wrap) -> None: self.to_wrap = to_wrap self.current_id = None - """Optional[bytes]: The current transaction ID.""" + """Union[bytes, None]: The current transaction ID.""" self.retry_id = None - """Optional[bytes]: The ID of the first attempted transaction.""" + """Union[bytes, None]: The ID of the first attempted transaction.""" def _reset(self) -> None: """Unset the transaction IDs.""" self.current_id = None self.retry_id = None - - def _pre_commit(self, transaction, *args, **kwargs) -> NoReturn: - raise NotImplementedError - - def _maybe_commit(self, transaction) -> NoReturn: - raise NotImplementedError - - def __call__(self, transaction, *args, **kwargs): - raise NotImplementedError diff --git a/google/cloud/firestore_v1/batch.py b/google/cloud/firestore_v1/batch.py index 2621efc205..136e9e4359 100644 --- a/google/cloud/firestore_v1/batch.py +++ b/google/cloud/firestore_v1/batch.py @@ -15,9 +15,9 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" from google.api_core import gapic_v1 -from google.api_core import retry as retries from google.cloud.firestore_v1.base_batch import BaseWriteBatch +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry class WriteBatch(BaseWriteBatch): @@ -38,7 +38,7 @@ def __init__(self, client) -> None: super(WriteBatch, self).__init__(client=client) def commit( - self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None + self, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None ) -> list: """Commit the changes accumulated in this batch. diff --git a/google/cloud/firestore_v1/bulk_batch.py b/google/cloud/firestore_v1/bulk_batch.py index a525a09620..f7488c02e8 100644 --- a/google/cloud/firestore_v1/bulk_batch.py +++ b/google/cloud/firestore_v1/bulk_batch.py @@ -13,12 +13,14 @@ # limitations under the License. """Helpers for batch requests to the Google Cloud Firestore API.""" +from typing import Optional + from google.api_core import gapic_v1 -from google.api_core import retry as retries from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.base_batch import BaseBatch from google.cloud.firestore_v1.types.firestore import BatchWriteResponse +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry class BulkWriteBatch(BaseBatch): @@ -46,7 +48,9 @@ def __init__(self, client) -> None: super(BulkWriteBatch, self).__init__(client=client) def commit( - self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None + self, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> BatchWriteResponse: """Writes the changes accumulated in this batch. @@ -79,7 +83,7 @@ def commit( return save_response - def _prep_commit(self, retry: retries.Retry, timeout: float): + def _prep_commit(self, retry: OptionalRetry, timeout: Optional[float]): request = { "database": self._client._database_string, "writes": self._write_pbs, diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py index e52061c03d..84bb44d06f 100644 --- a/google/cloud/firestore_v1/bulk_writer.py +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -23,21 +23,19 @@ import functools import logging import time - -from typing import Callable, Dict, List, Optional, Union, TYPE_CHECKING +from typing import Callable, Deque, Dict, List, Optional, Union from google.rpc import status_pb2 # type: ignore from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import async_client +from google.cloud.firestore_v1 import base_client from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch from google.cloud.firestore_v1.rate_limiter import RateLimiter from google.cloud.firestore_v1.types.firestore import BatchWriteResponse from google.cloud.firestore_v1.types.write import WriteResult -if TYPE_CHECKING: - from google.cloud.firestore_v1.base_client import BaseClient # pragma: NO COVER - logger = logging.getLogger(__name__) @@ -82,6 +80,17 @@ class AsyncBulkWriterMixin: wrapped in a decorator which ensures that the `SendMode` is honored. """ + _in_flight_documents: int = 0 + _total_batches_sent: int = 0 + _total_write_operations: int = 0 + + _success_callback: Callable + _batch_callback: Callable + _error_callback: Callable + + _options: "BulkWriterOptions" + _retries: collections.deque + def _with_send_mode(fn): """Decorates a method to ensure it is only called via the executor (IFF the SendMode value is SendMode.parallel!). @@ -116,7 +125,7 @@ def wrapper(self, *args, **kwargs): return wrapper - @_with_send_mode + @_with_send_mode # type: ignore def _send_batch( self, batch: BulkWriteBatch, operations: List["BulkWriterOperation"] ): @@ -180,7 +189,7 @@ def _process_response( def _retry_operation( self, operation: "BulkWriterOperation", - ) -> concurrent.futures.Future: + ): delay: int = 0 if self._options.retry == BulkRetry.exponential: @@ -252,7 +261,7 @@ class BulkWriter(AsyncBulkWriterMixin): def __init__( self, - client: "BaseClient" = None, + client: "base_client.BaseClient", options: Optional["BulkWriterOptions"] = None, ): # Because `BulkWriter` instances are all synchronous/blocking on the @@ -261,9 +270,9 @@ def __init__( # `BulkWriter` parallelizes all of its network I/O without the developer # having to worry about awaiting async methods, so we must convert an # AsyncClient instance into a plain Client instance. - self._client = ( - client._to_sync_copy() if type(client).__name__ == "AsyncClient" else client - ) + if isinstance(client, async_client.AsyncClient): + client = client._to_sync_copy() + self._client = client self._options = options or BulkWriterOptions() self._send_mode = self._options.mode @@ -279,9 +288,9 @@ def __init__( # the raw operation with the `datetime` of its next scheduled attempt. # `self._retries` must always remain sorted for efficient reads, so it is # required to only ever add elements via `bisect.insort`. - self._retries: collections.deque["OperationRetry"] = collections.deque([]) + self._retries: Deque["OperationRetry"] = collections.deque([]) - self._queued_batches = collections.deque([]) + self._queued_batches: Deque = collections.deque([]) self._is_open: bool = True # This list will go on to store the future returned from each submission @@ -299,16 +308,11 @@ def __init__( [BulkWriteFailure, BulkWriter], bool ] = BulkWriter._default_on_error - self._in_flight_documents: int = 0 self._rate_limiter = RateLimiter( initial_tokens=self._options.initial_ops_per_second, global_max_tokens=self._options.max_ops_per_second, ) - # Keep track of progress as batches and write operations are completed - self._total_batches_sent: int = 0 - self._total_write_operations: int = 0 - self._ensure_executor() @staticmethod @@ -500,7 +504,7 @@ def _schedule_ready_retries(self): def _request_send(self, batch_size: int) -> bool: # Set up this boolean to avoid repeatedly taking tokens if we're only # waiting on the `max_in_flight` limit. - have_received_tokens: bool = False + got_tokens: bool = False while True: # To avoid bottlenecks on the server, an additional limit is that no @@ -512,10 +516,9 @@ def _request_send(self, batch_size: int) -> bool: ) # Ask for tokens each pass through this loop until they are granted, # and then stop. - have_received_tokens = ( - have_received_tokens or self._rate_limiter.take_tokens(batch_size) - ) - if not under_threshold or not have_received_tokens: + got_tokens = got_tokens or bool(self._rate_limiter.take_tokens(batch_size)) + + if not under_threshold or not got_tokens: # Try again until both checks are true. # Note that this sleep is helpful to prevent the main BulkWriter # thread from spinning through this loop as fast as possible and @@ -725,6 +728,8 @@ class BulkWriterOperation: similar writes to the same document. """ + attempts: int + def add_to_batch(self, batch: BulkWriteBatch): """Adds `self` to the supplied batch.""" assert isinstance(batch, BulkWriteBatch) @@ -763,6 +768,9 @@ class BaseOperationRetry: Python 3.6 is dropped and `dataclasses` becomes universal. """ + operation: BulkWriterOperation + run_at: datetime.datetime + def __lt__(self, other: "OperationRetry"): """Allows use of `bisect` to maintain a sorted list of `OperationRetry` instances, which in turn allows us to cheaply grab all that are ready to @@ -882,7 +890,7 @@ class BulkWriterDeleteOperation(BulkWriterOperation): # versions above. Additonally, the methods on `BaseOperationRetry` can be added # directly to `OperationRetry` and `BaseOperationRetry` can be deleted. - class BulkWriterOptions: + class BulkWriterOptions: # type: ignore def __init__( self, initial_ops_per_second: int = 500, @@ -900,7 +908,7 @@ def __eq__(self, other): return NotImplemented return self.__dict__ == other.__dict__ - class BulkWriteFailure: + class BulkWriteFailure: # type: ignore def __init__( self, operation: BulkWriterOperation, @@ -916,7 +924,7 @@ def __init__( def attempts(self) -> int: return self.operation.attempts - class OperationRetry(BaseOperationRetry): + class OperationRetry(BaseOperationRetry): # type: ignore """Container for an additional attempt at an operation, scheduled for the future.""" @@ -926,7 +934,7 @@ def __init__( self.operation = operation self.run_at = run_at - class BulkWriterCreateOperation(BulkWriterOperation): + class BulkWriterCreateOperation(BulkWriterOperation): # type: ignore """Container for BulkWriter.create() operations.""" def __init__( @@ -939,7 +947,7 @@ def __init__( self.document_data = document_data self.attempts = attempts - class BulkWriterUpdateOperation(BulkWriterOperation): + class BulkWriterUpdateOperation(BulkWriterOperation): # type: ignore """Container for BulkWriter.update() operations.""" def __init__( @@ -954,7 +962,7 @@ def __init__( self.option = option self.attempts = attempts - class BulkWriterSetOperation(BulkWriterOperation): + class BulkWriterSetOperation(BulkWriterOperation): # type: ignore """Container for BulkWriter.set() operations.""" def __init__( @@ -969,7 +977,7 @@ def __init__( self.merge = merge self.attempts = attempts - class BulkWriterDeleteOperation(BulkWriterOperation): + class BulkWriterDeleteOperation(BulkWriterOperation): # type: ignore """Container for BulkWriter.delete() operations.""" def __init__( diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 345f833c98..c50ae3b098 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -23,9 +23,9 @@ * a :class:`~google.cloud.firestore_v1.client.Client` owns a :class:`~google.cloud.firestore_v1.document.DocumentReference` """ +from typing import Any, cast, Generator, Iterable, List, Optional, Union from google.api_core import gapic_v1 -from google.api_core import retry as retries from google.cloud.firestore_v1.base_client import ( BaseClient, @@ -35,23 +35,20 @@ _path_helper, ) -from google.cloud.firestore_v1.query import CollectionGroup +from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.batch import WriteBatch +from google.cloud.firestore_v1.bulk_writer import BulkWriter from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.field_path import FieldPath -from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.query import CollectionGroup +from google.cloud.firestore_v1.query import Query from google.cloud.firestore_v1.services.firestore import client as firestore_client +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry from google.cloud.firestore_v1.services.firestore.transports import ( grpc as firestore_grpc_transport, ) -from typing import Any, Generator, Iterable, List, Optional, Union, TYPE_CHECKING - -# Types needed only for Type Hints -from google.cloud.firestore_v1.base_document import DocumentSnapshot - -if TYPE_CHECKING: - from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER +from google.cloud.firestore_v1.transaction import Transaction class Client(BaseClient): @@ -82,6 +79,8 @@ class Client(BaseClient): should be set through client_options. """ + _API_CLIENT_CLASS = firestore_client.FirestoreClient + def __init__( self, project=None, @@ -111,16 +110,6 @@ def _firestore_api(self): firestore_client, ) - @property - def _target(self): - """Return the target (where the API is). - Eg. "firestore.googleapis.com" - - Returns: - str: The location of the API. - """ - return self._target_helper(firestore_client.FirestoreClient) - def collection(self, *collection_path: str) -> CollectionReference: """Get a reference to a collection. @@ -209,12 +198,33 @@ def document(self, *document_path: str) -> DocumentReference: *self._document_path_helper(*document_path), client=self ) + def _get_collection_reference(self, collection_id: str) -> CollectionReference: + """Checks validity of collection_id and then uses subclasses collection implementation. + + Args: + collection_id (str) Identifies the collections to query over. + + Every collection or subcollection with this ID as the last segment of its + path will be included. Cannot contain a slash. + + Returns: + The created collection. + """ + if "/" in collection_id: + raise ValueError( + "Invalid collection_id " + + collection_id + + ". Collection IDs must not contain '/'." + ) + + return self.collection(collection_id) + def get_all( self, references: list, field_paths: Iterable[str] = None, transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Generator[DocumentSnapshot, Any, None]: """Retrieve a batch of documents. @@ -267,7 +277,7 @@ def get_all( yield _parse_batch_get(get_doc_response, reference_map, self) def collections( - self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, + self, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Generator[Any, Any, None]: """List top-level collections of the client's database. @@ -294,8 +304,8 @@ def recursive_delete( self, reference: Union[CollectionReference, DocumentReference], *, - bulk_writer: Optional["BulkWriter"] = None, - chunk_size: Optional[int] = 5000, + bulk_writer: Optional[BulkWriter] = None, + chunk_size: int = 5000, ) -> int: """Deletes documents and their subcollections, regardless of collection name. @@ -319,17 +329,17 @@ def recursive_delete( """ if bulk_writer is None: - bulk_writer or self.bulk_writer() + bulk_writer = self.bulk_writer() return self._recursive_delete(reference, bulk_writer, chunk_size=chunk_size,) def _recursive_delete( self, reference: Union[CollectionReference, DocumentReference], - bulk_writer: "BulkWriter", + bulk_writer: BulkWriter, *, - chunk_size: Optional[int] = 5000, - depth: Optional[int] = 0, + chunk_size: int = 5000, + depth: int = 0, ) -> int: """Recursion helper for `recursive_delete.""" @@ -337,11 +347,8 @@ def _recursive_delete( if isinstance(reference, CollectionReference): chunk: List[DocumentSnapshot] - for chunk in ( - reference.recursive() - .select([FieldPath.document_id()]) - ._chunkify(chunk_size) - ): + query = reference.recursive().select([FieldPath.document_id()]) + for chunk in cast(Query, query)._chunkify(chunk_size): doc_snap: DocumentSnapshot for doc_snap in chunk: num_deleted += 1 diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 3488275dd7..858e58e312 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -14,20 +14,18 @@ """Classes for representing collections for the Google Cloud Firestore API.""" +from typing import Any, Callable, cast, Generator, Tuple + from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.cloud.firestore_v1.base_collection import ( - BaseCollectionReference, - _item_to_document_ref, -) -from google.cloud.firestore_v1 import query as query_mod -from google.cloud.firestore_v1.watch import Watch from google.cloud.firestore_v1 import document -from typing import Any, Callable, Generator, Tuple - -# Types needed only for Type Hints +from google.cloud.firestore_v1 import query as query_mod +from google.cloud.firestore_v1.base_collection import BaseCollectionReference +from google.cloud.firestore_v1.base_collection import _item_to_document_ref +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry +from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.watch import Watch class CollectionReference(BaseCollectionReference): @@ -71,7 +69,7 @@ def add( self, document_data: dict, document_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. @@ -104,13 +102,14 @@ def add( document_ref, kwargs = self._prep_add( document_data, document_id, retry, timeout, ) - write_result = document_ref.create(document_data, **kwargs) + document_ref = cast(document.DocumentReference, document_ref) + write_result: write.Write = document_ref.create(document_data, **kwargs) return write_result.update_time, document_ref def list_documents( self, page_size: int = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Generator[Any, Any, None]: """List all subdocuments of the current collection. @@ -143,7 +142,7 @@ def _chunkify(self, chunk_size: int): def get( self, transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> list: """Read the documents in this collection. @@ -174,7 +173,7 @@ def get( def stream( self, transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Generator[document.DocumentSnapshot, Any, None]: """Read the documents in this collection. diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index acdab69e7a..4219d2f25e 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -15,21 +15,19 @@ """Classes for representing documents for the Google Cloud Firestore API.""" import datetime import logging +from typing import Any, Callable, Generator, Iterable from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore - -from google.cloud.firestore_v1.base_document import ( - BaseDocumentReference, - DocumentSnapshot, - _first_write_result, -) +from google.cloud._helpers import _datetime_to_pb_timestamp +from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_document import _first_write_result +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.watch import Watch -from google.protobuf.timestamp_pb2 import Timestamp -from typing import Any, Callable, Generator, Iterable logger = logging.getLogger(__name__) @@ -66,7 +64,7 @@ def __init__(self, *path, **kwargs) -> None: def create( self, document_data: dict, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> write.WriteResult: """Create a document in the Firestore database. @@ -103,7 +101,7 @@ def set( self, document_data: dict, merge: bool = False, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> write.WriteResult: """Create / replace / merge a document in the Firestore database. @@ -171,7 +169,7 @@ def update( self, field_updates: dict, option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> write.WriteResult: """Update an existing document in the Firestore database. @@ -328,7 +326,7 @@ def update( def delete( self, option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Timestamp: """Delete the current document in the Firestore database. @@ -361,7 +359,7 @@ def get( self, field_paths: Iterable[str] = None, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> DocumentSnapshot: """Retrieve a snapshot of the current document. @@ -428,7 +426,7 @@ def get( def collections( self, page_size: int = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Generator[Any, Any, None]: """List subcollections of the current document. diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 24683fb843..cc4eb6a45c 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -56,7 +56,7 @@ def _tokenize_field_path(path: str): match = get_token(path) while match is not None: type_ = match.lastgroup - value = match.group(type_) + value = match.group(type_) # type: ignore yield value pos = match.end() match = get_token(path, pos) diff --git a/google/cloud/firestore_v1/order.py b/google/cloud/firestore_v1/order.py index 37052f9f57..257caf539b 100644 --- a/google/cloud/firestore_v1/order.py +++ b/google/cloud/firestore_v1/order.py @@ -13,10 +13,12 @@ # limitations under the License. from enum import Enum -from google.cloud.firestore_v1._helpers import decode_value import math from typing import Any +from google.cloud.firestore_v1._helpers import decode_value +from google.cloud.firestore_v1._helpers import GeoPoint + class TypeOrder(Enum): # NOTE: This order is defined by the backend and cannot be changed. @@ -123,6 +125,8 @@ def compare_timestamps(left, right) -> Any: def compare_geo_points(left, right) -> Any: left_value = decode_value(left, None) right_value = decode_value(right, None) + assert isinstance(left_value, GeoPoint) + assert isinstance(right_value, GeoPoint) cmp = (left_value.latitude > right_value.latitude) - ( left_value.latitude < right_value.latitude ) diff --git a/google/cloud/firestore_v1/py.typed b/google/cloud/firestore_v1/py.typed index 35a48b3acc..f5e325dd02 100644 --- a/google/cloud/firestore_v1/py.typed +++ b/google/cloud/firestore_v1/py.typed @@ -1,2 +1,2 @@ # Marker file for PEP 561. -# The google-cloud-firestore package uses inline types. +# The google-cloud-firestore_v1 package uses inline types. diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 25ac92cc2f..40baf89ce2 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -18,24 +18,24 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create a query than direct usage of the constructor. """ -from google.cloud import firestore_v1 -from google.cloud.firestore_v1.base_document import DocumentSnapshot +from typing import Any, Callable, cast, Generator, List, Optional, Type + from google.api_core import exceptions from google.api_core import gapic_v1 -from google.api_core import retry as retries +from google.cloud import firestore_v1 # to break cyccles in type decls +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1 import document +from google.cloud.firestore_v1.base_query import BaseCollectionGroup +from google.cloud.firestore_v1.base_query import BaseQuery +from google.cloud.firestore_v1.base_query import QueryPartition from google.cloud.firestore_v1.base_query import ( - BaseCollectionGroup, - BaseQuery, - QueryPartition, - _query_response_to_snapshot, _collection_group_query_response_to_snapshot, - _enum_from_direction, ) - -from google.cloud.firestore_v1 import document +from google.cloud.firestore_v1.base_query import _enum_from_direction +from google.cloud.firestore_v1.base_query import _query_response_to_snapshot +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List, Optional, Type class Query(BaseQuery): @@ -125,7 +125,7 @@ def __init__( def get( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> List[DocumentSnapshot]: """Read the documents in the collection that match this query. @@ -162,11 +162,13 @@ def get( ) self._limit_to_last = False - result = self.stream(transaction=transaction, retry=retry, timeout=timeout) + result = list( + self.stream(transaction=transaction, retry=retry, timeout=timeout) + ) if is_limited_to_last: - result = reversed(list(result)) + result.reverse() - return list(result) + return result def _chunkify( self, chunk_size: int @@ -174,7 +176,7 @@ def _chunkify( max_to_return: Optional[int] = self._limit num_returned: int = 0 - original: Query = self._copy() + original: Query = cast(Query, self._copy()) last_document: Optional[DocumentSnapshot] = None while True: @@ -189,7 +191,7 @@ def _chunkify( if last_document: _q = _q.start_after(last_document) - snapshots = _q.get() + snapshots = cast(Query, _q).get() if snapshots: last_document = snapshots[-1] @@ -233,8 +235,8 @@ def _retry_query_after_exception(self, exc, retry, transaction): def stream( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> Generator[document.DocumentSnapshot, Any, None]: """Read the documents in the collection that match this query. @@ -277,7 +279,7 @@ def stream( response = next(response_iterator, None) except exceptions.GoogleAPICallError as exc: if self._retry_query_after_exception(exc, retry, transaction): - new_query = self.start_after(last_snapshot) + new_query = cast(Query, self.start_after(last_snapshot)) response_iterator, _ = new_query._get_stream_iterator( transaction, retry, timeout, ) @@ -387,8 +389,8 @@ def _get_query_class(): def get_partitions( self, partition_count, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> Generator[QueryPartition, None, None]: """Partition a query for parallelization. diff --git a/google/cloud/firestore_v1/rate_limiter.py b/google/cloud/firestore_v1/rate_limiter.py index ee920edae0..e6aeb54569 100644 --- a/google/cloud/firestore_v1/rate_limiter.py +++ b/google/cloud/firestore_v1/rate_limiter.py @@ -13,7 +13,7 @@ # limitations under the License. import datetime -from typing import NoReturn, Optional +from typing import Optional def utcnow(): @@ -99,7 +99,7 @@ def _start_clock(self): self._start = self._start or utcnow() self._last_refill = self._last_refill or utcnow() - def take_tokens(self, num: Optional[int] = 1, allow_less: bool = False) -> int: + def take_tokens(self, num: int = 1, allow_less: bool = False) -> int: """Returns the number of available tokens, up to the amount requested.""" self._start_clock() self._check_phase() @@ -144,12 +144,12 @@ def _check_phase(self): if operations_last_phase and self._phase > previous_phase: self._increase_maximum_tokens() - def _increase_maximum_tokens(self) -> NoReturn: + def _increase_maximum_tokens(self): self._maximum_tokens = round(self._maximum_tokens * 1.5) if self._global_max_tokens is not None: self._maximum_tokens = min(self._maximum_tokens, self._global_max_tokens) - def _refill(self) -> NoReturn: + def _refill(self): """Replenishes any tokens that should have regenerated since the last operation.""" now: datetime.datetime = utcnow() diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index cfcb968c8f..b9ee75630e 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -14,37 +14,32 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" - import random import time +from typing import Any, Callable, Generator, Optional, Sequence +from google.api_core import exceptions from google.api_core import gapic_v1 -from google.api_core import retry as retries - -from google.cloud.firestore_v1.base_transaction import ( - _BaseTransactional, - BaseTransaction, - MAX_ATTEMPTS, - _CANT_BEGIN, - _CANT_ROLLBACK, - _CANT_COMMIT, - _WRITE_READ_ONLY, - _INITIAL_SLEEP, - _MAX_SLEEP, - _MULTIPLIER, - _EXCEED_ATTEMPTS_TEMPLATE, -) -from google.api_core import exceptions +from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import batch +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_transaction import BaseTransaction +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS +from google.cloud.firestore_v1.base_transaction import _BaseTransactional +from google.cloud.firestore_v1.base_transaction import _CANT_BEGIN +from google.cloud.firestore_v1.base_transaction import _CANT_COMMIT +from google.cloud.firestore_v1.base_transaction import _CANT_ROLLBACK +from google.cloud.firestore_v1.base_transaction import _EXCEED_ATTEMPTS_TEMPLATE +from google.cloud.firestore_v1.base_transaction import _INITIAL_SLEEP +from google.cloud.firestore_v1.base_transaction import _MAX_SLEEP +from google.cloud.firestore_v1.base_transaction import _MULTIPLIER +from google.cloud.firestore_v1.base_transaction import _WRITE_READ_ONLY from google.cloud.firestore_v1.document import DocumentReference -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.query import Query - -# Types needed only for Type Hints -from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.services.firestore.client import OptionalRetry from google.cloud.firestore_v1.types import CommitResponse -from typing import Any, Callable, Generator, Optional +from google.cloud.firestore_v1.types import write class Transaction(batch.WriteBatch, BaseTransaction): @@ -65,7 +60,7 @@ def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: super(Transaction, self).__init__(client) BaseTransaction.__init__(self, max_attempts, read_only) - def _add_write_pbs(self, write_pbs: list) -> None: + def _add_write_pbs(self, write_pbs: Sequence[write.Write]) -> None: """Add `Write`` protobufs to this transaction. Args: @@ -139,6 +134,7 @@ def _commit(self) -> list: if not self.in_progress: raise ValueError(_CANT_COMMIT) + assert self._id is not None commit_response = _commit_with_retry(self._client, self._write_pbs, self._id) self._clean_up() @@ -147,7 +143,7 @@ def _commit(self) -> list: def get_all( self, references: list, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Generator[DocumentSnapshot, Any, None]: """Retrieves multiple documents from Firestore. @@ -170,7 +166,7 @@ def get_all( def get( self, ref_or_query, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, ) -> Generator[DocumentSnapshot, Any, None]: """Retrieve a document or a query result from the database. @@ -331,7 +327,7 @@ def transactional(to_wrap: Callable) -> _Transactional: def _commit_with_retry( - client, write_pbs: list, transaction_id: bytes + client, write_pbs: Sequence[write.Write], transaction_id: bytes ) -> CommitResponse: """Call ``Commit`` on the GAPIC client with retry / sleep. diff --git a/mypy.ini b/mypy.ini index 4505b48543..de30420d80 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,24 @@ [mypy] python_version = 3.6 namespace_packages = True + +[mypy-grpc.*] +ignore_missing_imports = True + +[mypy-google.auth.*] +ignore_missing_imports = True + +[mypy-google.oauth2.*] +ignore_missing_imports = True + +[mypy-google.rpc.*] +ignore_missing_imports = True + +[mypy-google.type.*] +ignore_missing_imports = True + +[mypy-proto.*] +ignore_missing_imports = True + +[mypy-pytest] +ignore_missing_imports = True diff --git a/noxfile.py b/noxfile.py index b388f2797b..5b91b8d7d7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -85,9 +85,12 @@ def pytype(session): def mypy(session): """Verify type hints are mypy compatible.""" session.install("-e", ".") - session.install("mypy", "types-setuptools") - # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google.cloud.firestore", "--no-incremental") + session.install( + "mypy", "types-setuptools", "types-protobuf", "types-dataclasses", "types-mock", + ) + # Note: getenerated tests (in 'tests/unit/gapic') are not yet + # mypy-safe + session.run("mypy", "google/", "tests/unit/v1/", "tests/system/") @nox.session(python=DEFAULT_PYTHON_VERSION) diff --git a/owlbot.py b/owlbot.py index 1b86d222e7..0cfa9c2bcd 100644 --- a/owlbot.py +++ b/owlbot.py @@ -305,9 +305,12 @@ def pytype(session): def mypy(session): """Verify type hints are mypy compatible.""" session.install("-e", ".") - session.install("mypy", "types-setuptools") - # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google.cloud.firestore", "--no-incremental") + session.install( + "mypy", "types-setuptools", "types-protobuf", "types-dataclasses", "types-mock", + ) + # Note: getenerated tests (in 'tests/unit/gapic') are not yet + # mypy-safe + session.run("mypy", "google/", "tests/unit/v1/", "tests/system/") @nox.session(python=DEFAULT_PYTHON_VERSION) diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index f5541fd8a2..b5d304a5ac 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -1,8 +1,9 @@ import os import re + from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST -from test_utils.system import unique_resource_id -from test_utils.system import EmulatorCreds +from test_utils.system import unique_resource_id # type: ignore +from test_utils.system import EmulatorCreds # type: ignore FIRESTORE_CREDS = os.environ.get("FIRESTORE_APPLICATION_CREDENTIALS") FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") diff --git a/tests/system/test_system.py b/tests/system/test_system.py index b0bf4d5406..6aa58c8fad 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -16,9 +16,8 @@ import itertools import math import operator - -from google.oauth2 import service_account -import pytest +from time import sleep +from typing import Callable, cast, Dict, List, Optional from google.api_core.exceptions import AlreadyExists from google.api_core.exceptions import FailedPrecondition @@ -26,20 +25,17 @@ from google.api_core.exceptions import NotFound from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud._helpers import UTC -from google.cloud import firestore_v1 as firestore +from google.oauth2 import service_account # type: ignore +import pytest # type: ignore -from time import sleep -from typing import Callable, Dict, List, Optional - -from tests.system.test__helpers import ( - FIRESTORE_CREDS, - FIRESTORE_PROJECT, - RANDOM_ID_REGEX, - MISSING_DOCUMENT, - UNIQUE_RESOURCE_ID, - EMULATOR_CREDS, - FIRESTORE_EMULATOR, -) +from google.cloud import firestore_v1 as firestore +from tests.system.test__helpers import FIRESTORE_CREDS +from tests.system.test__helpers import FIRESTORE_PROJECT +from tests.system.test__helpers import RANDOM_ID_REGEX +from tests.system.test__helpers import MISSING_DOCUMENT +from tests.system.test__helpers import UNIQUE_RESOURCE_ID +from tests.system.test__helpers import EMULATOR_CREDS +from tests.system.test__helpers import FIRESTORE_EMULATOR def _get_credentials_and_project(): @@ -1269,7 +1265,7 @@ def _persist_documents( for block in documents: col_ref = client.collection(collection_name) document_id: str = block["data"]["name"] - doc_ref = col_ref.document(document_id) + doc_ref = cast(firestore.DocumentReference, col_ref.document(document_id)) doc_ref.set(block["data"]) if cleanup is not None: cleanup(doc_ref.delete) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index b4f8dddbf8..3fe1ae57c3 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -16,11 +16,11 @@ import datetime import itertools import math -import pytest import operator from typing import Callable, Dict, List, Optional -from google.oauth2 import service_account +from google.oauth2 import service_account # type: ignore +import pytest # type: ignore from google.api_core.exceptions import AlreadyExists from google.api_core.exceptions import FailedPrecondition @@ -30,15 +30,13 @@ from google.cloud._helpers import UTC from google.cloud import firestore_v1 as firestore -from tests.system.test__helpers import ( - FIRESTORE_CREDS, - FIRESTORE_PROJECT, - RANDOM_ID_REGEX, - MISSING_DOCUMENT, - UNIQUE_RESOURCE_ID, - EMULATOR_CREDS, - FIRESTORE_EMULATOR, -) +from tests.system.test__helpers import FIRESTORE_CREDS +from tests.system.test__helpers import FIRESTORE_PROJECT +from tests.system.test__helpers import RANDOM_ID_REGEX +from tests.system.test__helpers import MISSING_DOCUMENT +from tests.system.test__helpers import UNIQUE_RESOURCE_ID +from tests.system.test__helpers import EMULATOR_CREDS +from tests.system.test__helpers import FIRESTORE_EMULATOR _test_event_loop = asyncio.new_event_loop() pytestmark = pytest.mark.asyncio diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py index 92d20b7ece..b23676473d 100644 --- a/tests/unit/v1/_test_helpers.py +++ b/tests/unit/v1/_test_helpers.py @@ -89,7 +89,7 @@ class FakeThreadPoolExecutor: def __init__(self, *args, **kwargs): self._shutdown = False - def submit(self, callable) -> typing.NoReturn: + def submit(self, callable): if self._shutdown: raise RuntimeError( "cannot schedule new futures after shutdown" diff --git a/tests/unit/v1/test_base_client.py b/tests/unit/v1/test_base_client.py index 42f9b25ca4..e7cb04c959 100644 --- a/tests/unit/v1/test_base_client.py +++ b/tests/unit/v1/test_base_client.py @@ -91,9 +91,10 @@ def test_baseclient__firestore_api_helper_w_already(): def test_baseclient__firestore_api_helper_wo_emulator(): + endpoint = "https://example.com/api" client = _make_default_base_client() client_options = client._client_options = mock.Mock() - target = client._target = mock.Mock() + client_options.api_endpoint = endpoint assert client._firestore_api_internal is None transport_class = mock.Mock() @@ -106,10 +107,10 @@ def test_baseclient__firestore_api_helper_wo_emulator(): assert client._firestore_api_internal is api channel_options = {"grpc.keepalive_time_ms": 30000} transport_class.create_channel.assert_called_once_with( - target, credentials=client._credentials, options=channel_options.items() + endpoint, credentials=client._credentials, options=channel_options.items() ) transport_class.assert_called_once_with( - host=target, channel=transport_class.create_channel.return_value, + host=endpoint, channel=transport_class.create_channel.return_value, ) client_class.assert_called_once_with( transport=transport_class.return_value, client_options=client_options @@ -123,7 +124,6 @@ def test_baseclient__firestore_api_helper_w_emulator(): client = _make_default_base_client() client_options = client._client_options = mock.Mock() - target = client._target = mock.Mock() emulator_channel = client._emulator_channel = mock.Mock() assert client._firestore_api_internal is None @@ -138,7 +138,7 @@ def test_baseclient__firestore_api_helper_w_emulator(): emulator_channel.assert_called_once_with(transport_class) transport_class.assert_called_once_with( - host=target, channel=emulator_channel.return_value, + host=emulator_host, channel=emulator_channel.return_value, ) client_class.assert_called_once_with( transport=transport_class.return_value, client_options=client_options diff --git a/tests/unit/v1/test_bulk_writer.py b/tests/unit/v1/test_bulk_writer.py index dc185d387e..aee790522e 100644 --- a/tests/unit/v1/test_bulk_writer.py +++ b/tests/unit/v1/test_bulk_writer.py @@ -13,7 +13,7 @@ # limitations under the License. import datetime -from typing import List, NoReturn, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type import aiounittest # type: ignore import mock @@ -21,7 +21,6 @@ from google.cloud.firestore_v1 import async_client from google.cloud.firestore_v1 import client -from google.cloud.firestore_v1 import base_client def _make_no_send_bulk_writer(*args, **kwargs): @@ -42,7 +41,7 @@ class NoSendBulkWriter(BulkWriter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._responses: List[ - Tuple[BulkWriteBatch, BatchWriteResponse, BulkWriterOperation] + Tuple[BulkWriteBatch, BatchWriteResponse, List[BulkWriterOperation]] ] = [] self._fail_indices: List[int] = [] @@ -68,7 +67,7 @@ def _process_response( batch: BulkWriteBatch, response: BatchWriteResponse, operations: List[BulkWriterOperation], - ) -> NoReturn: + ): super()._process_response(batch, response, operations) self._responses.append((batch, response, operations)) @@ -145,28 +144,18 @@ def _doc_iter(self, client, num: int, ids: Optional[List[str]] = None): yield _get_document_reference(client, id=id), {"id": _} def _verify_bw_activity(self, bw, counts: List[Tuple[int, int]]): - """ - Args: - bw: (BulkWriter) - The BulkWriter instance to inspect. - counts: (tuple) A sequence of integer pairs, with 0-index integers - representing the size of sent batches, and 1-index integers - representing the number of times batches of that size should - have been sent. - """ from google.cloud.firestore_v1.types.firestore import BatchWriteResponse total_batches = sum([el[1] for el in counts]) assert len(bw._responses) == total_batches - docs_count = {} + expected_counts = dict(counts) + docs_count: Dict[int, int] = {} resp: BatchWriteResponse for _, resp, ops in bw._responses: docs_count.setdefault(len(resp.write_results), 0) docs_count[len(resp.write_results)] += 1 - assert len(docs_count) == len(counts) - for size, num_sent in counts: - assert docs_count[size] == num_sent + assert docs_count == expected_counts # Assert flush leaves no operation behind assert len(bw._operations) == 0 @@ -700,8 +689,6 @@ def test_scheduling_operation_retry_scheduling(): def _get_document_reference( - client: base_client.BaseClient, - collection_name: Optional[str] = "col", - id: Optional[str] = None, + client, collection_name: str = "col", id: str = None, ) -> Type: return client.collection(collection_name).document(id) diff --git a/tests/unit/v1/test_rate_limiter.py b/tests/unit/v1/test_rate_limiter.py index e5068b3590..85b2147dc6 100644 --- a/tests/unit/v1/test_rate_limiter.py +++ b/tests/unit/v1/test_rate_limiter.py @@ -22,7 +22,7 @@ fake_now = datetime.datetime.utcnow() -def now_plus_n(seconds: int = 0, microseconds: int = 0) -> datetime.timedelta: +def now_plus_n(seconds: int = 0, microseconds: int = 0) -> datetime.datetime: return fake_now + datetime.timedelta(seconds=seconds, microseconds=microseconds,)