Skip to content

Commit

Permalink
refactor _document_registry + log a warning when user register multip…
Browse files Browse the repository at this point in the history
…le Document classes with the same name (only flagging when this happens in different module)
  • Loading branch information
bagerard committed Oct 2, 2024
1 parent fd109d8 commit f0de61e
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 61 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Development
- make sure to read https://www.mongodb.com/docs/manual/core/transactions-in-applications/#callback-api-vs-core-api
- run_in_transaction context manager relies on Pymongo coreAPI, it will retry automatically in case of `UnknownTransactionCommitResult` but not `TransientTransactionError` exceptions
- Using .count() in a transaction will always use Collection.count_document (as estimated_document_count is not supported in transactions)
- BREAKING CHANGE: wrap _document_registry (normally not used by end users) with _DocumentRegistry which acts as a singleton to access the registry
- Log a warning in case users creates multiple Document classes with the same name as it can lead to unexpected behavior #1778
- Fix use of $geoNear or $collStats in aggregate #2493
- BREAKING CHANGE: Further to the deprecation warning, remove ability to use an unpacked list to `Queryset.aggregate(*pipeline)`, a plain list must be provided instead `Queryset.aggregate(pipeline)`, as it's closer to pymongo interface
- BREAKING CHANGE: Further to the deprecation warning, remove `full_response` from `QuerySet.modify` as it wasn't supported with Pymongo 3+
Expand All @@ -21,6 +23,7 @@ Development
- BREAKING CHANGE: Remove LongField as it's equivalent to IntField since we drop support to Python2 long time ago (User should simply switch to IntField) #2309
- BugFix - Calling .clear on a ListField wasn't being marked as changed (and flushed to db upon .save()) #2858


