Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: declare typing for 'firestore_v1' and 'firestore_bundle' packages #507

Closed
wants to merge 9 commits into from
2 changes: 1 addition & 1 deletion google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
except ImportError:
import pkgutil

__path__ = pkgutil.extend_path(__path__, __name__)
__path__ = pkgutil.extend_path(__path__, __name__) # type: ignore
2 changes: 1 addition & 1 deletion google/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
except ImportError:
import pkgutil

__path__ = pkgutil.extend_path(__path__, __name__)
__path__ = pkgutil.extend_path(__path__, __name__) # type: ignore
33 changes: 15 additions & 18 deletions google/cloud/firestore_bundle/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,25 @@

import datetime
import json
from typing import Dict, List, Optional, Union

from google.cloud._helpers import _datetime_to_pb_timestamp
from google.cloud._helpers import UTC
from google.protobuf.timestamp_pb2 import Timestamp
from google.protobuf import json_format

from google.cloud.firestore_bundle.types.bundle import (
BundledDocumentMetadata,
BundledQuery,
BundleElement,
BundleMetadata,
NamedQuery,
)
from google.cloud._helpers import _datetime_to_pb_timestamp, UTC # type: ignore
from google.cloud.firestore_bundle._helpers import limit_type_of_query
from google.cloud.firestore_bundle.types.bundle import BundleElement
from google.cloud.firestore_bundle.types.bundle import BundleMetadata
from google.cloud.firestore_bundle.types.bundle import BundledDocumentMetadata
from google.cloud.firestore_bundle.types.bundle import BundledQuery
from google.cloud.firestore_bundle.types.bundle import NamedQuery
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.async_query import AsyncQuery
from google.cloud.firestore_v1.base_client import BaseClient
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import BaseQuery
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1 import _helpers
from google.protobuf.timestamp_pb2 import Timestamp # type: ignore
from google.protobuf import json_format # type: ignore
from typing import (
Dict,
List,
Optional,
Union,
)


class FirestoreBundle:
Expand Down Expand Up @@ -333,8 +328,10 @@ def build(self) -> str:
BundleElement(document_metadata=bundled_document.metadata)
)
document_count += 1
document_msg = bundled_document.snapshot._to_protobuf()
assert document_msg is not None
buffer += self._compile_bundle_element(
BundleElement(document=bundled_document.snapshot._to_protobuf()._pb,)
BundleElement(document=document_msg._pb,)
)

metadata: BundleElement = BundleElement(
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/firestore_bundle/py.typed
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Marker file for PEP 561.
# The google-cloud-bundle package uses inline types.
# The google-cloud-firestore_bundle package uses inline types.
3 changes: 1 addition & 2 deletions google/cloud/firestore_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@

"""Python idiomatic client for Google Cloud Firestore."""


import pkg_resources

try:
__version__ = pkg_resources.get_distribution("google-cloud-firestore").version
except pkg_resources.DistributionNotFound:
__version__ = None
__version__ = None # type: ignore

from google.cloud.firestore_v1 import types
from google.cloud.firestore_v1._helpers import GeoPoint
Expand Down
31 changes: 18 additions & 13 deletions google/cloud/firestore_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
Generator,
Iterator,
List,
NoReturn,
Optional,
Tuple,
Union,
Expand Down Expand Up @@ -496,7 +495,7 @@ def __init__(self, document_data) -> None:
self.increments = {}
self.minimums = {}
self.maximums = {}
self.set_fields = {}
self.set_fields: Dict[str, Any] = {}
self.empty_document = False

prefix_path = FieldPath()
Expand Down Expand Up @@ -559,13 +558,16 @@ def transform_paths(self):
+ list(self.minimums)
)

def _get_update_mask(self, allow_empty_mask=False) -> None:
def _get_update_mask(
self, allow_empty_mask=False
) -> Union[types.common.DocumentMask, None]:
return None

def get_update_pb(
self, document_path, exists=None, allow_empty_mask=False
) -> types.write.Write:

current_document: Union[common.Precondition, None]
if exists is not None:
current_document = common.Precondition(exists=exists)
else:
Expand Down Expand Up @@ -725,9 +727,9 @@ class DocumentExtractorForMerge(DocumentExtractor):

def __init__(self, document_data) -> None:
super(DocumentExtractorForMerge, self).__init__(document_data)
self.data_merge = []
self.transform_merge = []
self.merge = []
self.data_merge: List[str] = []
self.transform_merge: List[FieldPath] = []
self.merge: List[FieldPath] = []

