diff --git a/.github/workflows/github-actions.yml b/.github/workflows/github-actions.yml index df8d6b25b..3ebd2ddf8 100644 --- a/.github/workflows/github-actions.yml +++ b/.github/workflows/github-actions.yml @@ -49,6 +49,18 @@ jobs: - run: bash .github/workflows/install_ci_python_dep.sh - run: pre-commit run -a + type-check: + # Can be moved to pre-commit, separate step for now. + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.9' + check-latest: true + - run: bash .github/workflows/install_ci_python_typing_deps.sh + - run: mypy mongoengine tests + test: # Test suite run against recent python versions # and against a few combination of MongoDB and pymongo diff --git a/.github/workflows/install_ci_python_typing_deps.sh b/.github/workflows/install_ci_python_typing_deps.sh new file mode 100644 index 000000000..b25a523a1 --- /dev/null +++ b/.github/workflows/install_ci_python_typing_deps.sh @@ -0,0 +1,4 @@ +#!/bin/bash +pip install --upgrade pip +pip install mypy==1.13.0 typing-extensions mongomock types-Pygments types-cffi types-colorama types-pyOpenSSL types-python-dateutil types-requests types-setuptools +pip install -e '.[test]' diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 3b2a884b6..5b16270ce 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -20,12 +20,12 @@ from mongoengine.signals import * # noqa: F401 __all__ = ( - list(document.__all__) - + list(fields.__all__) - + list(connection.__all__) - + list(queryset.__all__) - + list(signals.__all__) - + list(errors.__all__) + document.__all__ + + fields.__all__ + + connection.__all__ + + queryset.__all__ + + signals.__all__ + + errors.__all__ ) diff --git a/mongoengine/_typing.py b/mongoengine/_typing.py new file mode 100644 index 000000000..0e379e017 --- /dev/null +++ b/mongoengine/_typing.py @@ -0,0 +1,6 @@ +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from mongoengine.queryset.queryset import QuerySet + +QS = TypeVar("QS", bound="QuerySet") diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py index fe631a40e..8b97044e5 100644 --- a/mongoengine/base/common.py +++ b/mongoengine/base/common.py @@ -1,9 +1,14 @@ +from __future__ import annotations + import warnings +from typing import TYPE_CHECKING from mongoengine.errors import NotRegistered -__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry") +if TYPE_CHECKING: + from mongoengine.document import Document +__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry") UPDATE_OPERATORS = { "set", @@ -24,7 +29,7 @@ } -_document_registry = {} +_document_registry: dict[str, type[Document]] = {} class _DocumentRegistry: @@ -33,7 +38,7 @@ class _DocumentRegistry: """ @staticmethod - def get(name): + def get(name: str) -> type[Document]: doc = _document_registry.get(name, None) if not doc: # Possible old style name @@ -58,7 +63,7 @@ def get(name): return doc @staticmethod - def register(DocCls): + def register(DocCls: type[Document]) -> None: ExistingDocCls = _document_registry.get(DocCls._class_name) if ( ExistingDocCls is not None @@ -76,7 +81,7 @@ def register(DocCls): _document_registry[DocCls._class_name] = DocCls @staticmethod - def unregister(doc_cls_name): + def unregister(doc_cls_name: str): _document_registry.pop(doc_cls_name) diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index dcb8438c7..4c3ba299a 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -1,10 +1,18 @@ +from __future__ import annotations + import weakref +from typing import TYPE_CHECKING, Any, Generic, TypeVar -from bson import DBRef +from bson import DBRef, ObjectId from mongoengine.common import _import_class from mongoengine.errors import DoesNotExist, MultipleObjectsReturned +if TYPE_CHECKING: + from mongoengine import Document + +_T = TypeVar("_T", bound="Document") + __all__ = ( "BaseDict", "StrictDict", @@ -356,7 +364,7 @@ def update(self, **update): class StrictDict: __slots__ = () _special_fields = {"get", "pop", "iteritems", "items", "keys", "create"} - _classes = {} + _classes: dict[str, Any] = {} def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -435,7 +443,7 @@ def __repr__(self): return cls._classes[allowed_keys] -class LazyReference(DBRef): +class LazyReference(Generic[_T], DBRef): __slots__ = ("_cached_doc", "passthrough", "document_type") def fetch(self, force=False): @@ -449,19 +457,21 @@ def fetch(self, force=False): def pk(self): return self.id - def __init__(self, document_type, pk, cached_doc=None, passthrough=False): + def __init__( + self, document_type: type[_T], pk: ObjectId, cached_doc=None, passthrough=False + ): self.document_type = document_type self._cached_doc = cached_doc self.passthrough = passthrough - super().__init__(self.document_type._get_collection_name(), pk) + super().__init__(self.document_type._get_collection_name(), pk) # type: ignore[arg-type] - def __getitem__(self, name): + def __getitem__(self, name: str) -> Any: if not self.passthrough: raise KeyError() document = self.fetch() return document[name] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if not object.__getattribute__(self, "passthrough"): raise AttributeError() document = self.fetch() diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index ea3962ad7..3afff8afe 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -1,10 +1,15 @@ +# mypy: disable-error-code="attr-defined,union-attr,assignment" +from __future__ import annotations + import copy import numbers import warnings from functools import partial +from typing import TYPE_CHECKING, Any import pymongo from bson import SON, DBRef, ObjectId, json_util +from typing_extensions import Self from mongoengine import signals from mongoengine.base.common import _DocumentRegistry @@ -15,7 +20,7 @@ LazyReference, StrictDict, ) -from mongoengine.base.fields import ComplexBaseField +from mongoengine.base.fields import BaseField, ComplexBaseField from mongoengine.common import _import_class from mongoengine.errors import ( FieldDoesNotExist, @@ -26,12 +31,15 @@ ) from mongoengine.pymongo_support import LEGACY_JSON_OPTIONS +if TYPE_CHECKING: + from mongoengine.fields import DynamicField + __all__ = ("BaseDocument", "NON_FIELD_ERRORS") NON_FIELD_ERRORS = "__all__" try: - GEOHAYSTACK = pymongo.GEOHAYSTACK + GEOHAYSTACK = pymongo.GEOHAYSTACK # type: ignore[attr-defined] except AttributeError: GEOHAYSTACK = None @@ -62,7 +70,12 @@ class BaseDocument: _dynamic_lock = True STRICT = False - def __init__(self, *args, **values): + # Fields, added by metaclass + _class_name: str + _fields: dict[str, BaseField] + _meta: dict[str, Any] + + def __init__(self, *args, **values) -> None: """ Initialise a document or an embedded document. @@ -103,7 +116,7 @@ def __init__(self, *args, **values): else: self._data = {} - self._dynamic_fields = SON() + self._dynamic_fields: SON[str, DynamicField] = SON() # Assign default values for fields # not set in the constructor @@ -329,13 +342,15 @@ def get_text_score(self): return self._data["_text_score"] - def to_mongo(self, use_db_field=True, fields=None): + def to_mongo( + self, use_db_field: bool = True, fields: list[str] | None = None + ) -> SON[Any, Any]: """ Return as SON data ready for use with MongoDB. """ fields = fields or [] - data = SON() + data: SON[str, Any] = SON() data["_id"] = None data["_cls"] = self._class_name @@ -354,7 +369,7 @@ def to_mongo(self, use_db_field=True, fields=None): if value is not None: f_inputs = field.to_mongo.__code__.co_varnames - ex_vars = {} + ex_vars: dict[str, Any] = {} if fields and "fields" in f_inputs: key = "%s." % field_name embedded_fields = [ @@ -370,7 +385,7 @@ def to_mongo(self, use_db_field=True, fields=None): # Handle self generating fields if value is None and field._auto_gen: - value = field.generate() + value = field.generate() # type: ignore[attr-defined] self._data[field_name] = value if value is not None or field.null: @@ -385,7 +400,7 @@ def to_mongo(self, use_db_field=True, fields=None): return data - def validate(self, clean=True): + def validate(self, clean: bool = True) -> None: """Ensure that all fields' values are valid and that required fields are present. @@ -439,7 +454,7 @@ def validate(self, clean=True): message = f"ValidationError ({self._class_name}:{pk}) " raise ValidationError(message, errors=errors) - def to_json(self, *args, **kwargs): + def to_json(self, *args: Any, **kwargs: Any) -> str: """Convert this document to JSON. :param use_db_field: Serialize field names as they appear in @@ -461,7 +476,7 @@ def to_json(self, *args, **kwargs): return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs) @classmethod - def from_json(cls, json_data, created=False, **kwargs): + def from_json(cls, json_data: str, created: bool = False, **kwargs: Any) -> Self: """Converts json data to a Document instance. :param str json_data: The json data to load into the Document. @@ -687,7 +702,7 @@ def _get_changed_fields(self): self._nestable_types_changed_fields(changed_fields, key, data) return changed_fields - def _delta(self): + def _delta(self) -> tuple[dict[str, Any], dict[str, Any]]: """Returns the delta (set, unset) of the changes for a document. Gets any values that have been explicitly changed. """ @@ -771,14 +786,16 @@ def _delta(self): return set_data, unset_data @classmethod - def _get_collection_name(cls): + def _get_collection_name(cls) -> str | None: """Return the collection name for this class. None for abstract class. """ return cls._meta.get("collection", None) @classmethod - def _from_son(cls, son, _auto_dereference=True, created=False): + def _from_son( + cls, son: dict[str, Any], _auto_dereference: bool = True, created: bool = False + ) -> Self: """Create an instance of a Document (subclass) from a PyMongo SON (dict)""" if son and not isinstance(son, dict): raise ValueError( diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index cead14449..5d49fe91d 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import contextlib import operator import threading import weakref +from typing import TYPE_CHECKING, Any, Callable, Iterable, NoReturn import pymongo from bson import SON, DBRef, ObjectId @@ -15,6 +18,10 @@ from mongoengine.common import _import_class from mongoengine.errors import DeprecatedError, ValidationError +if TYPE_CHECKING: + from mongoengine.document import Document + + __all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") @@ -36,8 +43,8 @@ class BaseField: may be added to subclasses of `Document` to define a document's schema. """ - name = None # set in TopLevelDocumentMetaclass - _geo_index = False + name: str = None # type: ignore[assignment] # set in TopLevelDocumentMetaclass + _geo_index: bool | str = False _auto_gen = False # Call `generate` to generate a value _thread_local_storage = threading.local() @@ -49,17 +56,17 @@ class BaseField: def __init__( self, - db_field=None, - required=False, - default=None, - unique=False, - unique_with=None, - primary_key=False, - validation=None, - choices=None, - null=False, - sparse=False, - **kwargs, + db_field: str | None = None, + required: bool = False, + default: Any | None | Callable[[], Any] = None, + unique: bool = False, + unique_with: str | Iterable[str] | None = None, + primary_key: bool = False, + validation: Callable[[Any], None] | None = None, + choices: Any = None, + null: bool = False, + sparse: bool = False, + **kwargs: Any, ): """ :param db_field: The database field to store this field in @@ -173,7 +180,7 @@ def __get__(self, instance, owner): # Get value from document instance if available return instance._data.get(self.name) - def __set__(self, instance, value): + def __set__(self, instance: Any, value: Any) -> None: """Descriptor for assigning a value to a field in a document.""" # If setting to None and there is a default value provided for this # field, then set the value to the default value. @@ -209,16 +216,21 @@ def __set__(self, instance, value): instance._data[self.name] = value - def error(self, message="", errors=None, field_name=None): + def error( + self, + message: str = "", + errors: dict[str, Any] | None = None, + field_name: str | None = None, + ) -> NoReturn: """Raise a ValidationError.""" field_name = field_name if field_name else self.name raise ValidationError(message, errors=errors, field_name=field_name) - def to_python(self, value): + def to_python(self, value: Any) -> Any: """Convert a MongoDB-compatible type to a Python type.""" return value - def to_mongo(self, value): + def to_mongo(self, value: Any) -> Any: """Convert a Python type to a MongoDB-compatible type.""" return self.to_python(value) @@ -234,13 +246,13 @@ def _to_mongo_safe_call(self, value, use_db_field=True, fields=None): return self.to_mongo(value, **ex_vars) - def prepare_query_value(self, op, value): + def prepare_query_value(self, op: str, value: Any) -> Any: """Prepare a value that is being used in a query for PyMongo.""" if op in UPDATE_OPERATORS: self.validate(value) return value - def validate(self, value, clean=True): + def validate(self, value: Any, clean: bool = True) -> None: """Perform validation on a value.""" pass @@ -292,13 +304,13 @@ def _validate(self, value, **kwargs): self.validate(value, **kwargs) @property - def owner_document(self): - return self._owner_document + def owner_document(self) -> type[Document]: + return self._owner_document # type: ignore[return-value] def _set_owner_document(self, owner_document): self._owner_document = owner_document - @owner_document.setter + @owner_document.setter # type: ignore[attr-defined,no-redef] def owner_document(self, owner_document): self._set_owner_document(owner_document) diff --git a/mongoengine/base/fields.pyi b/mongoengine/base/fields.pyi new file mode 100644 index 000000000..f6dc36cee --- /dev/null +++ b/mongoengine/base/fields.pyi @@ -0,0 +1,149 @@ +# pyright: reportIncompatibleMethodOverride=warning +from typing import ( + Any, + Callable, + Generic, + Iterable, + Literal, + NoReturn, + Optional, + Sequence, + TypedDict, + TypeVar, + Union, + overload, +) + +from bson import ObjectId +from typing_extensions import Self, TypeAlias, Unpack + +from mongoengine.document import Document + +__all__ = [ + "BaseField", + "ComplexBaseField", + "ObjectIdField", + "GeoJsonBaseField", + "_no_dereference_for_fields", +] +_ST = TypeVar("_ST") +_GT = TypeVar("_GT") +_F = TypeVar("_F", bound=BaseField) +_Choice: TypeAlias = str | tuple[str, str] +_no_dereference_for_fields: Any + +class _BaseFieldOptions(TypedDict, total=False): + db_field: str + name: str + unique: bool + unique_with: str | Iterable[str] + primary_key: bool + choices: Iterable[_Choice] + null: bool + verbose_name: str + help_text: str + +class BaseField(Generic[_ST, _GT]): + name: str + creation_counter: int + auto_creation_counter: int + db_field: str + required: bool + default: bool + unique: bool + unique_with: str | Iterable[str] | None + primary_key: bool + validation: Callable[[Any], None] | None + choices: Any + null: bool + sparse: bool + + _auto_gen: bool + + def __set__(self, instance: Any, value: _ST) -> None: ... + @overload + def __get__(self, instance: None, owner: Any) -> Self: ... + @overload + def __get__(self, instance: Any, owner: Any) -> _GT: ... + def __init___( + self, + db_field: str | None = None, + required: bool = False, + default: Any | None | Callable[[], Any] = None, + unique: bool = False, + unique_with: str | Iterable[str] | None = None, + primary_key: bool = False, + validation: Callable[[Any], None] | None = None, + choices: Any = None, + null: bool = False, + sparse: bool = False, + **kwargs: Any, + ) -> None: ... + def error( + self, + message: str = "", + errors: dict[str, Any] | None = None, + field_name: str | None = None, + ) -> NoReturn: ... + def to_python(self, value: Any) -> Any: ... + def to_mongo(self, value: Any) -> Any: ... + def prepare_query_value(self, op: str, value: Any) -> Any: ... + def validate(self, value: Any, clean: bool = True) -> None: ... + @property + def owner_document(self) -> type[Document]: ... + @owner_document.setter + def owner_document(self, owner_document: type[Document]) -> None: ... + +class ComplexBaseField(Generic[_F, _ST, _GT], BaseField[_ST, _GT]): + field: _F + def to_python(self, value): ... + def to_mongo( + self, value, use_db_field: bool = True, fields: Sequence[str] | None = None + ): ... + def validate(self, value: Any) -> None: ... # type: ignore[override] + def prepare_query_value(self, op, value): ... + def lookup_member(self, member_name): ... + +class ObjectIdField(BaseField[_ST, _GT]): + # ObjectIdField() + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> ObjectIdField[Optional[ObjectId], Optional[ObjectId]]: ... + # ObjectIdField(default=ObjectId) + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: Union[ObjectId, Callable[[], ObjectId]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> ObjectIdField[Optional[ObjectId], ObjectId]: ... + # ObjectIdField(required=True) + @overload + def __new__( + cls, + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> ObjectIdField[ObjectId, ObjectId]: ... + # ObjectIdField(required=True, default=ObjectId) + @overload + def __new__( + cls, + *, + required: Literal[True], + default: Union[ObjectId, Callable[[], ObjectId]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> ObjectIdField[Optional[ObjectId], ObjectId]: ... + def __set__(self, instance: Any, value: _ST) -> None: ... + +class GeoJsonBaseField(BaseField[dict[str, Any], dict[str, Any]]): + def __init__( + self, auto_index: bool = True, *args: Any, **kwargs: Unpack[_BaseFieldOptions] + ) -> None: ... diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index a311aa167..e9def0877 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -249,6 +249,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): collection in the database. """ + DoesNotExist: type[DoesNotExist] + MultipleObjectsReturned: type[MultipleObjectsReturned] + def __new__(mcs, name, bases, attrs): flattened_bases = mcs._get_bases(bases) super_new = super().__new__ diff --git a/mongoengine/common.py b/mongoengine/common.py index 640384ec0..290df8403 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -1,5 +1,9 @@ -_class_registry_cache = {} -_field_list_cache = [] +from __future__ import annotations + +from typing import Any + +_class_registry_cache: dict[str, Any] = {} +_field_list_cache: list[Any] = [] def _import_class(cls_name): diff --git a/mongoengine/connection.py b/mongoengine/connection.py index a24f0cc36..50747eb36 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,25 +1,29 @@ +from __future__ import annotations + import collections import threading import warnings +from typing import Any from pymongo import MongoClient, ReadPreference, uri_parser from pymongo.common import _UUID_REPRESENTATIONS +from pymongo.database import Database try: from pymongo.database_shared import _check_name except ImportError: - from pymongo.database import _check_name + from pymongo.database import _check_name # type: ignore # DriverInfo was added in PyMongo 3.7. try: from pymongo.driver_info import DriverInfo except ImportError: - DriverInfo = None + DriverInfo = None # type: ignore import mongoengine from mongoengine.pymongo_support import PYMONGO_VERSION -__all__ = [ +__all__ = ( "DEFAULT_CONNECTION_NAME", "DEFAULT_DATABASE_NAME", "ConnectionFailure", @@ -29,17 +33,17 @@ "get_connection", "get_db", "register_connection", -] +) -DEFAULT_CONNECTION_NAME = "default" -DEFAULT_DATABASE_NAME = "test" -DEFAULT_HOST = "localhost" -DEFAULT_PORT = 27017 +DEFAULT_CONNECTION_NAME: str = "default" +DEFAULT_DATABASE_NAME: str = "test" +DEFAULT_HOST: str = "localhost" +DEFAULT_PORT: int = 27017 _connection_settings = {} -_connections = {} -_dbs = {} +_connections: dict[str, MongoClient] = {} +_dbs: dict[str, Any] = {} READ_PREFERENCE = ReadPreference.PRIMARY @@ -220,18 +224,18 @@ def _get_connection_settings( def register_connection( - alias, - db=None, - name=None, - host=None, - port=None, - read_preference=READ_PREFERENCE, - username=None, - password=None, - authentication_source=None, - authentication_mechanism=None, - authmechanismproperties=None, - **kwargs, + alias: str, + db: str | None = None, + name: str | None = None, + host: str | None = None, + port: int | None = None, + read_preference: Any = READ_PREFERENCE, + username: str | None = None, + password: str | None = None, + authentication_source: str | None = None, + authentication_mechanism: str | None = None, + authmechanismproperties: Any = None, + **kwargs: Any, ): """Register the connection settings. @@ -270,7 +274,7 @@ def register_connection( _connection_settings[alias] = conn_settings -def disconnect(alias=DEFAULT_CONNECTION_NAME): +def disconnect(alias: str = DEFAULT_CONNECTION_NAME) -> None: """Close the connection with a given alias.""" from mongoengine import Document from mongoengine.base.common import _get_documents_by_db @@ -297,13 +301,15 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME): del _connection_settings[alias] -def disconnect_all(): +def disconnect_all() -> None: """Close all registered database.""" for alias in list(_connections.keys()): disconnect(alias) -def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): +def get_connection( + alias: str = DEFAULT_CONNECTION_NAME, reconnect: bool = False +) -> MongoClient[Any]: """Return a connection with a given alias.""" # Connect to the database if not already connected @@ -415,7 +421,9 @@ def _clean_settings(settings_dict): return _connections[db_alias] -def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): +def get_db( + alias: str = DEFAULT_CONNECTION_NAME, reconnect: bool = False +) -> Database[Any]: if reconnect: disconnect(alias) @@ -443,7 +451,9 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): return _dbs[alias] -def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): +def connect( + db: str | None = None, alias: str = DEFAULT_CONNECTION_NAME, **kwargs +) -> MongoClient[Any]: """Connect to the database specified by the 'db' argument. Connection settings may be provided here as well if the database is not diff --git a/mongoengine/document.py b/mongoengine/document.py index 829c07135..40698532c 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,8 +1,13 @@ +# mypy: disable-error-code="attr-defined,union-attr,assignment,return-value,arg-type" +from __future__ import annotations + import re +from typing import TYPE_CHECKING, Any, Mapping import pymongo from bson.dbref import DBRef from pymongo.read_preferences import ReadPreference +from typing_extensions import NotRequired, Self, TypedDict from mongoengine import signals from mongoengine.base import ( @@ -39,6 +44,13 @@ transform, ) +if TYPE_CHECKING: + from bson import ObjectId + from pymongo.collection import Collection + + from mongoengine.fields import ObjectIdField + from mongoengine.queryset.manager import QuerySetManager + __all__ = ( "Document", "EmbeddedDocument", @@ -180,6 +192,13 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass): __slots__ = ("__objects",) + id: ObjectIdField[ObjectId, ObjectId] + objects: QuerySetManager[QuerySet[Self]] + meta: _MetaDict + _meta: _UnderMetaDict + _fields: dict[str, Any] + _collection: Collection[Any] | None + @property def pk(self): """Get the primary key.""" @@ -188,7 +207,7 @@ def pk(self): return getattr(self, self._meta["id_field"]) @pk.setter - def pk(self, value): + def pk(self, value: Any): """Set the primary key.""" return setattr(self, self._meta["id_field"], value) @@ -212,7 +231,7 @@ def _disconnect(cls): cls._collection = None @classmethod - def _get_collection(cls): + def _get_collection(cls) -> Collection[Any]: """Return the PyMongo collection corresponding to this document. Upon first call, this method: @@ -312,7 +331,7 @@ def to_mongo(self, *args, **kwargs): return data - def modify(self, query=None, **update): + def modify(self, query: object | None = None, **update) -> bool: """Perform an atomic update of the document in the database and reload the document object using updated version. @@ -652,7 +671,7 @@ def _object_key(self): select_dict["__".join(field_parts)] = val return select_dict - def update(self, **kwargs): + def update(self, **kwargs: Any) -> int: """Performs an update on the :class:`~mongoengine.Document` A convenience wrapper to :meth:`~mongoengine.QuerySet.update`. @@ -671,7 +690,7 @@ def update(self, **kwargs): # Need to add shard key to query, or you get an error return self._qs.filter(**self._object_key).update_one(**kwargs) - def delete(self, signal_kwargs=None, **write_concern): + def delete(self, signal_kwargs: object = None, **write_concern) -> None: """Delete the :class:`~mongoengine.Document` from the database. This will only take effect if the document has been previously saved. @@ -1158,3 +1177,12 @@ def object(self): self._key_object = self._document.objects.with_id(self.key) return self._key_object return self._key_object + + +_MetaDict = Mapping[str, Any] + + +class _UnderMetaDict(TypedDict): + id_field: NotRequired[str] + strict: bool + collection: str diff --git a/mongoengine/errors.py b/mongoengine/errors.py index d789b2a10..57ef0c5a5 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -83,7 +83,7 @@ class ValidationError(AssertionError): individual field. """ - errors = {} + errors: dict[str, str] = {} field_name = None _message = None @@ -114,7 +114,7 @@ def _get_message(self): def _set_message(self, message): self._message = message - message = property(_get_message, _set_message) + message: str = property(_get_message, _set_message) # type: ignore[assignment] def to_dict(self): """Returns a dictionary of all errors within a document diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 980098dfb..294cb2552 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,3 +1,7 @@ +# pyright: reportIncompatibleMethodOverride=warning,reportNoOverloadImplementation=false +# mypy: disable-error-code="override,misc" +from __future__ import annotations + import datetime import decimal import inspect @@ -9,19 +13,14 @@ from inspect import isclass from io import BytesIO from operator import itemgetter +from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar import gridfs import pymongo from bson import SON, Binary, DBRef, ObjectId from bson.decimal128 import Decimal128, create_decimal128_context from pymongo import ReturnDocument - -try: - import dateutil -except ImportError: - dateutil = None -else: - import dateutil.parser +from typing_extensions import Self from mongoengine.base import ( BaseDocument, @@ -49,6 +48,16 @@ from mongoengine.queryset.base import BaseQuerySet from mongoengine.queryset.transform import STRING_OPERATORS +if TYPE_CHECKING: + from enum import Enum + +try: + import dateutil # type: ignore[import-untyped] +except ImportError: + dateutil = None # type: ignore[assignment] +else: + import dateutil.parser # type: ignore[import-untyped] + try: from PIL import Image, ImageOps @@ -58,8 +67,8 @@ LANCZOS = Image.LANCZOS except ImportError: # pillow is optional so may not be installed - Image = None - ImageOps = None + Image = None # type: ignore[assignment] + ImageOps = None # type: ignore[assignment] __all__ = ( @@ -109,6 +118,7 @@ ) RECURSIVE_REFERENCE_CONSTANT = "self" +_T = TypeVar("_T") def _unsaved_object_error(document): @@ -122,7 +132,13 @@ def _unsaved_object_error(document): class StringField(BaseField): """A unicode string field.""" - def __init__(self, regex=None, max_length=None, min_length=None, **kwargs): + def __init__( + self, + regex: str | None = None, + max_length: int | None = None, + min_length: int | None = None, + **kwargs, + ) -> None: """ :param regex: (optional) A string pattern that will be applied during validation :param max_length: (optional) A max length that will be applied during validation @@ -205,7 +221,12 @@ class URLField(StringField): ) _URL_SCHEMES = ["http", "https", "ftp", "ftps"] - def __init__(self, url_regex=None, schemes=None, **kwargs): + def __init__( + self, + url_regex: str | None = None, + schemes: Iterable[str] | None = None, + **kwargs, + ) -> None: """ :param url_regex: (optional) Overwrite the default regex used for validation :param schemes: (optional) Overwrite the default URL schemes that are allowed @@ -257,11 +278,11 @@ class EmailField(StringField): def __init__( self, - domain_whitelist=None, - allow_utf8_user=False, - allow_ip_domain=False, - *args, - **kwargs, + domain_whitelist: list[str] | None = None, + allow_utf8_user: bool = False, + allow_ip_domain: bool = False, + *args: Any, + **kwargs: Any, ): """ :param domain_whitelist: (optional) list of valid domain names applied during validation @@ -338,7 +359,12 @@ def validate(self, value): class IntField(BaseField): """32-bit integer field.""" - def __init__(self, min_value=None, max_value=None, **kwargs): + def __init__( + self, + min_value: int | None = None, + max_value: int | None = None, + **kwargs: Any, + ): """ :param min_value: (optional) A min value that will be applied during validation :param max_value: (optional) A max value that will be applied during validation @@ -376,7 +402,12 @@ def prepare_query_value(self, op, value): class FloatField(BaseField): """Floating point number field.""" - def __init__(self, min_value=None, max_value=None, **kwargs): + def __init__( + self, + min_value: float | int | None = None, + max_value: float | int | None = None, + **kwargs, + ): """ :param min_value: (optional) A min value that will be applied during validation :param max_value: (optional) A max value that will be applied during validation @@ -425,11 +456,11 @@ class DecimalField(BaseField): def __init__( self, - min_value=None, - max_value=None, - force_string=False, - precision=2, - rounding=decimal.ROUND_HALF_UP, + min_value: decimal.Decimal | int | None = None, + max_value: decimal.Decimal | int | None = None, + force_string: bool = False, + precision: int = 2, + rounding: str = decimal.ROUND_HALF_UP, **kwargs, ): """ @@ -634,7 +665,7 @@ class ComplexDateTimeField(StringField): Note: To default the field to the current datetime, use: DateTimeField(default=datetime.utcnow) """ - def __init__(self, separator=",", **kwargs): + def __init__(self, separator: str = ",", **kwargs): """ :param separator: Allows to customize the separator used for storage (default ``,``) :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.StringField` @@ -714,7 +745,7 @@ class EmbeddedDocumentField(BaseField): Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. """ - def __init__(self, document_type, **kwargs): + def __init__(self, document_type: type[EmbeddedDocument] | str, **kwargs: Any): if not ( isinstance(document_type, str) or issubclass(document_type, EmbeddedDocument) @@ -728,7 +759,7 @@ def __init__(self, document_type, **kwargs): super().__init__(**kwargs) @property - def document_type(self): + def document_type(self) -> type[Any]: if isinstance(self.document_type_obj, str): if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: resolved_document_type = self.owner_document @@ -924,7 +955,7 @@ def __init__(self, field=None, *, max_length=None, **kwargs): kwargs.setdefault("default", list) super().__init__(field=field, **kwargs) - def __get__(self, instance, owner): + def __get__(self, instance: Any, owner: Any) -> list[dict[str, Any]] | Self: if instance is None: # Document class being used rather than a document object return self @@ -973,7 +1004,10 @@ def prepare_query_value(self, op, value): return super().prepare_query_value(op, value) -class EmbeddedDocumentListField(ListField): +class EmbeddedDocumentListField( + ListField, + Generic[_T], +): """A :class:`~mongoengine.ListField` designed specially to hold a list of embedded documents to provide additional query helpers. @@ -1043,12 +1077,12 @@ class DictField(ComplexBaseField): Required means it cannot be empty - as the default for DictFields is {} """ - def __init__(self, field=None, *args, **kwargs): + def __init__(self, field: Any | None = None, *args, **kwargs): kwargs.setdefault("default", dict) super().__init__(*args, field=field, **kwargs) self.set_auto_dereferencing(False) - def validate(self, value): + def validate(self, value: Any): """Make sure that a list of valid fields is being used.""" if not isinstance(value, dict): self.error("Only dictionaries may be used in a DictField") @@ -1143,8 +1177,12 @@ class User(Document): """ def __init__( - self, document_type, dbref=False, reverse_delete_rule=DO_NOTHING, **kwargs - ): + self, + document_type: type[_T] | str, + dbref: bool = False, + reverse_delete_rule=DO_NOTHING, + **kwargs, + ) -> None: """Initialises the Reference Field. :param document_type: The type of Document that will be referenced @@ -1190,7 +1228,7 @@ def _lazy_load_ref(ref_cls, dbref): return ref_cls._from_son(dereferenced_son) - def __get__(self, instance, owner): + def __get__(self, instance: Any, owner: Any) -> Any: """Descriptor to allow lazy dereferencing.""" if instance is None: # Document class being used rather than a document object @@ -1276,7 +1314,13 @@ def lookup_member(self, member_name): class CachedReferenceField(BaseField): """A referencefield with cache fields to purpose pseudo-joins""" - def __init__(self, document_type, fields=None, auto_sync=True, **kwargs): + def __init__( + self, + document_type: str | type[Document], + fields: Iterable[str] | None = None, + auto_sync: bool = True, + **kwargs, + ): """Initialises the Cached Reference Field. :param document_type: The type of Document that will be referenced @@ -1542,7 +1586,7 @@ def prepare_query_value(self, op, value): class BinaryField(BaseField): """A binary data field.""" - def __init__(self, max_bytes=None, **kwargs): + def __init__(self, max_bytes: int | None = None, **kwargs): self.max_bytes = max_bytes super().__init__(**kwargs) @@ -1606,7 +1650,7 @@ class ModelWithEnum(Document): status = EnumField(Status, choices=[Status.NEW, Status.DONE]) """ - def __init__(self, enum, **kwargs): + def __init__(self, enum: type[Enum], **kwargs): self._enum_cls = enum if kwargs.get("choices"): invalid_choices = [] @@ -1636,7 +1680,7 @@ def to_python(self, value): return value return value - def __set__(self, instance, value): + def __set__(self, instance: Any, value: Any) -> None: return super().__set__(instance, self.to_python(value)) def to_mongo(self, value): @@ -1661,11 +1705,11 @@ class GridFSProxy: def __init__( self, - grid_id=None, - key=None, - instance=None, - db_alias=DEFAULT_CONNECTION_NAME, - collection_name="fs", + grid_id: ObjectId | None = None, + key: str | None = None, + instance: Any | None = None, + db_alias: str = DEFAULT_CONNECTION_NAME, + collection_name: str = "fs", ): self.grid_id = grid_id # Store GridFS id for file self.key = key @@ -1822,13 +1866,20 @@ class FileField(BaseField): proxy_class = GridFSProxy def __init__( - self, db_alias=DEFAULT_CONNECTION_NAME, collection_name="fs", **kwargs + self, + db_alias: str = DEFAULT_CONNECTION_NAME, + collection_name: str = "fs", + **kwargs, ): super().__init__(**kwargs) self.collection_name = collection_name self.db_alias = db_alias - def __get__(self, instance, owner): + def __get__( + self, + instance: Any, + owner: Any, + ) -> GridFSProxy | Self: if instance is None: return self @@ -1843,7 +1894,7 @@ def __get__(self, instance, owner): grid_file.instance = instance return grid_file - def __set__(self, instance, value): + def __set__(self, instance: Any, value: GridFSProxy): key = self.name if ( hasattr(value, "read") and not isinstance(value, GridFSProxy) @@ -1865,7 +1916,13 @@ def __set__(self, instance, value): instance._mark_as_changed(key) - def get_proxy_obj(self, key, instance, db_alias=None, collection_name=None): + def get_proxy_obj( + self, + key: str, + instance: Any, + db_alias: str | None = None, + collection_name: str | None = None, + ) -> GridFSProxy: if db_alias is None: db_alias = self.db_alias if collection_name is None: @@ -2034,7 +2091,11 @@ class ImageField(FileField): proxy_class = ImageGridFsProxy def __init__( - self, size=None, thumbnail_size=None, collection_name="images", **kwargs + self, + size: tuple[int, int, bool] | None = None, + thumbnail_size: tuple[int, int, bool] | None = None, + collection_name: str = "images", + **kwargs, ): if not Image: raise ImproperlyConfigured("PIL library was not found") @@ -2246,7 +2307,7 @@ class GeoPointField(BaseField): _geo_index = pymongo.GEO2D - def validate(self, value): + def validate(self, value: Any): """Make sure that a geo-value is of type (x, y)""" if not isinstance(value, (list, tuple)): self.error("GeoPointField can only accept tuples or lists of (x, y)") @@ -2392,10 +2453,10 @@ class LazyReferenceField(BaseField): def __init__( self, - document_type, - passthrough=False, - dbref=False, - reverse_delete_rule=DO_NOTHING, + document_type: type[EmbeddedDocument] | str, + passthrough: bool = False, + dbref: bool = False, + reverse_delete_rule: int = DO_NOTHING, **kwargs, ): """Initialises the Reference Field. @@ -2626,7 +2687,9 @@ class Decimal128Field(BaseField): DECIMAL_CONTEXT = create_decimal128_context() - def __init__(self, min_value=None, max_value=None, **kwargs): + def __init__( + self, min_value: int | None = None, max_value: int | None = None, **kwargs + ): self.min_value = min_value self.max_value = max_value super().__init__(**kwargs) diff --git a/mongoengine/fields.pyi b/mongoengine/fields.pyi new file mode 100644 index 000000000..1e18a9fa9 --- /dev/null +++ b/mongoengine/fields.pyi @@ -0,0 +1,992 @@ +# pyright: reportIncompatibleMethodOverride=false +from __future__ import annotations + +from datetime import date, datetime +from decimal import Decimal +from enum import Enum +from typing import ( + Any, + Callable, + Container, + Dict, + Iterator, + List, + Optional, + Pattern, + Tuple, + Type, + TypeVar, + Union, + overload, +) +from uuid import UUID + +from typing_extensions import Literal, TypeAlias, Unpack + +from mongoengine.base import BaseField, ComplexBaseField +from mongoengine.base.datastructures import LazyReference +from mongoengine.base.fields import ( + _F, + _GT, + _ST, + GeoJsonBaseField, + ObjectIdField, + _BaseFieldOptions, +) +from mongoengine.document import Document + +_T = TypeVar("_T") +_DT = TypeVar("_DT", bound=Document) +_Choice: TypeAlias = str | tuple[str, str] +__all__ = ( + "StringField", + "URLField", + "EmailField", + "IntField", + "FloatField", + "DecimalField", + "BooleanField", + "DateTimeField", + "DateField", + "ComplexDateTimeField", + "EmbeddedDocumentField", + "ObjectIdField", + "GenericEmbeddedDocumentField", + "DynamicField", + "ListField", + "SortedListField", + "EmbeddedDocumentListField", + "DictField", + "MapField", + "ReferenceField", + "CachedReferenceField", + "LazyReferenceField", + "GenericLazyReferenceField", + "GenericReferenceField", + "BinaryField", + "GridFSError", + "GridFSProxy", + "FileField", + "ImageGridFsProxy", + "ImproperlyConfigured", + "ImageField", + "GeoPointField", + "PointField", + "LineStringField", + "PolygonField", + "SequenceField", + "UUIDField", + "EnumField", + "MultiPointField", + "MultiLineStringField", + "MultiPolygonField", + "GeoJsonBaseField", + "Decimal128Field", +) + +class StringField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> StringField[Optional[str], Optional[str]]: ... + # StringField(default="foo") + @overload + def __new__( + cls, + *, + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[False] = ..., + default: Union[str, Callable[[], str]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> StringField[Optional[str], str]: ... + # StringField(required=True) + @overload + def __new__( + cls, + *, + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> StringField[str, str]: ... + # StringField(required=True, default="foo") + @overload + def __new__( + cls, + *, + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[True], + default: Union[str, Callable[[], str]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> StringField[Optional[str], str]: ... + +class URLField(StringField[_ST, _GT]): + @overload + def __new__( + cls, + *, + url_regex: Optional[Pattern[str]] = ..., + schemas: Optional[Container[str]] = ..., + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> URLField[Optional[str], Optional[str]]: ... + # URLField(default="foo") + @overload + def __new__( + cls, + *, + url_regex: Optional[Pattern[str]] = ..., + schemas: Optional[Container[str]] = ..., + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[False] = ..., + default: Union[str, Callable[[], str]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> URLField[Optional[str], str]: ... + # URLField(required=True) + @overload + def __new__( + cls, + *, + url_regex: Optional[Pattern[str]] = ..., + schemas: Optional[Container[str]] = ..., + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> URLField[str, str]: ... + # URLField(required=True, default="foo") + @overload + def __new__( + cls, + *, + url_regex: Optional[Pattern[str]] = ..., + schemas: Optional[Container[str]] = ..., + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[True], + default: Union[str, Callable[[], str]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> URLField[Optional[str], str]: ... + def __set__(self, instance: Any, value: _ST) -> None: ... + +class EmailField(StringField[_ST, _GT]): + @overload + def __new__( + cls, + *, + domain_whitelist: Optional[List[str]] = ..., + allow_utf8_user: bool = ..., + allow_ip_domain: bool = ..., + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> EmailField[Optional[str], Optional[str]]: ... + @overload + def __new__( + cls, + *, + domain_whitelist: Optional[List[str]] = ..., + allow_utf8_user: bool = ..., + allow_ip_domain: bool = ..., + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[False] = ..., + default: Union[str, Callable[[], str]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> EmailField[Optional[str], str]: ... + @overload + def __new__( + cls, + *, + domain_whitelist: Optional[List[str]] = ..., + allow_utf8_user: bool = ..., + allow_ip_domain: bool = ..., + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> EmailField[str, str]: ... + @overload + def __new__( + cls, + *, + domain_whitelist: Optional[List[str]] = ..., + allow_utf8_user: bool = ..., + allow_ip_domain: bool = ..., + regex: Optional[str] = ..., + max_length: Optional[int] = ..., + min_length: Optional[int] = ..., + required: Literal[True], + default: Union[str, Callable[[], str]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> EmailField[Optional[str], str]: ... + +class IntField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + min_value: int = ..., + max_value: int = ..., + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> IntField[Optional[int], Optional[int]]: ... + @overload + def __new__( + cls, + *, + min_value: int = ..., + max_value: int = ..., + required: Literal[False] = ..., + default: Union[int, Callable[[], int]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> IntField[Optional[int], int]: ... + @overload + def __new__( + cls, + *, + min_value: int = ..., + max_value: int = ..., + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> IntField[int, int]: ... + @overload + def __new__( + cls, + *, + min_value: int = ..., + max_value: int = ..., + required: Literal[True], + default: Union[int, Callable[[], int]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> IntField[Optional[int], int]: ... + +class FloatField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + min_value: float = ..., + max_value: float = ..., + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> FloatField[Optional[float], Optional[float]]: ... + @overload + def __new__( + cls, + *, + min_value: float = ..., + max_value: float = ..., + required: Literal[False] = ..., + default: Union[float, Callable[[], float]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> FloatField[Optional[float], float]: ... + @overload + def __new__( + cls, + *, + min_value: float = ..., + max_value: float = ..., + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> FloatField[float, float]: ... + @overload + def __new__( + cls, + *, + min_value: float = ..., + max_value: float = ..., + required: Literal[True], + default: Union[float, Callable[[], float]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> FloatField[Optional[float], float]: ... + +class DecimalField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + min_value: Decimal = ..., + max_value: Decimal = ..., + force_string: bool = ..., + precision: int = ..., + rounding: str = ..., + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> DecimalField[Optional[Decimal], Optional[Decimal]]: ... + @overload + def __new__( + cls, + *, + min_value: Decimal = ..., + max_value: Decimal = ..., + force_string: bool = ..., + precision: int = ..., + rounding: str = ..., + required: Literal[False] = ..., + default: Union[Decimal, Callable[[], Decimal]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> DecimalField[Optional[Decimal], Decimal]: ... + @overload + def __new__( + cls, + *, + min_value: Decimal = ..., + max_value: Decimal = ..., + force_string: bool = ..., + precision: int = ..., + rounding: str = ..., + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> DecimalField[Decimal, Decimal]: ... + @overload + def __new__( + cls, + *, + min_value: Decimal = ..., + max_value: Decimal = ..., + force_string: bool = ..., + precision: int = ..., + rounding: str = ..., + required: Literal[True], + default: Union[Decimal, Callable[[], Decimal]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> DecimalField[Optional[Decimal], Decimal]: ... + +class BooleanField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> BooleanField[Optional[bool], Optional[bool]]: ... + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: Union[bool, Callable[[], bool]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> BooleanField[Optional[bool], bool]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> BooleanField[bool, bool]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: Union[bool, Callable[[], bool]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> BooleanField[Optional[bool], bool]: ... + +class DateTimeField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> DateTimeField[Optional[datetime], Optional[datetime]]: ... + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: Union[datetime, Callable[[], datetime]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> DateTimeField[Optional[datetime], datetime]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> DateTimeField[datetime, datetime]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: Union[datetime, Callable[[], datetime]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> DateTimeField[Optional[datetime], datetime]: ... + +class EmbeddedDocumentField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + document_type: Type[_T], + required: Literal[False] = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> EmbeddedDocumentField[Optional[_T], Optional[_T]]: ... + @overload + def __new__( + cls, + document_type: Type[_T], + *, + required: Literal[True], + **kwargs: Unpack[_BaseFieldOptions], + ) -> EmbeddedDocumentField[_T, _T]: ... + +class DynamicField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> DynamicField[Optional[Any], Optional[Any]]: ... + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: Union[Any, Callable[[], Any]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> DynamicField[Optional[Any], Any]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> DynamicField[Any, Any]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: Union[Any, Callable[[], Any]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> DynamicField[Optional[Any], Any]: ... + +class ListField(ComplexBaseField[_F, _ST, _GT]): + # see: https://github.com/python/mypy/issues/4236#issuecomment-521628880 + # and probably this: + # * https://github.com/python/typing/issues/548 + # With Higher-Kinded TypeVars this could be simplfied, but it's not there yet. + @overload + def __new__( + cls, + field: BaseField[_ST, _GT], + required: bool = ..., + default: Optional[Union[List[Any], Callable[[], List[Any]]]] = ..., + ) -> ListField[BaseField[_ST, _GT], list[_ST], list[_GT]]: ... + @overload + def __new__( + cls, + field: Any | None, + required: bool = ..., + default: Optional[Union[List[Any], Callable[[], List[Any]]]] = ..., + ) -> ListField[Any, Any, Any]: ... + def __getitem__(self, arg: Any) -> _GT: ... + def __iter__(self) -> Iterator[_GT]: ... + +class DictField(ComplexBaseField[_F, _ST, _GT]): + def __new__( + cls, + field: BaseField[_ST, _GT] = ..., + required: bool = ..., + default: Union[Dict[str, Any], None, Callable[[], Dict[str, Any]]] = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> DictField[BaseField[_ST, _GT], dict[str, _ST], dict[str, _GT]]: ... + def __getitem__(self, arg: Any) -> _GT: ... + +class EmbeddedDocumentListField(ListField[_F, _ST, _GT]): + def __new__( + cls, + document_type: Type[_T] | str, + required: bool = ..., + default: Optional[Any] = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> EmbeddedDocumentListField[ + EmbeddedDocumentField[_T, _T], list[_T], list[_T] + ]: ... + +class LazyReferenceField(BaseField[_ST, _GT]): + def __new__( + cls, + document_type: type[_T] | str, + passthrough: bool = ..., + dbref: bool = ..., + reverse_delete_rule: int = ..., + required: bool = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> LazyReferenceField[_T, _T]: ... + +class UUIDField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + binary: bool = ..., + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> UUIDField[Optional[UUID], Optional[UUID]]: ... + @overload + def __new__( + cls, + *, + binary: bool = ..., + required: Literal[False] = ..., + default: Union[UUID, Callable[[], UUID]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> UUIDField[Optional[UUID], UUID]: ... + @overload + def __new__( + cls, + *, + binary: bool = ..., + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> UUIDField[UUID, UUID]: ... + @overload + def __new__( + cls, + *, + binary: bool = ..., + required: Literal[True], + default: Union[UUID, Callable[[], UUID]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> UUIDField[Optional[UUID], UUID]: ... + +_Tuple2Like = Union[Tuple[Union[float, int], Union[float, int]], List[float], List[int]] + +class GeoPointField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> GeoPointField[_Tuple2Like | None, list[float] | None]: ... + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: Union[_Tuple2Like, Callable[[], _Tuple2Like]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> GeoPointField[_Tuple2Like | None, list[float]]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> GeoPointField[_Tuple2Like, list[float]]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: Union[_Tuple2Like, Callable[[], _Tuple2Like]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> GeoPointField[_Tuple2Like | None, list[float]]: ... + def __set__(self, instance: Any, value: _ST) -> None: ... + +class MapField(DictField[_F, _ST, _GT]): + def __new__( + cls, + field: BaseField[_ST, _GT] = ..., + required: bool = ..., + default: Union[Dict[str, Any], None, Callable[[], Dict[str, Any]]] = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> MapField[BaseField[_ST, _GT], dict[str, _ST], dict[str, _GT]]: ... + def __getitem__(self, arg: Any) -> _GT: ... + +class ReferenceField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + document_type: Type[_T] | str, + dbref: bool = ..., + reverse_delete_rule: int = ..., + required: Literal[False] = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> ReferenceField[_T | None, _T | None]: ... + @overload + def __new__( + cls, + document_type: Type[_T] | str, + dbref: bool = ..., + reverse_delete_rule: int = ..., + required: Literal[True] = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> ReferenceField[_T, _T]: ... + def __getitem__(self, arg: Any) -> Any: ... + def __set__(self, instance: Any, value: _ST) -> None: ... + +_T_ENUM = TypeVar("_T_ENUM", bound=Enum) + +class EnumField(BaseField[_ST, _GT]): + # EnumField(Foo) + @overload + def __new__( + cls, + enum: Type[_T_ENUM], + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> EnumField[Optional[_T_ENUM], Optional[_T_ENUM]]: ... + # EnumField(Foo, default=Foo.Bar) + @overload + def __new__( + cls, + enum: Type[_T_ENUM], + *, + required: Literal[False] = ..., + default: Union[_T_ENUM, Callable[[], _T_ENUM]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> EnumField[Optional[_T_ENUM], _T_ENUM]: ... + # EnumField(Foo, required=True) + @overload + def __new__( + cls, + enum: Type[_T_ENUM], + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> EnumField[_T_ENUM, _T_ENUM]: ... + # EnumField(Foo, required=True, default=Foo.Bar) + @overload + def __new__( + cls, + enum: Type[_T_ENUM], + *, + required: Literal[True], + default: Union[_T_ENUM, Callable[[], _T_ENUM]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> EnumField[Optional[_T_ENUM], _T_ENUM]: ... + def __set__(self, instance: Any, value: _ST) -> None: ... + +class DateField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> DateField[Optional[date], Optional[date]]: ... + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: Union[date, Callable[[], date]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> DateField[Optional[date], date]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> DateField[date, date]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: Union[date, Callable[[], date]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> DateField[Optional[date], date]: ... + +class ComplexDateTimeField(StringField[_ST, _GT]): + @overload + def __new__( + cls, + separator: str = ..., + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> ComplexDateTimeField[Optional[datetime], Optional[datetime]]: ... + @overload + def __new__( + cls, + separator: str = ..., + *, + required: Literal[False] = ..., + default: Union[datetime, Callable[[], datetime]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> ComplexDateTimeField[Optional[datetime], datetime]: ... + @overload + def __new__( + cls, + separator: str = ..., + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> ComplexDateTimeField[datetime, datetime]: ... + @overload + def __new__( + cls, + separator: str = ..., + *, + required: Literal[True], + default: Union[datetime, Callable[[], datetime]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> ComplexDateTimeField[Optional[datetime], datetime]: ... + +class GenericEmbeddedDocumentField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> GenericEmbeddedDocumentField[Optional[Any], Optional[Any]]: ... + @overload + def __new__( + cls, + *, + required: Literal[False] = ..., + default: Union[Any, Callable[[], Any]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> GenericEmbeddedDocumentField[Optional[Any], Any]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> GenericEmbeddedDocumentField[Any, Any]: ... + @overload + def __new__( + cls, + *, + required: Literal[True], + default: Union[Any, Callable[[], Any]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> GenericEmbeddedDocumentField[Optional[Any], Any]: ... + +class SortedListField(ListField[_F, _ST, _GT]): + @overload + def __new__( + cls, + field: BaseField[_ST, _GT], + required: bool = ..., + default: Optional[Union[List[Any], Callable[[], List[Any]]]] = ..., + ) -> SortedListField[BaseField[_ST, _GT], list[_ST], list[_GT]]: ... + @overload + def __new__( + cls, + field: Any | None, + required: bool = ..., + default: Optional[Union[List[Any], Callable[[], List[Any]]]] = ..., + ) -> SortedListField[Any, Any, Any]: ... + +class CachedReferenceField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + document_type: Type[_T] | str, + required: Literal[False] = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> CachedReferenceField[_T | None, _T | None]: ... + @overload + def __new__( + cls, + document_type: Type[_T] | str, + required: Literal[True] = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> CachedReferenceField[_T, _T]: ... + +class GenericLazyReferenceField(BaseField[LazyReference[Any], LazyReference[Any]]): + def __init__( + self, *args: Any, passthrough: bool = False, **kwargs: Unpack[_BaseFieldOptions] + ): ... + +class BinaryField(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + max_bytes: int | None = ..., + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> BinaryField[Optional[bytes], Optional[bytes]]: ... + @overload + def __new__( + cls, + max_bytes: int | None = ..., + *, + required: Literal[False] = ..., + default: Union[bytes, Callable[[], bytes]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> BinaryField[Optional[bytes], bytes]: ... + @overload + def __new__( + cls, + max_bytes: int | None = ..., + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> BinaryField[bytes, bytes]: ... + @overload + def __new__( + cls, + max_bytes: int | None = ..., + *, + required: Literal[True], + default: Union[bytes, Callable[[], bytes]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> BinaryField[Optional[bytes], bytes]: ... + +class GridFSError(Exception): ... +class GridFSProxy: ... + +class FileField(BaseField[_ST, _GT]): + def __new__( + cls, + db_alias: str = ..., + collection_name: str = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> FileField[Any, Any]: ... + +class ImageGridFsProxy(GridFSProxy): + def put(self, file_obj, **kwargs): ... + def delete(self, *args, **kwargs): ... + @property + def size(self): ... + @property + def format(self): ... + @property + def thumbnail(self): ... + def write(self, *args, **kwargs) -> None: ... + def writelines(self, *args, **kwargs) -> None: ... + +class ImproperlyConfigured(Exception): ... +class PointField(GeoJsonBaseField): ... +class LineStringField(GeoJsonBaseField): ... +class PolygonField(GeoJsonBaseField): ... + +class SequenceField(BaseField[Any, Any]): + def __init__( + self, + collection_name: str | None = ..., + db_alias: str | None = ..., + sequence_name: str | None = ..., + value_decorator: Any | None = ..., + *args: Any, + **kwargs: Unpack[_BaseFieldOptions], + ): ... + +class MultiPointField(GeoJsonBaseField): ... +class MultiLineStringField(GeoJsonBaseField): ... +class MultiPolygonField(GeoJsonBaseField): ... + +class Decimal128Field(BaseField[_ST, _GT]): + @overload + def __new__( + cls, + min_value: int | None = ..., + max_value: int | None = ..., + *, + required: Literal[False] = ..., + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> Decimal128Field[Optional[Decimal], Optional[Decimal]]: ... + @overload + def __new__( + cls, + min_value: int | None = ..., + max_value: int | None = ..., + *, + required: Literal[False] = ..., + default: Union[Decimal, Callable[[], Decimal]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> Decimal128Field[Optional[Decimal], Decimal]: ... + @overload + def __new__( + cls, + min_value: int | None = ..., + max_value: int | None = ..., + *, + required: Literal[True], + default: None = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> Decimal128Field[Decimal, Decimal]: ... + @overload + def __new__( + cls, + min_value: int | None = ..., + max_value: int | None = ..., + *, + required: Literal[True], + default: Union[Decimal, Callable[[], Decimal]], + **kwargs: Unpack[_BaseFieldOptions], + ) -> Decimal128Field[Optional[Decimal], Decimal]: ... + +class ImageField(FileField[_ST, _GT]): + def __new__( + cls, + size: tuple[int, int, bool] | None = ..., + thumbnail_size: tuple[int, int, bool] | None = ..., + collection_name: str = ..., + db_alias: str = ..., + **kwargs: Unpack[_BaseFieldOptions], + ) -> ImageField[Any, Any]: ... + +class GenericReferenceField(BaseField[Any, Any]): + def __init__(self, *args: Any, **kwargs: Unpack[_BaseFieldOptions]): ... diff --git a/mongoengine/py.typed b/mongoengine/py.typed new file mode 100644 index 000000000..b648ac923 --- /dev/null +++ b/mongoengine/py.typed @@ -0,0 +1 @@ +partial diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 2db97ddb7..9bdf38345 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1,16 +1,33 @@ +# mypy: disable-error-code="attr-defined,union-attr,assignment,misc,arg-type,var-annotated,list-item,return-value,has-type" +from __future__ import annotations + import copy import itertools import re import warnings from collections.abc import Mapping +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + Iterator, + List, + Tuple, + TypeVar, + Union, + overload, +) import pymongo import pymongo.errors from bson import SON, json_util from bson.code import Code -from pymongo.collection import ReturnDocument +from pymongo.collation import Collation +from pymongo.collection import Collection, ReturnDocument from pymongo.common import validate_read_preference from pymongo.read_concern import ReadConcern +from typing_extensions import Literal, Self, TypedDict from mongoengine import signals from mongoengine.base import _DocumentRegistry @@ -37,8 +54,16 @@ from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode +if TYPE_CHECKING: + from pymongo.read_preferences import _ServerMode + + from mongoengine.document import Document + __all__ = ("BaseQuerySet", "DO_NOTHING", "NULLIFY", "CASCADE", "DENY", "PULL") +_T = TypeVar("_T", bound="Document") +_U = TypeVar("_U", bound="BaseQuerySet[Any]") + # Delete rules DO_NOTHING = 0 NULLIFY = 1 @@ -47,12 +72,12 @@ PULL = 4 -class BaseQuerySet: +class BaseQuerySet(Generic[_T]): """A set of results returned from a query. Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as the results. """ - def __init__(self, document, collection): + def __init__(self, document: type[_T], collection: Collection[Any]) -> None: self._document = document self._collection_obj = collection self._mongo_query = None @@ -162,7 +187,13 @@ def __setstate__(self, obj_dict): # forse load cursor # self._cursor - def __getitem__(self, key): + @overload + def __getitem__(self, key: int) -> _T: ... # noqa: E704 + + @overload + def __getitem__(self, key: slice) -> Self: ... # noqa: E704 + + def __getitem__(self, key: int | slice) -> _T | Self: """Return a document instance corresponding to a given index if the key is an integer. If the key is a slice, translate its bounds into a skip and a limit, and return a cloned queryset @@ -208,7 +239,7 @@ def __getitem__(self, key): raise TypeError("Provide a slice or an integer index") - def __iter__(self): + def __iter__(self) -> Iterator[_T]: raise NotImplementedError def _has_data(self): @@ -222,11 +253,11 @@ def __bool__(self): # Core functions - def all(self): + def all(self) -> Self: """Returns a copy of the current QuerySet.""" return self.__call__() - def filter(self, *q_objs, **query): + def filter(self, *q_objs: Q, **query: Any) -> Self: """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__`""" return self.__call__(*q_objs, **query) @@ -259,7 +290,7 @@ def search_text(self, text, language=None, text_score=True): return queryset - def get(self, *q_objs, **query): + def get(self, *q_objs: Q, **query: Any) -> _T: """Retrieve the matching object raising :class:`~mongoengine.queryset.MultipleObjectsReturned` or `DocumentName.MultipleObjectsReturned` exception if multiple results @@ -287,11 +318,11 @@ def get(self, *q_objs, **query): "2 or more items returned, instead of 1" ) - def create(self, **kwargs): + def create(self, **kwargs: Any) -> _T: """Create new object. Returns the saved object instance.""" return self._document(**kwargs).save(force_insert=True) - def first(self): + def first(self) -> _T | None: """Retrieve the first object matching the query.""" queryset = self.clone() if self._none or self._empty: @@ -304,8 +335,12 @@ def first(self): return result def insert( - self, doc_or_docs, load_bulk=True, write_concern=None, signal_kwargs=None - ): + self, + doc_or_docs: Iterable[_T] | _T, + load_bulk: bool = True, + write_concern: _ReadWriteConcern | None = None, + signal_kwargs: Any | None = None, + ) -> list[_T]: """bulk insert documents :param doc_or_docs: a document or list of documents to be inserted @@ -397,7 +432,7 @@ def insert( ) return results[0] if return_one else results - def count(self, with_limit_and_skip=False): + def count(self, with_limit_and_skip: bool = False) -> int: """Count the selected elements in the query. :param with_limit_and_skip (optional): take any :meth:`limit` or @@ -696,12 +731,13 @@ def update_one( def modify( self, - upsert=False, - remove=False, - new=False, + upsert: bool = False, + full_response: bool = False, + remove: bool = False, + new: bool = False, array_filters=None, - **update, - ): + **update: Any, + ) -> Self | None: """Update and return the updated document. Returns either the document before or after modification based on `new` @@ -806,7 +842,7 @@ def in_bulk(self, object_ids): return doc_map - def none(self): + def none(self) -> Self: """Returns a queryset that never returns any objects and no query will be executed when accessing the results inspired by django none() https://docs.djangoproject.com/en/dev/ref/models/querysets/#none """ @@ -836,11 +872,11 @@ def using(self, alias): return self._clone_into(self.__class__(self._document, collection)) - def clone(self): + def clone(self) -> Self: """Create a copy of the current queryset.""" return self._clone_into(self.__class__(self._document, self._collection_obj)) - def _clone_into(self, new_qs): + def _clone_into(self, new_qs: _U) -> _U: """Copy all the relevant properties of this queryset to a new queryset (which has to be an instance of :class:`~mongoengine.queryset.base.BaseQuerySet`). @@ -899,7 +935,7 @@ def select_related(self, max_depth=1): queryset = self.clone() return queryset._dereference(queryset, max_depth=max_depth) - def limit(self, n): + def limit(self, n: int) -> Self: """Limit the number of returned documents to `n`. This may also be achieved using array-slicing syntax (e.g. ``User.objects[:5]``). @@ -916,7 +952,7 @@ def limit(self, n): return queryset - def skip(self, n): + def skip(self, n: int | None) -> Self: """Skip `n` documents before returning the results. This may also be achieved using array-slicing syntax (e.g. ``User.objects[5:]``). @@ -931,7 +967,7 @@ def skip(self, n): return queryset - def hint(self, index=None): + def hint(self, index: _Hint | None = None) -> Self: """Added 'hint' support, telling Mongo the proper index to use for the query. @@ -951,7 +987,7 @@ def hint(self, index=None): return queryset - def collation(self, collation=None): + def collation(self, collation: _Collation | None = None) -> Self: """ Collation allows users to specify language-specific rules for string comparison, such as rules for lettercase and accent marks. @@ -977,7 +1013,7 @@ def collation(self, collation=None): return queryset - def batch_size(self, size): + def batch_size(self, size: int) -> Self: """Limit the number of documents returned in a single batch (each batch requires a round trip to the server). @@ -995,7 +1031,7 @@ def batch_size(self, size): return queryset - def distinct(self, field): + def distinct(self, field: str) -> list[Any]: """Return a list of distinct values for a given field. :param field: the field to select distinct values from @@ -1053,7 +1089,7 @@ def distinct(self, field): return distinct - def only(self, *fields): + def only(self, *fields: str) -> Self: """Load only a subset of this document's fields. :: post = BlogPost.objects(...).only('title', 'author.name') @@ -1071,7 +1107,7 @@ def only(self, *fields): fields = {f: QueryFieldList.ONLY for f in fields} return self.fields(True, **fields) - def exclude(self, *fields): + def exclude(self, *fields: str) -> Self: """Opposite to .only(), exclude some document's fields. :: post = BlogPost.objects(...).exclude('comments') @@ -1089,7 +1125,7 @@ def exclude(self, *fields): fields = {f: QueryFieldList.EXCLUDE for f in fields} return self.fields(**fields) - def fields(self, _only_called=False, **kwargs): + def fields(self, _only_called: bool = False, **kwargs: Any) -> Self: """Manipulate how you load this document's fields. Used by `.only()` and `.exclude()` to manipulate which fields to retrieve. If called directly, use a set of kwargs similar to the MongoDB projection @@ -1147,7 +1183,7 @@ def _sort_key(field_tuple): return queryset - def all_fields(self): + def all_fields(self) -> Self: """Include all fields. Reset all previously calls of .only() or .exclude(). :: @@ -1159,7 +1195,7 @@ def all_fields(self): ) return queryset - def order_by(self, *keys, __raw__=None): + def order_by(self, *keys: str, __raw__=None) -> Self: """Order the :class:`~mongoengine.queryset.QuerySet` by the given keys. The order may be specified by prepending each of the keys by a "+" or @@ -1212,7 +1248,7 @@ def clear_cls_query(self): queryset._cls_query = {} return queryset - def comment(self, text): + def comment(self, text: str) -> Self: """Add a comment to the query. See https://www.mongodb.com/docs/manual/reference/method/cursor.comment/ @@ -1220,7 +1256,7 @@ def comment(self, text): """ return self._chainable_method("comment", text) - def explain(self): + def explain(self) -> _ExplainCursor: """Return an explain plan record for the :class:`~mongoengine.queryset.QuerySet` cursor. """ @@ -1248,7 +1284,7 @@ def allow_disk_use(self, enabled): queryset._allow_disk_use = enabled return queryset - def timeout(self, enabled): + def timeout(self, enabled: bool) -> Self: """Enable or disable the default mongod timeout when querying. (no_cursor_timeout option) :param enabled: whether or not the timeout is used @@ -1257,7 +1293,7 @@ def timeout(self, enabled): queryset._timeout = enabled return queryset - def read_preference(self, read_preference): + def read_preference(self, read_preference: _ServerMode) -> Self: """Change the read_preference when querying. :param read_preference: override ReplicaSetConnection-level @@ -1285,7 +1321,7 @@ def read_concern(self, read_concern): queryset._cursor_obj = None # we need to re-create the cursor object whenever we apply read_concern return queryset - def scalar(self, *fields): + def scalar(self, *fields) -> list[Any]: """Instead of returning Document instances, return either a specific value or a tuple of values in order. @@ -1308,11 +1344,11 @@ def scalar(self, *fields): return queryset - def values_list(self, *fields): + def values_list(self, *fields: str) -> list[Any]: """An alias for scalar""" return self.scalar(*fields) - def as_pymongo(self): + def as_pymongo(self) -> BaseQuerySet[dict[str, Any]]: # type: ignore """Instead of returning Document instances, return raw values from pymongo. @@ -1321,9 +1357,9 @@ def as_pymongo(self): """ queryset = self.clone() queryset._as_pymongo = True - return queryset + return queryset # type: ignore - def max_time_ms(self, ms): + def max_time_ms(self, ms: int | None) -> Self: """Wait `ms` milliseconds before killing the query on the server :param ms: the number of milliseconds before killing the query on the server @@ -1332,7 +1368,7 @@ def max_time_ms(self, ms): # JSON Helpers - def to_json(self, *args, **kwargs): + def to_json(self, *args: Any, **kwargs: Any) -> str: """Converts a queryset to JSON""" if "json_options" not in kwargs: warnings.warn( @@ -1347,7 +1383,7 @@ def to_json(self, *args, **kwargs): kwargs["json_options"] = LEGACY_JSON_OPTIONS return json_util.dumps(self.as_pymongo(), *args, **kwargs) - def from_json(self, json_data): + def from_json(self, json_data: str) -> list[_T]: """Converts json data to unsaved objects""" son_data = json_util.loads(json_data) return [self._document._from_son(data) for data in son_data] @@ -1551,7 +1587,7 @@ def map_reduce( queryset._document, queryset._collection, doc["_id"], doc["value"] ) - def exec_js(self, code, *fields, **options): + def exec_js(self, code: str, *fields: str, **options: Any) -> Self: """Execute a Javascript function on the server. A list of fields may be provided, which will be translated to their correct names and supplied as the arguments to the function. A few extra variables are added to @@ -1606,7 +1642,7 @@ def where(self, where_clause): queryset._where_clause = where_clause return queryset - def sum(self, field): + def sum(self, field: str) -> float: """Sum over the values of the specified field. :param field: the field to sum over; use dot notation to refer to @@ -1634,7 +1670,7 @@ def sum(self, field): return result[0]["total"] return 0 - def average(self, field): + def average(self, field: str) -> Self: """Average over the values of the specified field. :param field: the field to average over; use dot notation to refer to @@ -1687,7 +1723,7 @@ def item_frequencies(self, field, normalize=False, map_reduce=True): # Iterator helpers - def __next__(self): + def __next__(self) -> _T: """Wrap the result in a :class:`~mongoengine.Document` object.""" if self._none or self._empty: raise StopIteration @@ -1707,7 +1743,7 @@ def __next__(self): return doc - def rewind(self): + def rewind(self) -> None: """Rewind the cursor to its unevaluated state.""" self._iter = False self._cursor.rewind() @@ -1836,7 +1872,7 @@ def _auto_dereference(self): should_deref = not no_dereferencing_active_for_class(self._document) return should_deref and self.__auto_dereference - def no_dereference(self): + def no_dereference(self) -> Self: """Turn off any dereferencing for the results of this queryset.""" queryset = self.clone() queryset.__auto_dereference = False @@ -2086,3 +2122,41 @@ def _chainable_method(self, method_name, val): setattr(queryset, "_" + method_name, val) return queryset + + +_Hint = Union[str, List[Tuple[str, Literal[-1, 1]]]] +_ReadWriteConcern = Mapping[str, Union[str, int, bool]] +_Collation = Union[Collation, Mapping[str, Union[bool, int, str, None]]] + + +class _ExecutionStats(TypedDict): + allPlansExecution: list[Any] + executionStages: dict[str, Any] + executionSuccess: bool + executionTimeMillis: int + nReturned: int + totalDocsExamined: int + totalKeysExamined: int + + +class _QueryPlanner(TypedDict): + indexFilterSet: bool + namespace: str + parsedQuery: dict[str, Any] + plannerVersion: int + rejectedPlans: list[Any] + winningPlan: dict[str, Any] + + +class _ServerInfo(TypedDict): + gitVersion: str + host: str + port: int + version: str + + +class _ExplainCursor(TypedDict): + executionStats: _ExecutionStats + ok: float + queryPlanner: _QueryPlanner + serverInfo: _ServerInfo diff --git a/mongoengine/queryset/manager.py b/mongoengine/queryset/manager.py index 46f137a27..3c0d6e4a9 100644 --- a/mongoengine/queryset/manager.py +++ b/mongoengine/queryset/manager.py @@ -1,11 +1,13 @@ from functools import partial +from typing import Generic +from mongoengine._typing import QS from mongoengine.queryset.queryset import QuerySet __all__ = ("queryset_manager", "QuerySetManager") -class QuerySetManager: +class QuerySetManager(Generic[QS]): """ The default QuerySet Manager. @@ -25,13 +27,13 @@ def __init__(self, queryset_func=None): if queryset_func: self.get_queryset = queryset_func - def __get__(self, instance, owner): + def __get__(self, instance, owner) -> QS: """Descriptor for instantiating a new QuerySet object when Document.objects is accessed. """ if instance is not None: # Document object being used rather than a document class - return self + return self # type: ignore[return-value] # owner is the document that contains the QuerySetManager queryset_class = owner._meta.get("queryset_class", self.default) diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index e0f7765b9..be19fa64c 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -1,3 +1,8 @@ +# mypy: disable-error-code="call-overload,arg-type,return-value" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterator, TypeVar + from mongoengine.errors import OperationError from mongoengine.queryset.base import ( CASCADE, @@ -8,6 +13,9 @@ BaseQuerySet, ) +if TYPE_CHECKING: + from mongoengine.document import Document + __all__ = ( "QuerySet", "QuerySetNoCache", @@ -18,12 +26,15 @@ "PULL", ) +_T = TypeVar("_T", bound="Document") +_U = TypeVar("_U", bound="QuerySet[Any]") + # The maximum number of items to display in a QuerySet.__repr__ REPR_OUTPUT_SIZE = 20 ITER_CHUNK_SIZE = 100 -class QuerySet(BaseQuerySet): +class QuerySet(BaseQuerySet[_T]): """The default queryset, that builds queries and handles a set of results returned from a query. @@ -35,7 +46,7 @@ class QuerySet(BaseQuerySet): _len = None _result_cache = None - def __iter__(self): + def __iter__(self) -> Iterator[_T]: """Iteration utilises a results cache which iterates the cursor in batches of ``ITER_CHUNK_SIZE``. @@ -50,7 +61,7 @@ def __iter__(self): # iterating over the cache. return iter(self._result_cache) - def __len__(self): + def __len__(self) -> int: """Since __len__ is called quite frequently (for example, as part of list(qs)), we populate the result cache and cache the length. """ @@ -132,7 +143,7 @@ def _populate_cache(self): # information in other places. self._has_more = False - def count(self, with_limit_and_skip=False): + def count(self, with_limit_and_skip: bool = False) -> int: """Count the selected elements in the query. :param with_limit_and_skip (optional): take any :meth:`limit` or @@ -148,7 +159,7 @@ def count(self, with_limit_and_skip=False): return self._len - def no_cache(self): + def no_cache(self) -> QuerySetNoCache[_T]: """Convert to a non-caching queryset""" if self._result_cache is not None: raise OperationError("QuerySet already cached") @@ -156,7 +167,7 @@ def no_cache(self): return self._clone_into(QuerySetNoCache(self._document, self._collection)) -class QuerySetNoCache(BaseQuerySet): +class QuerySetNoCache(BaseQuerySet[_T]): """A non caching QuerySet""" def cache(self): diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 9e26d4e83..f86ccb0e6 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -168,7 +168,7 @@ class Q(QNode): query structures. """ - def __init__(self, **query): + def __init__(self, **query) -> None: self.query = query def __repr__(self): diff --git a/mongoengine/signals.py b/mongoengine/signals.py index 940209a57..71936364b 100644 --- a/mongoengine/signals.py +++ b/mongoengine/signals.py @@ -15,7 +15,7 @@ signals_available = True except ImportError: - class Namespace: + class Namespace: # type: ignore[no-redef] def signal(self, name, doc=None): return _FakeSignal(name, doc) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..0e7251bba --- /dev/null +++ b/mypy.ini @@ -0,0 +1,13 @@ +[mypy] +show_column_numbers=True +pretty=True +show_error_codes=True +show_error_context=True + +disallow_any_expr=False +disallow_any_explicit=False + +# We're using __init__ together with __new__, but this is broken in mypy +# Disable error code until fixed. +# https://github.com/python/mypy/issues/17251 +disable_error_code = var-annotated diff --git a/py.typed b/py.typed new file mode 100644 index 000000000..b648ac923 --- /dev/null +++ b/py.typed @@ -0,0 +1 @@ +partial diff --git a/pylsp-mypy.cfg b/pylsp-mypy.cfg new file mode 100644 index 000000000..94099ecb4 --- /dev/null +++ b/pylsp-mypy.cfg @@ -0,0 +1,5 @@ +{ + "enabled": True, + "live_mode": True, + "overrides": ["--no-pretty", True] +} diff --git a/setup.cfg b/setup.cfg index aa965c8f8..e93b6ce2c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,6 +7,8 @@ max-complexity=47 # Limits the discovery to tests directory # avoids that it runs for instance the benchmark testpaths = tests +addopts = + --mypy-only-local-stub [isort] known_first_party = mongoengine,tests diff --git a/setup.py b/setup.py index a629d5b12..a2859099e 100644 --- a/setup.py +++ b/setup.py @@ -42,10 +42,14 @@ def get_version(version_tuple): "Topic :: Software Development :: Libraries :: Python Modules", ] -install_require = ["pymongo>=3.4,<5.0"] +install_require = [ + "pymongo>=3.4,<5.0", + "typing_extensions>=4.1", +] tests_require = [ "pytest", "pytest-cov", + "pytest-mypy-plugins", "coverage", "blinker", "Pillow>=7.0.0", @@ -62,6 +66,9 @@ def get_version(version_tuple): download_url="https://github.com/MongoEngine/mongoengine/tarball/master", license="MIT", include_package_data=True, + package_data={ + "mongoengine": ["py.typed"], + }, description=DESCRIPTION, long_description=LONG_DESCRIPTION, platforms=["any"], diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index c5970420f..b726ef6fe 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -11,8 +11,46 @@ from bson import DBRef, ObjectId from pymongo.errors import DuplicateKeyError -from mongoengine import * -from mongoengine import signals +from mongoengine import ( + CASCADE, + DENY, + PULL, + BooleanField, + ComplexDateTimeField, + DateTimeField, + DecimalField, + DictField, + Document, + DoesNotExist, + DynamicDocument, + DynamicEmbeddedDocument, + DynamicField, + EmailField, + EmbeddedDocument, + EmbeddedDocumentField, + EmbeddedDocumentListField, + FileField, + FloatField, + GenericEmbeddedDocumentField, + GenericReferenceField, + GeoPointField, + IntField, + InvalidCollectionError, + LazyReferenceField, + ListField, + MapField, + ObjectIdField, + OperationError, + ReferenceField, + SequenceField, + SortedListField, + StringField, + URLField, + UUIDField, + ValidationError, + register_connection, + signals, +) from mongoengine.base import _DocumentRegistry from mongoengine.connection import get_db from mongoengine.context_managers import query_counter, switch_db diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py index a98f222ad..f0e1fa4d1 100644 --- a/tests/fields/test_date_field.py +++ b/tests/fields/test_date_field.py @@ -5,9 +5,9 @@ try: import dateutil except ImportError: - dateutil = None + dateutil = None # type: ignore[assignment] -from mongoengine import * +from mongoengine import DateField, Document, ValidationError from tests.utils import MongoDBTestCase diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index d04f39b04..54fca814d 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -2,14 +2,19 @@ import pytest -from mongoengine import * -from mongoengine import connection +from mongoengine import ( + DateTimeField, + Document, + ValidationError, + connect, + connection, +) from tests.utils import MongoDBTestCase, get_as_pymongo try: import dateutil except ImportError: - dateutil = None + dateutil = None # type: ignore[assignment] class TestDateTimeField(MongoDBTestCase): diff --git a/tests/fields/test_decimal128_field.py b/tests/fields/test_decimal128_field.py index 6aa2ec23e..02633c9de 100644 --- a/tests/fields/test_decimal128_field.py +++ b/tests/fields/test_decimal128_field.py @@ -1,6 +1,7 @@ import json import random from decimal import Decimal +from typing import Type import pytest from bson.decimal128 import Decimal128 @@ -15,7 +16,7 @@ class Decimal128Document(Document): dec128_max_100 = Decimal128Field(max_value=100) -def generate_test_cls() -> Document: +def generate_test_cls() -> Type[Document]: Decimal128Document.drop_collection() Decimal128Document(dec128_fld=None).save() Decimal128Document(dec128_fld=Decimal(1)).save() diff --git a/tests/fields/test_uuid_field.py b/tests/fields/test_uuid_field.py index ec81033b0..3eb3c9704 100644 --- a/tests/fields/test_uuid_field.py +++ b/tests/fields/test_uuid_field.py @@ -2,7 +2,7 @@ import pytest -from mongoengine import * +from mongoengine import Document, UUIDField, ValidationError from tests.utils import MongoDBTestCase, get_as_pymongo diff --git a/tests/fixtures.py b/tests/fixtures.py index ef82c22af..c9f44662b 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,8 +1,19 @@ import pickle from datetime import datetime -from mongoengine import * -from mongoengine import signals +from mongoengine import ( + DateTimeField, + Document, + DynamicDocument, + DynamicEmbeddedDocument, + EmbeddedDocument, + EmbeddedDocumentField, + FileField, + IntField, + ListField, + StringField, + signals, +) class PickleEmbedded(EmbeddedDocument): diff --git a/tests/queryset/test_geo.py b/tests/queryset/test_geo.py index e87d27aea..bf39b7346 100644 --- a/tests/queryset/test_geo.py +++ b/tests/queryset/test_geo.py @@ -1,7 +1,17 @@ import datetime import unittest -from mongoengine import * +from mongoengine import ( + DateTimeField, + Document, + EmbeddedDocument, + EmbeddedDocumentField, + GeoPointField, + LineStringField, + PointField, + PolygonField, + StringField, +) from mongoengine.pymongo_support import PYMONGO_VERSION from tests.utils import MongoDBTestCase diff --git a/tests/queryset/test_modify.py b/tests/queryset/test_modify.py index b96e05e63..4f9071af0 100644 --- a/tests/queryset/test_modify.py +++ b/tests/queryset/test_modify.py @@ -10,7 +10,7 @@ class Doc(Document): - id = IntField(primary_key=True) + id = IntField(primary_key=True) # type: ignore[assignment] value = IntField() diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 8386249f2..873669ba8 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -9,7 +9,44 @@ from pymongo.read_preferences import ReadPreference from pymongo.results import UpdateResult -from mongoengine import * +from mongoengine import ( + CASCADE, + DENY, + NULLIFY, + PULL, + BooleanField, + ComplexDateTimeField, + DateTimeField, + DecimalField, + DictField, + Document, + DynamicDocument, + DynamicField, + EmailField, + EmbeddedDocument, + EmbeddedDocumentField, + EmbeddedDocumentListField, + FloatField, + GenericEmbeddedDocumentField, + GenericReferenceField, + GeoPointField, + IntField, + ListField, + MapField, + NotUniqueError, + ObjectIdField, + OperationError, + QuerySetNoCache, + ReferenceField, + SequenceField, + SortedListField, + StringField, + URLField, + UUIDField, + ValidationError, + connect, + register_connection, +) from mongoengine.connection import get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError diff --git a/tests/queryset/test_transform.py b/tests/queryset/test_transform.py index 8704187b8..b3edecccf 100644 --- a/tests/queryset/test_transform.py +++ b/tests/queryset/test_transform.py @@ -1,9 +1,12 @@ +# pyright: reportOptionalMemberAccess=false import unittest import pytest from bson.son import SON -from mongoengine import * +from mongoengine import fields +from mongoengine.document import Document, EmbeddedDocument +from mongoengine.errors import InvalidQueryError from mongoengine.queryset import Q, transform from tests.utils import MongoDBTestCase @@ -38,10 +41,10 @@ def test_transform_query(self): def test_transform_update(self): class LisDoc(Document): - foo = ListField(StringField()) + foo = fields.ListField(fields.StringField()) class DicDoc(Document): - dictField = DictField() + dictField = fields.DictField() class Doc(Document): pass @@ -75,7 +78,7 @@ def test_transform_update_push(self): """Ensure the differences in behvaior between 'push' and 'push_all'""" class BlogPost(Document): - tags = ListField(StringField()) + tags = fields.ListField(fields.StringField()) update = transform.update(BlogPost, push__tags=["mongo", "db"]) assert update == {"$push": {"tags": ["mongo", "db"]}} @@ -87,7 +90,7 @@ def test_transform_update_no_operator_default_to_set(self): """Ensure the differences in behvaior between 'push' and 'push_all'""" class BlogPost(Document): - tags = ListField(StringField()) + tags = fields.ListField(fields.StringField()) update = transform.update(BlogPost, tags=["mongo", "db"]) assert update == {"$set": {"tags": ["mongo", "db"]}} @@ -96,12 +99,12 @@ def test_query_field_name(self): """Ensure that the correct field name is used when querying.""" class Comment(EmbeddedDocument): - content = StringField(db_field="commentContent") + content = fields.StringField(db_field="commentContent") class BlogPost(Document): - title = StringField(db_field="postTitle") - comments = ListField( - EmbeddedDocumentField(Comment), db_field="postComments" + title = fields.StringField(db_field="postTitle") + comments = fields.ListField( + fields.EmbeddedDocumentField(Comment), db_field="postComments" ) BlogPost.drop_collection() @@ -130,7 +133,7 @@ def test_query_pk_field_name(self): """ class BlogPost(Document): - title = StringField(primary_key=True, db_field="postTitle") + title = fields.StringField(primary_key=True, db_field="postTitle") BlogPost.drop_collection() @@ -149,7 +152,7 @@ class A(Document): pass class B(Document): - a = ReferenceField(A) + a = fields.ReferenceField(A) A.drop_collection() B.drop_collection() @@ -174,10 +177,10 @@ def test_raw_query_and_Q_objects(self): """ class Foo(Document): - name = StringField() - a = StringField() - b = StringField() - c = StringField() + name = fields.StringField() + a = fields.StringField() + b = fields.StringField() + c = fields.StringField() meta = {"allow_inheritance": False} @@ -211,7 +214,7 @@ class Doc(Document): def test_geojson_PointField(self): class Location(Document): - loc = PointField() + loc = fields.PointField() update = transform.update(Location, set__loc=[1, 2]) assert update == {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} @@ -223,7 +226,7 @@ class Location(Document): def test_geojson_LineStringField(self): class Location(Document): - line = LineStringField() + line = fields.LineStringField() update = transform.update(Location, set__line=[[1, 2], [2, 2]]) assert update == { @@ -239,7 +242,7 @@ class Location(Document): def test_geojson_PolygonField(self): class Location(Document): - poly = PolygonField() + poly = fields.PolygonField() update = transform.update( Location, set__poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]] @@ -271,7 +274,7 @@ class Location(Document): def test_type(self): class Doc(Document): - df = DynamicField() + df = fields.DynamicField() Doc(df=True).save() Doc(df=7).save() @@ -283,11 +286,11 @@ class Doc(Document): def test_embedded_field_name_like_operator(self): class EmbeddedItem(EmbeddedDocument): - type = StringField() - name = StringField() + type = fields.StringField() + name = fields.StringField() class Doc(Document): - item = EmbeddedDocumentField(EmbeddedItem) + item = fields.EmbeddedDocumentField(EmbeddedItem) Doc.drop_collection() @@ -303,8 +306,8 @@ class Doc(Document): def test_regular_field_named_like_operator(self): class SimpleDoc(Document): - size = StringField() - type = StringField() + size = fields.StringField() + type = fields.StringField() SimpleDoc.drop_collection() SimpleDoc(type="ok", size="ok").save() @@ -327,8 +330,8 @@ class SimpleDoc(Document): def test_understandable_error_raised(self): class Event(Document): - title = StringField() - location = GeoPointField() + title = fields.StringField() + location = fields.GeoPointField() box = [(35.0, -125.0), (40.0, -100.0)] # I *meant* to execute location__within_box=box @@ -343,16 +346,16 @@ def test_update_pull_for_list_fields(self): """ class Word(EmbeddedDocument): - word = StringField() - index = IntField() + word = fields.StringField() + index = fields.IntField() class SubDoc(EmbeddedDocument): - heading = ListField(StringField()) - text = EmbeddedDocumentListField(Word) + heading = fields.ListField(fields.StringField()) + text = fields.EmbeddedDocumentListField(Word) class MainDoc(Document): - title = StringField() - content = EmbeddedDocumentField(SubDoc) + title = fields.StringField() + content = fields.EmbeddedDocumentField(SubDoc) word = Word(word="abc", index=1) update = transform.update(MainDoc, pull__content__text=word) @@ -378,11 +381,11 @@ def test_transform_embedded_document_list_fields(self): """ class Drink(EmbeddedDocument): - id = StringField() + id = fields.StringField() meta = {"strict": False} class Shop(Document): - drinks = EmbeddedDocumentListField(Drink) + drinks = fields.EmbeddedDocumentListField(Drink) Shop.drop_collection() drinks = [Drink(id="drink_1"), Drink(id="drink_2")] diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 1333f5574..25e747687 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -8,7 +8,7 @@ import pytest from bson import DBRef -from mongoengine import * +from mongoengine import Document, connect, register_connection from mongoengine.connection import _get_session, get_db from mongoengine.context_managers import ( no_dereference, @@ -20,6 +20,14 @@ switch_collection, switch_db, ) +from mongoengine.errors import OperationError +from mongoengine.fields import ( + GenericReferenceField, + IntField, + ListField, + ReferenceField, + StringField, +) from mongoengine.pymongo_support import count_documents from tests.utils import ( MongoDBTestCase, diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 000000000..2d780ccf1 --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,107 @@ +# mypy: enable-error-code="var-annotated" +from datetime import date, datetime +from decimal import Decimal +from enum import Enum +from typing import Any, Optional +from uuid import UUID + +from bson import ObjectId +from typing_extensions import assert_type + +from mongoengine import Document, EmbeddedDocument, fields +from mongoengine.base.datastructures import LazyReference + + +def test_it_uses_correct_types() -> None: + + class ImageEmbedded(EmbeddedDocument): + pass + + class ImageDocument(Document): + pass + + class Color(Enum): + RED = "red" + + class Doc(Document): + stringfield = fields.StringField() + urlfield = fields.URLField() + emailfield = fields.EmailField() + intfield = fields.IntField() + floatfield = fields.FloatField() + decimalfield = fields.DecimalField() + booleanfield = fields.BooleanField() + datetimefield = fields.DateTimeField() + datefield = fields.DateField() + complexdatetimefield = fields.ComplexDateTimeField() + embeddeddocumentfield = fields.EmbeddedDocumentField(ImageEmbedded) + objectidfield = fields.ObjectIdField() + genericembeddeddocumentfield = fields.GenericEmbeddedDocumentField() + dynamicfield = fields.DynamicField() + listfield = fields.ListField(fields.StringField()) + sortedlistfield = fields.SortedListField(fields.StringField()) + embeddeddocumentlistfield = fields.EmbeddedDocumentListField(ImageEmbedded) + dictfield = fields.DictField(fields.StringField(required=True)) + mapfield = fields.MapField(fields.StringField()) + referencefield = fields.ReferenceField(ImageDocument) + cachedreferencefield = fields.CachedReferenceField(ImageDocument) + lazyreferencefield = fields.LazyReferenceField(ImageDocument) + genericlazyreferencefield = fields.GenericLazyReferenceField() + genericreferencefield = fields.GenericReferenceField() + binaryfield = fields.BinaryField() + filefield = fields.FileField() + imagefield = fields.ImageField() + geopointfield = fields.GeoPointField() + pointfield = fields.PointField() + linestringfield = fields.LineStringField() + polygonfield = fields.PolygonField() + sequencefield = fields.SequenceField() + uuidfield = fields.UUIDField() + enumfield = fields.EnumField(Color) + multipointfield = fields.MultiPointField() + multilinestringfield = fields.MultiLineStringField() + multipolygonfield = fields.MultiPolygonField() + decimal128field = fields.Decimal128Field() + + # Setting sequencefield prevents database access in tests. + doc = Doc(sequencefield=1) + + assert_type(doc.stringfield, Optional[str]) + assert_type(doc.urlfield, Optional[str]) + assert_type(doc.emailfield, Optional[str]) + assert_type(doc.intfield, Optional[int]) + assert_type(doc.longfield, Optional[int]) + assert_type(doc.floatfield, Optional[float]) + assert_type(doc.decimalfield, Optional[Decimal]) + assert_type(doc.booleanfield, Optional[bool]) + assert_type(doc.datetimefield, Optional[datetime]) + assert_type(doc.datefield, Optional[date]) + assert_type(doc.complexdatetimefield, Optional[datetime]) + assert_type(doc.embeddeddocumentfield, Optional[ImageEmbedded]) + assert_type(doc.objectidfield, Optional[ObjectId]) + assert_type(doc.genericembeddeddocumentfield, Optional[Any]) + assert_type(doc.dynamicfield, Optional[Any]) + assert_type(doc.listfield, list[Optional[str]]) + assert_type(doc.sortedlistfield, list[Optional[str]]) + assert_type(doc.embeddeddocumentlistfield, list[ImageEmbedded]) + assert_type(doc.dictfield, dict[str, str]) + assert_type(doc.mapfield, dict[str, Optional[str]]) + assert_type(doc.referencefield, Optional[ImageDocument]) + assert_type(doc.cachedreferencefield, Optional[ImageDocument]) + assert_type(doc.lazyreferencefield, ImageDocument) + assert_type(doc.genericlazyreferencefield, LazyReference[Any]) + assert_type(doc.genericreferencefield, Any) + assert_type(doc.binaryfield, Optional[bytes]) + assert_type(doc.filefield, Any) + assert_type(doc.imagefield, Any) + assert_type(doc.geopointfield, Optional[list[float]]) + assert_type(doc.pointfield, dict[str, Any]) + assert_type(doc.linestringfield, dict[str, Any]) + assert_type(doc.polygonfield, dict[str, Any]) + assert_type(doc.sequencefield, Any) + assert_type(doc.uuidfield, Optional[UUID]) + assert_type(doc.enumfield, Optional[Color]) + assert_type(doc.multipointfield, dict[str, Any]) + assert_type(doc.multilinestringfield, dict[str, Any]) + assert_type(doc.multipolygonfield, dict[str, Any]) + assert_type(doc.decimal128field, Optional[Decimal]) diff --git a/tests/test_typing.yml b/tests/test_typing.yml new file mode 100644 index 000000000..e91e670b4 --- /dev/null +++ b/tests/test_typing.yml @@ -0,0 +1,38 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/typeddjango/pytest-mypy-plugins/master/pytest_mypy_plugins/schema.json +- case: list_set_type + mypy_config: | + enable_error_code = var-annotated + skip: sys.version_info < (3, 9) + regex: true + main: | + from mongoengine import ( + Document, + ListField, + StringField, + ) + + class Book(Document): + authors = ListField(StringField()) + + book = Book() + book.authors = ["Sun Tzu"] + book.authors = [1] # E: List item 0 has incompatible type "int"; .* +- case: type_check + skip: sys.version_info < (3, 9) + parametrized: + - field: StringField + type: str + - field: IntField + type: int + main: | + from typing import Optional + from typing_extensions import assert_type + from mongoengine import Document, {{ field }} + + class Target(Document): + field = {{ field }}() + field_required = {{ field }}(required=True) + + instance = Target() + assert_type(instance.field, Optional[{{ type }}]) + assert_type(instance.field_required, {{ type }})