Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types #2822

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/github-actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ jobs:
- run: bash .github/workflows/install_ci_python_dep.sh
- run: pre-commit run -a

type-check:
# Can be moved to pre-commit, separate step for now.
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.9'
check-latest: true
- run: bash .github/workflows/install_ci_python_typing_deps.sh
- run: mypy mongoengine tests

test:
# Test suite run against recent python versions
# and against a few combination of MongoDB and pymongo
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/install_ci_python_typing_deps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
pip install --upgrade pip
pip install mypy==1.13.0 typing-extensions mongomock types-Pygments types-cffi types-colorama types-pyOpenSSL types-python-dateutil types-requests types-setuptools
pip install -e '.[test]'
12 changes: 6 additions & 6 deletions mongoengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from mongoengine.signals import * # noqa: F401

__all__ = (
list(document.__all__)
+ list(fields.__all__)
+ list(connection.__all__)
+ list(queryset.__all__)
+ list(signals.__all__)
+ list(errors.__all__)
document.__all__
+ fields.__all__
+ connection.__all__
+ queryset.__all__
+ signals.__all__
+ errors.__all__
)


Expand Down
6 changes: 6 additions & 0 deletions mongoengine/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import TYPE_CHECKING, TypeVar

if TYPE_CHECKING:
from mongoengine.queryset.queryset import QuerySet

QS = TypeVar("QS", bound="QuerySet")
15 changes: 10 additions & 5 deletions mongoengine/base/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

from mongoengine.errors import NotRegistered

__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")
if TYPE_CHECKING:
from mongoengine.document import Document

__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")

UPDATE_OPERATORS = {
"set",
Expand All @@ -24,7 +29,7 @@
}


_document_registry = {}
_document_registry: dict[str, type[Document]] = {}


class _DocumentRegistry:
Expand All @@ -33,7 +38,7 @@ class _DocumentRegistry:
"""

@staticmethod
def get(name):
def get(name: str) -> type[Document]:
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
Expand All @@ -58,7 +63,7 @@ def get(name):
return doc

@staticmethod
def register(DocCls):
def register(DocCls: type[Document]) -> None:
ExistingDocCls = _document_registry.get(DocCls._class_name)
if (
ExistingDocCls is not None
Expand All @@ -76,7 +81,7 @@ def register(DocCls):
_document_registry[DocCls._class_name] = DocCls

@staticmethod
def unregister(doc_cls_name):
def unregister(doc_cls_name: str):
_document_registry.pop(doc_cls_name)


Expand Down
24 changes: 17 additions & 7 deletions mongoengine/base/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from __future__ import annotations

import weakref
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from bson import DBRef
from bson import DBRef, ObjectId

from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned

if TYPE_CHECKING:
from mongoengine import Document

_T = TypeVar("_T", bound="Document")

__all__ = (
"BaseDict",
"StrictDict",
Expand Down Expand Up @@ -356,7 +364,7 @@ def update(self, **update):
class StrictDict:
__slots__ = ()
_special_fields = {"get", "pop", "iteritems", "items", "keys", "create"}
_classes = {}
_classes: dict[str, Any] = {}

def __init__(self, **kwargs):
for k, v in kwargs.items():
Expand Down Expand Up @@ -435,7 +443,7 @@ def __repr__(self):
return cls._classes[allowed_keys]


class LazyReference(DBRef):
class LazyReference(Generic[_T], DBRef):
__slots__ = ("_cached_doc", "passthrough", "document_type")

def fetch(self, force=False):
Expand All @@ -449,19 +457,21 @@ def fetch(self, force=False):
def pk(self):
return self.id

def __init__(self, document_type, pk, cached_doc=None, passthrough=False):
def __init__(
self, document_type: type[_T], pk: ObjectId, cached_doc=None, passthrough=False
):
self.document_type = document_type
self._cached_doc = cached_doc
self.passthrough = passthrough
super().__init__(self.document_type._get_collection_name(), pk)
super().__init__(self.document_type._get_collection_name(), pk) # type: ignore[arg-type]

def __getitem__(self, name):
def __getitem__(self, name: str) -> Any:
if not self.passthrough:
raise KeyError()
document = self.fetch()
return document[name]

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
if not object.__getattribute__(self, "passthrough"):
raise AttributeError()
document = self.fetch()
Expand Down
45 changes: 31 additions & 14 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# mypy: disable-error-code="attr-defined,union-attr,assignment"
from __future__ import annotations

import copy
import numbers
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any

import pymongo
from bson import SON, DBRef, ObjectId, json_util
from typing_extensions import Self

from mongoengine import signals
from mongoengine.base.common import _DocumentRegistry
Expand All @@ -15,7 +20,7 @@
LazyReference,
StrictDict,
)
from mongoengine.base.fields import ComplexBaseField
from mongoengine.base.fields import BaseField, ComplexBaseField
from mongoengine.common import _import_class
from mongoengine.errors import (
FieldDoesNotExist,
Expand All @@ -26,12 +31,15 @@
)
from mongoengine.pymongo_support import LEGACY_JSON_OPTIONS

if TYPE_CHECKING:
from mongoengine.fields import DynamicField

__all__ = ("BaseDocument", "NON_FIELD_ERRORS")

NON_FIELD_ERRORS = "__all__"

try:
GEOHAYSTACK = pymongo.GEOHAYSTACK
GEOHAYSTACK = pymongo.GEOHAYSTACK # type: ignore[attr-defined]
except AttributeError:
GEOHAYSTACK = None

Expand Down Expand Up @@ -62,7 +70,12 @@ class BaseDocument:
_dynamic_lock = True
STRICT = False

def __init__(self, *args, **values):
# Fields, added by metaclass
_class_name: str
_fields: dict[str, BaseField]
_meta: dict[str, Any]

def __init__(self, *args, **values) -> None:
"""
Initialise a document or an embedded document.