def _apply_merge_all(self) -> None:
self.data_merge = sorted(self.field_paths + self.deleted_fields)
Expand Down Expand Up @@ -783,7 +785,7 @@ def _apply_merge_paths(self, merge) -> None:
self.data_merge.append(field_path)

# Clear out data for fields not merged.
merged_set_fields = {}
merged_set_fields: Dict[str, Any] = {}
for field_path in self.data_merge:
value = get_field_value(self.document_data, field_path)
set_field_value(merged_set_fields, field_path, value)
Expand Down Expand Up @@ -834,7 +836,7 @@ def apply_merge(self, merge) -> None:

def _get_update_mask(
self, allow_empty_mask=False
) -> Optional[types.common.DocumentMask]:
) -> Union[types.common.DocumentMask, None]:
# Mask uses dotted / quoted paths.
mask_paths = [
field_path.to_api_repr()
Expand Down Expand Up @@ -903,7 +905,9 @@ def _get_document_iterator(
) -> Generator[Tuple[Any, Any], Any, None]:
return extract_fields(self.document_data, prefix_path, expand_dots=True)

def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask:
def _get_update_mask(
self, allow_empty_mask=False
) -> Union[types.common.DocumentMask, None]:
mask_paths = []
for field_path in self.top_level_paths:
if field_path not in self.transform_paths:
Expand Down Expand Up @@ -1017,7 +1021,7 @@ def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]:
class WriteOption(object):
"""Option used to assert a condition on a write operation."""

def modify_write(self, write, no_create_msg=None) -> NoReturn:
def modify_write(self, write) -> None:
"""Modify a ``Write`` protobuf based on the state of this write option.

This is a virtual method intended to be implemented by subclasses.
Expand All @@ -1026,8 +1030,6 @@ def modify_write(self, write, no_create_msg=None) -> NoReturn:
write (google.cloud.firestore_v1.types.Write): A
``Write`` protobuf instance to be modified with a precondition
determined by the state of this option.
no_create_msg (Optional[str]): A message to use to indicate that
a create operation is not allowed.

Raises:
NotImplementedError: Always, this method is virtual.
Expand Down Expand Up @@ -1233,6 +1235,8 @@ def deserialize_bundle(
raise ValueError("Unexpected end to serialized FirestoreBundle")

# Now, finally add the metadata element
assert bundle is not None

bundle._add_bundle_element(
metadata_bundle_element, client=client, type="metadata", # type: ignore
)
Expand Down Expand Up @@ -1296,7 +1300,8 @@ def _get_documents_from_bundle(

def _get_document_from_bundle(
bundle, *, document_id: str,
) -> Optional["google.cloud.firestore.DocumentSnapshot"]: # type: ignore
) -> Union["google.cloud.firestore.DocumentSnapshot", None]: # type: ignore
bundled_doc = bundle.documents.get(document_id)
if bundled_doc:
return bundled_doc.snapshot
return None
4 changes: 2 additions & 2 deletions google/cloud/firestore_v1/async_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@


from google.api_core import gapic_v1
from google.api_core import retry as retries

from google.cloud.firestore_v1.base_batch import BaseWriteBatch
from google.cloud.firestore_v1.services.firestore.client import OptionalRetry


class AsyncWriteBatch(BaseWriteBatch):
Expand All @@ -37,7 +37,7 @@ def __init__(self, client) -> None:
super(AsyncWriteBatch, self).__init__(client=client)

async def commit(
self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None,
self, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None,
) -> list:
"""Commit the changes accumulated in this batch.

