Skip to content

Commit

Permalink
chore: Add types from mongo-types
Browse files Browse the repository at this point in the history
  • Loading branch information
last-partizan committed Dec 1, 2024
1 parent 4d3ab60 commit 82647e2
Show file tree
Hide file tree
Showing 40 changed files with 1,941 additions and 253 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/github-actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ jobs:
- run: bash .github/workflows/install_ci_python_dep.sh
- run: pre-commit run -a

type-check:
# Can be moved to pre-commit, separate step for now.
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.9'
check-latest: true
- run: bash .github/workflows/install_ci_python_typing_deps.sh
- run: mypy mongoengine tests

test:
# Test suite run against recent python versions
# and against a few combination of MongoDB and pymongo
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/install_ci_python_typing_deps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
pip install --upgrade pip
pip install mypy==1.13.0 typing-extensions mongomock types-Pygments types-cffi types-colorama types-pyOpenSSL types-python-dateutil types-requests types-setuptools
pip install -e '.[test]'
12 changes: 6 additions & 6 deletions mongoengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from mongoengine.signals import * # noqa: F401

__all__ = (
list(document.__all__)
+ list(fields.__all__)
+ list(connection.__all__)
+ list(queryset.__all__)
+ list(signals.__all__)
+ list(errors.__all__)
document.__all__
+ fields.__all__
+ connection.__all__
+ queryset.__all__
+ signals.__all__
+ errors.__all__
)


Expand Down
6 changes: 6 additions & 0 deletions mongoengine/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import TYPE_CHECKING, TypeVar

if TYPE_CHECKING:
from mongoengine.queryset.queryset import QuerySet

QS = TypeVar("QS", bound="QuerySet")
15 changes: 10 additions & 5 deletions mongoengine/base/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

from mongoengine.errors import NotRegistered

__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")
if TYPE_CHECKING:
from mongoengine.document import Document

__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")

UPDATE_OPERATORS = {
"set",
Expand All @@ -24,7 +29,7 @@
}


_document_registry = {}
_document_registry: dict[str, type[Document]] = {}


class _DocumentRegistry:
Expand All @@ -33,7 +38,7 @@ class _DocumentRegistry:
"""

@staticmethod
def get(name):
def get(name: str) -> type[Document]:
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
Expand All @@ -58,7 +63,7 @@ def get(name):
return doc

@staticmethod
def register(DocCls):
def register(DocCls: type[Document]) -> None:
ExistingDocCls = _document_registry.get(DocCls._class_name)
if (
ExistingDocCls is not None
Expand All @@ -76,7 +81,7 @@ def register(DocCls):
_document_registry[DocCls._class_name] = DocCls

@staticmethod
def unregister(doc_cls_name):
def unregister(doc_cls_name: str):
_document_registry.pop(doc_cls_name)


Expand Down
24 changes: 17 additions & 7 deletions mongoengine/base/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from __future__ import annotations

import weakref
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from bson import DBRef
from bson import DBRef, ObjectId

from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned

if TYPE_CHECKING:
from mongoengine import Document

_T = TypeVar("_T", bound="Document")

__all__ = (
"BaseDict",
"StrictDict",
Expand Down Expand Up @@ -356,7 +364,7 @@ def update(self, **update):
class StrictDict:
__slots__ = ()
_special_fields = {"get", "pop", "iteritems", "items", "keys", "create"}
_classes = {}
_classes: dict[str, Any] = {}

def __init__(self, **kwargs):
for k, v in kwargs.items():
Expand Down Expand Up @@ -435,7 +443,7 @@ def __repr__(self):
return cls._classes[allowed_keys]


class LazyReference(DBRef):
class LazyReference(Generic[_T], DBRef):
__slots__ = ("_cached_doc", "passthrough", "document_type")

def fetch(self, force=False):
Expand All @@ -449,19 +457,21 @@ def fetch(self, force=False):
def pk(self):
return self.id

def __init__(self, document_type, pk, cached_doc=None, passthrough=False):
def __init__(
self, document_type: type[_T], pk: ObjectId, cached_doc=None, passthrough=False
):
self.document_type = document_type
self._cached_doc = cached_doc
self.passthrough = passthrough
super().__init__(self.document_type._get_collection_name(), pk)
super().__init__(self.document_type._get_collection_name(), pk) # type: ignore[arg-type]

def __getitem__(self, name):
def __getitem__(self, name: str) -> Any:
if not self.passthrough:
raise KeyError()
document = self.fetch()
return document[name]

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
if not object.__getattribute__(self, "passthrough"):
raise AttributeError()
document = self.fetch()
Expand Down
45 changes: 31 additions & 14 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# mypy: disable-error-code="attr-defined,union-attr,assignment"
from __future__ import annotations

import copy
import numbers
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any

import pymongo
from bson import SON, DBRef, ObjectId, json_util
from typing_extensions import Self

from mongoengine import signals
from mongoengine.base.common import _DocumentRegistry
Expand All @@ -15,7 +20,7 @@
LazyReference,
StrictDict,
)
from mongoengine.base.fields import ComplexBaseField
from mongoengine.base.fields import BaseField, ComplexBaseField
from mongoengine.common import _import_class
from mongoengine.errors import (
FieldDoesNotExist,
Expand All @@ -26,12 +31,15 @@
)
from mongoengine.pymongo_support import LEGACY_JSON_OPTIONS

if TYPE_CHECKING:
from mongoengine.fields import DynamicField

__all__ = ("BaseDocument", "NON_FIELD_ERRORS")

NON_FIELD_ERRORS = "__all__"

try:
GEOHAYSTACK = pymongo.GEOHAYSTACK
GEOHAYSTACK = pymongo.GEOHAYSTACK # type: ignore[attr-defined]
except AttributeError:
GEOHAYSTACK = None

Expand Down Expand Up @@ -62,7 +70,12 @@ class BaseDocument:
_dynamic_lock = True
STRICT = False

def __init__(self, *args, **values):
# Fields, added by metaclass
_class_name: str
_fields: dict[str, BaseField]
_meta: dict[str, Any]

def __init__(self, *args, **values) -> None:
"""
Initialise a document or an embedded document.
Expand Down Expand Up @@ -103,7 +116,7 @@ def __init__(self, *args, **values):
else:
self._data = {}