Changes in 0.29.0
=================
- Fix weakref in EmbeddedDocumentListField (causing brief mem leak in certain circumstances) #2827
Expand Down
3 changes: 1 addition & 2 deletions mongoengine/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
__all__ = (
# common
"UPDATE_OPERATORS",
"_document_registry",
"get_document",
"_DocumentRegistry",
# datastructures
"BaseDict",
"BaseList",
Expand Down
77 changes: 54 additions & 23 deletions mongoengine/base/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings

from mongoengine.errors import NotRegistered

__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry")
__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")


UPDATE_OPERATORS = {
Expand All @@ -25,28 +27,57 @@
_document_registry = {}


def get_document(name):
"""Get a registered Document class by name."""
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
single_end = name.split(".")[-1]
compound_end = ".%s" % single_end
possible_match = [
k for k in _document_registry if k.endswith(compound_end) or k == single_end
]
if len(possible_match) == 1:
doc = _document_registry.get(possible_match.pop(), None)
if not doc:
raise NotRegistered(
"""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip()
% name
)
return doc
class _DocumentRegistry:
"""Wrapper for the document registry (providing a singleton pattern).
This is part of MongoEngine's internals, not meant to be used directly by end-users
"""

@staticmethod
def get(name):
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
single_end = name.split(".")[-1]
compound_end = ".%s" % single_end
possible_match = [
k
for k in _document_registry
if k.endswith(compound_end) or k == single_end
]
if len(possible_match) == 1:
doc = _document_registry.get(possible_match.pop(), None)
if not doc:
raise NotRegistered(
"""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip()
% name
)
return doc

@staticmethod
def register(DocCls):
ExistingDocCls = _document_registry.get(DocCls._class_name)
if (
ExistingDocCls is not None
and ExistingDocCls.__module__ != DocCls.__module__
):
# A sign that a codebase may have named two different classes with the same name accidentally,
# this could cause issues with dereferencing because MongoEngine makes the assumption that a Document
# class name is unique.
warnings.warn(
f"Multiple Document classes named `{DocCls._class_name}` were registered, "
f"first from: `{ExistingDocCls.__module__}`, then from: `{DocCls.__module__}`. "
"this may lead to unexpected behavior during dereferencing.",
stacklevel=4,
)
_document_registry[DocCls._class_name] = DocCls

@staticmethod
def unregister(doc_cls_name):
_document_registry.pop(doc_cls_name)


def _get_documents_by_db(connection_alias, default_connection_alias):
Expand Down
6 changes: 3 additions & 3 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bson import SON, DBRef, ObjectId, json_util

from mongoengine import signals
from mongoengine.base.common import get_document
from mongoengine.base.common import _DocumentRegistry
from mongoengine.base.datastructures import (
BaseDict,
BaseList,
Expand Down Expand Up @@ -500,7 +500,7 @@ def __expand_dynamic_values(self, name, value):
# If the value is a dict with '_cls' in it, turn it into a document
is_dict = isinstance(value, dict)
if is_dict and "_cls" in value:
cls = get_document(value["_cls"])
cls = _DocumentRegistry.get(value["_cls"])
return cls(**value)

if is_dict:
Expand Down Expand Up @@ -802,7 +802,7 @@ def _from_son(cls, son, _auto_dereference=True, created=False):

# Return correct subclass for document type
if class_name != cls._class_name:
cls = get_document(class_name)
cls = _DocumentRegistry.get(class_name)

errors_dict = {}

Expand Down
4 changes: 2 additions & 2 deletions mongoengine/base/metaclasses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import warnings

from mongoengine.base.common import _document_registry
from mongoengine.base.common import _DocumentRegistry
from mongoengine.base.fields import (
BaseField,
ComplexBaseField,
Expand Down Expand Up @@ -169,7 +169,7 @@ def __new__(mcs, name, bases, attrs):
new_class._collection = None

# Add class to the _document_registry
_document_registry[new_class._class_name] = new_class
_DocumentRegistry.register(new_class)

# Handle delete rules
for field in new_class._fields.values():
Expand Down
20 changes: 10 additions & 10 deletions mongoengine/dereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
BaseList,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
_DocumentRegistry,
)
from mongoengine.base.datastructures import LazyReference
from mongoengine.connection import _get_session, get_db
Expand Down Expand Up @@ -131,9 +131,9 @@ def _find_references(self, items, depth=0):
elif isinstance(v, DBRef):
reference_map.setdefault(field.document_type, set()).add(v.id)
elif isinstance(v, (dict, SON)) and "_ref" in v:
reference_map.setdefault(get_document(v["_cls"]), set()).add(
v["_ref"].id
)
reference_map.setdefault(
_DocumentRegistry.get(v["_cls"]), set()
).add(v["_ref"].id)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
field_cls = getattr(
getattr(field, "field", None), "document_type", None
Expand All @@ -151,9 +151,9 @@ def _find_references(self, items, depth=0):
elif isinstance(item, DBRef):
reference_map.setdefault(item.collection, set()).add(item.id)
elif isinstance(item, (dict, SON)) and "_ref" in item:
reference_map.setdefault(get_document(item["_cls"]), set()).add(
item["_ref"].id
)
reference_map.setdefault(
_DocumentRegistry.get(item["_cls"]), set()
).add(item["_ref"].id)
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1)
for key, refs in references.items():
Expand Down Expand Up @@ -198,9 +198,9 @@ def _fetch_objects(self, doc_type=None):
)
for ref in references:
if "_cls" in ref:
doc = get_document(ref["_cls"])._from_son(ref)
doc = _DocumentRegistry.get(ref["_cls"])._from_son(ref)
elif doc_type is None:
doc = get_document(
doc = _DocumentRegistry.get(
"".join(x.capitalize() for x in collection.split("_"))
)._from_son(ref)
else:
Expand Down Expand Up @@ -235,7 +235,7 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):
(items["_ref"].collection, items["_ref"].id), items
)
elif "_cls" in items:
doc = get_document(items["_cls"])._from_son(items)
doc = _DocumentRegistry.get(items["_cls"])._from_son(items)
_cls = doc._data.pop("_cls", None)
del items["_cls"]
doc._data = self._attach_objects(doc._data, depth, doc, None)
Expand Down
6 changes: 3 additions & 3 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DocumentMetaclass,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
_DocumentRegistry,
)
from mongoengine.base.utils import NonOrderedList
from mongoengine.common import _import_class
Expand Down Expand Up @@ -851,12 +851,12 @@ def register_delete_rule(cls, document_cls, field_name, rule):
object.
"""
classes = [
get_document(class_name)
_DocumentRegistry.get(class_name)
for class_name in cls._subclasses
if class_name != cls.__name__
] + [cls]
documents = [
get_document(class_name)
_DocumentRegistry.get(class_name)
for class_name in document_cls._subclasses
if class_name != document_cls.__name__
] + [document_cls]
Expand Down
20 changes: 10 additions & 10 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
GeoJsonBaseField,
LazyReference,
ObjectIdField,
get_document,
_DocumentRegistry,
)
from mongoengine.base.utils import LazyRegexCompiler
from mongoengine.common import _import_class
Expand Down Expand Up @@ -725,7 +725,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
resolved_document_type = self.owner_document
else:
resolved_document_type = get_document(self.document_type_obj)
resolved_document_type = _DocumentRegistry.get(self.document_type_obj)

if not issubclass(resolved_document_type, EmbeddedDocument):
# Due to the late resolution of the document_type
Expand Down Expand Up @@ -801,7 +801,7 @@ def prepare_query_value(self, op, value):

def to_python(self, value):
if isinstance(value, dict):
doc_cls = get_document(value["_cls"])
doc_cls = _DocumentRegistry.get(value["_cls"])
value = doc_cls._from_son(value)

return value
Expand Down Expand Up @@ -879,7 +879,7 @@ def to_mongo(self, value, use_db_field=True, fields=None):

def to_python(self, value):
if isinstance(value, dict) and "_cls" in value:
doc_cls = get_document(value["_cls"])
doc_cls = _DocumentRegistry.get(value["_cls"])
if "_ref" in value:
value = doc_cls._get_db().dereference(
value["_ref"], session=_get_session()
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
return self.document_type_obj

@staticmethod
Expand All @@ -1195,7 +1195,7 @@ def __get__(self, instance, owner):
if auto_dereference and isinstance(ref_value, DBRef):
if hasattr(ref_value, "cls"):
# Dereference using the class type specified in the reference
cls = get_document(ref_value.cls)
cls = _DocumentRegistry.get(ref_value.cls)
else:
cls = self.document_type

Expand Down Expand Up @@ -1335,7 +1335,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
return self.document_type_obj

@staticmethod
Expand Down Expand Up @@ -1498,7 +1498,7 @@ def __get__(self, instance, owner):

auto_dereference = instance._fields[self.name]._auto_dereference
if auto_dereference and isinstance(value, dict):
doc_cls = get_document(value["_cls"])
doc_cls = _DocumentRegistry.get(value["_cls"])
instance._data[self.name] = self._lazy_load_ref(doc_cls, value["_ref"])

return super().__get__(instance, owner)
Expand Down Expand Up @@ -2443,7 +2443,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
return self.document_type_obj

def build_lazyref(self, value):
Expand Down Expand Up @@ -2584,7 +2584,7 @@ def build_lazyref(self, value):
elif value is not None:
if isinstance(value, (dict, SON)):
value = LazyReference(
get_document(value["_cls"]),
_DocumentRegistry.get(value["_cls"]),
value["_ref"].id,
passthrough=self.passthrough,
)
Expand Down
6 changes: 4 additions & 2 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pymongo.read_concern import ReadConcern

from mongoengine import signals
from mongoengine.base import get_document
from mongoengine.base import _DocumentRegistry
from mongoengine.common import _import_class
from mongoengine.connection import _get_session, get_db
from mongoengine.context_managers import (
Expand Down Expand Up @@ -1956,7 +1956,9 @@ def _fields_to_dbfields(self, fields):
"""Translate fields' paths to their db equivalents."""
subclasses = []
if self._document._meta["allow_inheritance"]:
subclasses = [get_document(x) for x in self._document._subclasses][1:]
subclasses = [_DocumentRegistry.get(x) for x in self._document._subclasses][
1:
]