Expand Down
87 changes: 51 additions & 36 deletions google/cloud/firestore_v1/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,39 @@
:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`
"""

from google.api_core import gapic_v1
from google.api_core import retry as retries

from google.cloud.firestore_v1.base_client import (
BaseClient,
DEFAULT_DATABASE,
_CLIENT_INFO,
_parse_batch_get, # type: ignore
_path_helper,
from typing import (
Any,
cast,
AsyncGenerator,
Iterable,
List,
Optional,
Union,
)

from google.cloud.firestore_v1.async_query import AsyncCollectionGroup
from google.api_core import gapic_v1

from google.cloud.firestore_v1.async_batch import AsyncWriteBatch
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference
from google.cloud.firestore_v1.async_document import (
AsyncDocumentReference,
DocumentSnapshot,
)
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1.async_document import DocumentSnapshot
from google.cloud.firestore_v1.async_query import AsyncCollectionGroup
from google.cloud.firestore_v1.async_query import AsyncQuery
from google.cloud.firestore_v1.async_transaction import AsyncTransaction
from google.cloud.firestore_v1.base_client import BaseClient
from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE
from google.cloud.firestore_v1.base_client import _CLIENT_INFO
from google.cloud.firestore_v1.base_client import _parse_batch_get # type: ignore
from google.cloud.firestore_v1.base_client import _path_helper
from google.cloud.firestore_v1.bulk_writer import BulkWriter
from google.cloud.firestore_v1.field_path import FieldPath
from google.cloud.firestore_v1.services.firestore import (
async_client as firestore_client,
)
from google.cloud.firestore_v1.services.firestore.async_client import OptionalRetry
from google.cloud.firestore_v1.services.firestore.transports import (
grpc_asyncio as firestore_grpc_transport,
)
from typing import Any, AsyncGenerator, Iterable, List, Optional, Union, TYPE_CHECKING

if TYPE_CHECKING:
from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER


class AsyncClient(BaseClient):
Expand Down Expand Up @@ -84,6 +87,8 @@ class AsyncClient(BaseClient):
should be set through client_options.
"""

_API_CLIENT_CLASS = firestore_client.FirestoreAsyncClient

def __init__(
self,
project=None,
Expand Down Expand Up @@ -126,16 +131,6 @@ def _firestore_api(self):
firestore_client,
)

@property
def _target(self):
"""Return the target (where the API is).
Eg. "firestore.googleapis.com"

Returns:
str: The location of the API.
"""
return self._target_helper(firestore_client.FirestoreAsyncClient)

def collection(self, *collection_path: str) -> AsyncCollectionReference:
"""Get a reference to a collection.

Expand Down Expand Up @@ -189,6 +184,27 @@ def collection_group(self, collection_id: str) -> AsyncCollectionGroup:
"""
return AsyncCollectionGroup(self._get_collection_reference(collection_id))

def _get_collection_reference(self, collection_id: str) -> AsyncCollectionReference:
"""Checks validity of collection_id and then uses subclasses collection implementation.

Args:
collection_id (str) Identifies the collections to query over.

Every collection or subcollection with this ID as the last segment of its
path will be included. Cannot contain a slash.

Returns:
The created collection.
"""
if "/" in collection_id:
raise ValueError(
"Invalid collection_id "
+ collection_id
+ ". Collection IDs must not contain '/'."
)

return self.collection(collection_id)

def document(self, *document_path: str) -> AsyncDocumentReference:
"""Get a reference to a document in a collection.

Expand Down Expand Up @@ -229,7 +245,7 @@ async def get_all(
references: List[AsyncDocumentReference],
field_paths: Iterable[str] = None,
transaction=None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> AsyncGenerator[DocumentSnapshot, Any]:
"""Retrieve a batch of documents.
Expand Down Expand Up @@ -282,7 +298,7 @@ async def get_all(
yield _parse_batch_get(get_doc_response, reference_map, self)

async def collections(
self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None,
self, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None,
) -> AsyncGenerator[AsyncCollectionReference, Any]:
"""List top-level collections of the client's database.

Expand All @@ -309,7 +325,7 @@ async def recursive_delete(
reference: Union[AsyncCollectionReference, AsyncDocumentReference],
*,
bulk_writer: Optional["BulkWriter"] = None,
chunk_size: Optional[int] = 5000,
chunk_size: int = 5000,
):
"""Deletes documents and their subcollections, regardless of collection
name.
Expand Down Expand Up @@ -343,18 +359,17 @@ async def _recursive_delete(
reference: Union[AsyncCollectionReference, AsyncDocumentReference],
bulk_writer: "BulkWriter",
*,
chunk_size: Optional[int] = 5000,
depth: Optional[int] = 0,
chunk_size: int = 5000,
depth: int = 0,
) -> int:
"""Recursion helper for `recursive_delete."""

num_deleted: int = 0

if isinstance(reference, AsyncCollectionReference):
chunk: List[DocumentSnapshot]
async for chunk in reference.recursive().select(
[FieldPath.document_id()]
)._chunkify(chunk_size):
query = reference.recursive().select([FieldPath.document_id()])
async for chunk in cast(AsyncQuery, query)._chunkify(chunk_size):
doc_snap: DocumentSnapshot
for doc_snap in chunk:
num_deleted += 1
Expand Down
Loading