self._dynamic_fields = SON()
self._dynamic_fields: SON[str, DynamicField] = SON()

# Assign default values for fields
# not set in the constructor
Expand Down Expand Up @@ -329,13 +342,15 @@ def get_text_score(self):

return self._data["_text_score"]

def to_mongo(self, use_db_field=True, fields=None):
def to_mongo(
self, use_db_field: bool = True, fields: list[str] | None = None
) -> SON[Any, Any]:
"""
Return as SON data ready for use with MongoDB.
"""
fields = fields or []

data = SON()
data: SON[str, Any] = SON()
data["_id"] = None
data["_cls"] = self._class_name

Expand All @@ -354,7 +369,7 @@ def to_mongo(self, use_db_field=True, fields=None):

if value is not None:
f_inputs = field.to_mongo.__code__.co_varnames
ex_vars = {}
ex_vars: dict[str, Any] = {}
if fields and "fields" in f_inputs:
key = "%s." % field_name
embedded_fields = [
Expand All @@ -370,7 +385,7 @@ def to_mongo(self, use_db_field=True, fields=None):

# Handle self generating fields
if value is None and field._auto_gen:
value = field.generate()
value = field.generate() # type: ignore[attr-defined]
self._data[field_name] = value

if value is not None or field.null:
Expand All @@ -385,7 +400,7 @@ def to_mongo(self, use_db_field=True, fields=None):

return data

def validate(self, clean=True):
def validate(self, clean: bool = True) -> None:
"""Ensure that all fields' values are valid and that required fields
are present.
Expand Down Expand Up @@ -439,7 +454,7 @@ def validate(self, clean=True):
message = f"ValidationError ({self._class_name}:{pk}) "
raise ValidationError(message, errors=errors)

def to_json(self, *args, **kwargs):
def to_json(self, *args: Any, **kwargs: Any) -> str:
"""Convert this document to JSON.
:param use_db_field: Serialize field names as they appear in
Expand All @@ -461,7 +476,7 @@ def to_json(self, *args, **kwargs):
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)

@classmethod
def from_json(cls, json_data, created=False, **kwargs):
def from_json(cls, json_data: str, created: bool = False, **kwargs: Any) -> Self:
"""Converts json data to a Document instance.
:param str json_data: The json data to load into the Document.
Expand Down Expand Up @@ -687,7 +702,7 @@ def _get_changed_fields(self):
self._nestable_types_changed_fields(changed_fields, key, data)
return changed_fields

def _delta(self):
def _delta(self) -> tuple[dict[str, Any], dict[str, Any]]:
"""Returns the delta (set, unset) of the changes for a document.
Gets any values that have been explicitly changed.
"""
Expand Down Expand Up @@ -771,14 +786,16 @@ def _delta(self):
return set_data, unset_data

@classmethod
def _get_collection_name(cls):
def _get_collection_name(cls) -> str | None:
"""Return the collection name for this class. None for abstract
class.
"""
return cls._meta.get("collection", None)

@classmethod
def _from_son(cls, son, _auto_dereference=True, created=False):
def _from_son(
cls, son: dict[str, Any], _auto_dereference: bool = True, created: bool = False
) -> Self:
"""Create an instance of a Document (subclass) from a PyMongo SON (dict)"""
if son and not isinstance(son, dict):
raise ValueError(
Expand Down
Loading

0 comments on commit 82647e2

Please sign in to comment.