Expand Down Expand Up @@ -103,7 +116,7 @@ def __init__(self, *args, **values):
else:
self._data = {}

self._dynamic_fields = SON()
self._dynamic_fields: SON[str, DynamicField] = SON()

# Assign default values for fields
# not set in the constructor
Expand Down Expand Up @@ -329,13 +342,15 @@ def get_text_score(self):

return self._data["_text_score"]

def to_mongo(self, use_db_field=True, fields=None):
def to_mongo(
self, use_db_field: bool = True, fields: list[str] | None = None
) -> SON[Any, Any]:
"""
Return as SON data ready for use with MongoDB.
"""
fields = fields or []

data = SON()
data: SON[str, Any] = SON()
data["_id"] = None
data["_cls"] = self._class_name

Expand All @@ -354,7 +369,7 @@ def to_mongo(self, use_db_field=True, fields=None):

if value is not None:
f_inputs = field.to_mongo.__code__.co_varnames
ex_vars = {}
ex_vars: dict[str, Any] = {}
if fields and "fields" in f_inputs:
key = "%s." % field_name
embedded_fields = [
Expand All @@ -370,7 +385,7 @@ def to_mongo(self, use_db_field=True, fields=None):

# Handle self generating fields
if value is None and field._auto_gen:
value = field.generate()
value = field.generate() # type: ignore[attr-defined]
self._data[field_name] = value

if value is not None or field.null:
Expand All @@ -385,7 +400,7 @@ def to_mongo(self, use_db_field=True, fields=None):

return data

def validate(self, clean=True):
def validate(self, clean: bool = True) -> None:
"""Ensure that all fields' values are valid and that required fields
are present.

Expand Down Expand Up @@ -439,7 +454,7 @@ def validate(self, clean=True):
message = f"ValidationError ({self._class_name}:{pk}) "
raise ValidationError(message, errors=errors)

def to_json(self, *args, **kwargs):
def to_json(self, *args: Any, **kwargs: Any) -> str:
"""Convert this document to JSON.

:param use_db_field: Serialize field names as they appear in
Expand All @@ -461,7 +476,7 @@ def to_json(self, *args, **kwargs):
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)

@classmethod
def from_json(cls, json_data, created=False, **kwargs):
def from_json(cls, json_data: str, created: bool = False, **kwargs: Any) -> Self:
"""Converts json data to a Document instance.

:param str json_data: The json data to load into the Document.
Expand Down Expand Up @@ -687,7 +702,7 @@ def _get_changed_fields(self):
self._nestable_types_changed_fields(changed_fields, key, data)
return changed_fields

def _delta(self):
def _delta(self) -> tuple[dict[str, Any], dict[str, Any]]:
"""Returns the delta (set, unset) of the changes for a document.
Gets any values that have been explicitly changed.
"""
Expand Down Expand Up @@ -771,14 +786,16 @@ def _delta(self):
return set_data, unset_data

@classmethod
def _get_collection_name(cls):
def _get_collection_name(cls) -> str | None:
"""Return the collection name for this class. None for abstract
class.
"""
return cls._meta.get("collection", None)

@classmethod
def _from_son(cls, son, _auto_dereference=True, created=False):
def _from_son(
cls, son: dict[str, Any], _auto_dereference: bool = True, created: bool = False
) -> Self:
"""Create an instance of a Document (subclass) from a PyMongo SON (dict)"""
if son and not isinstance(son, dict):
raise ValueError(
Expand Down
Loading
Loading