Skip to content

Commit 88c5599

Browse files
author
Jesse McLaughlin
committed
[MongoEngine#2685] implement a thread-safe switch db context manager
1 parent 51afeca commit 88c5599

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

mongoengine/connection.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import warnings
2+
from threading import local
23

34
from pymongo import MongoClient, ReadPreference, uri_parser
45
from pymongo.database import _check_name
56

7+
from mongoengine.errors import DatabaseAliasError
68
from mongoengine.pymongo_support import PYMONGO_VERSION
79

810
__all__ = [
@@ -15,6 +17,8 @@
1517
"get_connection",
1618
"get_db",
1719
"register_connection",
20+
"set_local_db_alias",
21+
"del_local_db_alias"
1822
]
1923

2024

@@ -26,6 +30,8 @@
2630
_connection_settings = {}
2731
_connections = {}
2832
_dbs = {}
33+
_local = local()
34+
_local.db_alias = {}
2935

3036
READ_PREFERENCE = ReadPreference.PRIMARY
3137

@@ -372,7 +378,29 @@ def _clean_settings(settings_dict):
372378
return _connections[db_alias]
373379

374380

381+
def set_local_db_alias(local_alias, alias=DEFAULT_CONNECTION_NAME):
382+
if not alias or not local_alias:
383+
raise DatabaseAliasError(f"db alias and local_alias cannot be empty")
384+
385+
if alias not in _local.db_alias:
386+
_local.db_alias[alias] = local_alias
387+
else:
388+
raise DatabaseAliasError(f"local db alias already set: {alias}")
389+
390+
391+
def del_local_db_alias(alias):
392+
if not alias:
393+
raise DatabaseAliasError(f"db alias cannot be empty")
394+
if alias in _local.db_alias:
395+
del _local.db_alias[alias]
396+
else:
397+
raise DatabaseAliasError(f"local db alias not set: {alias}")
398+
399+
375400
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
401+
if alias in _local.db_alias:
402+
alias = _local.db_alias[alias]
403+
376404
if reconnect:
377405
disconnect(alias)
378406

mongoengine/context_managers.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from pymongo.write_concern import WriteConcern
55

66
from mongoengine.common import _import_class
7-
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
7+
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db, set_local_db_alias, del_local_db_alias
88
from mongoengine.pymongo_support import count_documents
99

1010
__all__ = (
11+
"switch_db_local",
1112
"switch_db",
1213
"switch_collection",
1314
"no_dereference",
@@ -18,6 +19,36 @@
1819
)
1920

2021

22+
class switch_db_local:
23+
"""switch_db_local alias context manager.
24+
25+
Switches a db alias in a thread-safe way.
26+
27+
Example ::
28+
register_connection('testdb-1', 'mongoenginetest1')
29+
register_connection('testdb-2', 'mongoenginetest2')
30+
31+
class Group(Document):
32+
name = StringField()
33+
34+
# The following two calls to save() could be run concurrently
35+
with switch_db_local('testdb-1'):
36+
Group(name='test').save()
37+
with switch_db_local('testdb-2'):
38+
Group(name='test').save()
39+
"""
40+
41+
def __init__(self, local_alias, alias=DEFAULT_CONNECTION_NAME):
42+
self.local_alias = local_alias
43+
self.alias = alias
44+
45+
def __enter__(self):
46+
set_local_db_alias(self.local_alias, self.alias)
47+
48+
def __exit__(self, t, value, traceback):
49+
del_local_db_alias(self.alias)
50+
51+
2152
class switch_db:
2253
"""switch_db alias context manager.
2354

mongoengine/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import defaultdict
22

33
__all__ = (
4+
"DatabaseAliasError",
45
"NotRegistered",
56
"InvalidDocumentError",
67
"LookUpError",
@@ -21,6 +22,10 @@ class MongoEngineException(Exception):
2122
pass
2223

2324

25+
class DatabaseAliasError(MongoEngineException):
26+
pass
27+
28+
2429
class NotRegistered(MongoEngineException):
2530
pass
2631

0 commit comments

Comments
 (0)