db_field_paths = []
for field in fields:
Expand Down
8 changes: 4 additions & 4 deletions tests/document/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from mongoengine import *
from mongoengine import signals
from mongoengine.base import _document_registry, get_document
from mongoengine.base import _DocumentRegistry
from mongoengine.connection import get_db
from mongoengine.context_managers import query_counter, switch_db
from mongoengine.errors import (
Expand Down Expand Up @@ -392,7 +392,7 @@ class NicePlace(Place):

# Mimic Place and NicePlace definitions being in a different file
# and the NicePlace model not being imported in at query time.
del _document_registry["Place.NicePlace"]
_DocumentRegistry.unregister("Place.NicePlace")

with pytest.raises(NotRegistered):
list(Place.objects.all())
Expand All @@ -407,8 +407,8 @@ class Area(Location):

Location.drop_collection()

assert Area == get_document("Area")
assert Area == get_document("Location.Area")
assert Area == _DocumentRegistry.get("Area")
assert Area == _DocumentRegistry.get("Location.Area")

def test_creation(self):
"""Ensure that document may be created using keyword arguments."""
Expand Down
4 changes: 2 additions & 2 deletions tests/fields/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from mongoengine.base import (
BaseField,
EmbeddedDocumentList,
_document_registry,
_DocumentRegistry,
)
from mongoengine.base.fields import _no_dereference_for_fields
from mongoengine.errors import DeprecatedError
Expand Down Expand Up @@ -1678,7 +1678,7 @@ class User(Document):

# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
del _document_registry["Link"]
_DocumentRegistry.unregister("Link")

user = User.objects.first()
try:
Expand Down

0 comments on commit f0de61e

Please sign in to comment.