From ebb1d62c2cf16d252511710bd336828d5bf09707 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Tue, 6 May 2025 20:29:24 -0400 Subject: [PATCH 1/9] PYTHON-4542 Improved sessions API - Via context variable. --- pymongo/asynchronous/client_session.py | 6 ++++++ pymongo/asynchronous/cursor.py | 7 ++++++- pymongo/asynchronous/mongo_client.py | 10 ++++++++-- pymongo/synchronous/client_session.py | 6 ++++++ pymongo/synchronous/cursor.py | 7 ++++++- pymongo/synchronous/mongo_client.py | 10 ++++++++-- test/asynchronous/test_session.py | 7 +++++++ test/test_session.py | 7 +++++++ 8 files changed, 54 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index b808684dd4..339a699f68 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -139,6 +139,7 @@ import time import uuid from collections.abc import Mapping as _Mapping +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, @@ -204,6 +205,7 @@ def __init__( causal_consistency: Optional[bool] = None, default_transaction_options: Optional[TransactionOptions] = None, snapshot: Optional[bool] = False, + bind: Optional[bool] = False, ) -> None: if snapshot: if causal_consistency: @@ -222,6 +224,7 @@ def __init__( ) self._default_transaction_options = default_transaction_options self._snapshot = snapshot + self._bind = bind @property def causal_consistency(self) -> bool: @@ -1065,6 +1068,9 @@ def __copy__(self) -> NoReturn: raise TypeError("A AsyncClientSession cannot be copied, create a new session instead") +SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None) + + class _EmptyServerSession: __slots__ = "dirty", "started_retryable_write" diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 1b25bf4ee8..1e683c269e 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -63,7 +63,7 @@ from _typeshed import SupportsItems from bson.codec_options import CodecOptions - from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.client_session import SESSION, AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.pool import AsyncConnection from pymongo.read_preferences import _ServerMode @@ -136,9 +136,14 @@ def __init__( self._killed = False self._session: Optional[AsyncClientSession] + _SESSION = SESSION.get() + if session: self._session = session self._explicit_session = True + elif _SESSION: + self._session = _SESSION + self._explicit_session = True else: self._session = None self._explicit_session = False diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a236b21348..384f013e1e 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -65,7 +65,7 @@ from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk -from pymongo.asynchronous.client_session import _EmptyServerSession +from pymongo.asynchronous.client_session import SESSION, _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext @@ -1355,13 +1355,18 @@ def _close_cursor_soon( def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession: server_session = _EmptyServerSession() opts = client_session.SessionOptions(**kwargs) - return client_session.AsyncClientSession(self, server_session, opts, implicit) + bind = opts._bind + session = client_session.AsyncClientSession(self, server_session, opts, implicit) + if bind: + SESSION.set(session) + return session def start_session( self, causal_consistency: Optional[bool] = None, default_transaction_options: Optional[client_session.TransactionOptions] = None, snapshot: Optional[bool] = False, + bind: Optional[bool] = False, ) -> client_session.AsyncClientSession: """Start a logical session. @@ -1384,6 +1389,7 @@ def start_session( causal_consistency=causal_consistency, default_transaction_options=default_transaction_options, snapshot=snapshot, + bind=bind, ) def _ensure_session( diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index aaf2d7574f..30dfed9462 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -139,6 +139,7 @@ import time import uuid from collections.abc import Mapping as _Mapping +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, @@ -203,6 +204,7 @@ def __init__( causal_consistency: Optional[bool] = None, default_transaction_options: Optional[TransactionOptions] = None, snapshot: Optional[bool] = False, + bind: Optional[bool] = False, ) -> None: if snapshot: if causal_consistency: @@ -221,6 +223,7 @@ def __init__( ) self._default_transaction_options = default_transaction_options self._snapshot = snapshot + self._bind = bind @property def causal_consistency(self) -> bool: @@ -1060,6 +1063,9 @@ def __copy__(self) -> NoReturn: raise TypeError("A ClientSession cannot be copied, create a new session instead") +SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None) + + class _EmptyServerSession: __slots__ = "dirty", "started_retryable_write" diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 31c4604f89..c4ccb4b994 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -64,7 +64,7 @@ from bson.codec_options import CodecOptions from pymongo.read_preferences import _ServerMode - from pymongo.synchronous.client_session import ClientSession + from pymongo.synchronous.client_session import SESSION, ClientSession from pymongo.synchronous.collection import Collection from pymongo.synchronous.pool import Connection @@ -136,9 +136,14 @@ def __init__( self._killed = False self._session: Optional[ClientSession] + _SESSION = SESSION.get() + if session: self._session = session self._explicit_session = True + elif _SESSION: + self._session = _SESSION + self._explicit_session = True else: self._session = None self._explicit_session = False diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 99a517e5c1..db881b1b5c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -107,7 +107,7 @@ from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk -from pymongo.synchronous.client_session import _EmptyServerSession +from pymongo.synchronous.client_session import SESSION, _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext @@ -1353,13 +1353,18 @@ def _close_cursor_soon( def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: server_session = _EmptyServerSession() opts = client_session.SessionOptions(**kwargs) - return client_session.ClientSession(self, server_session, opts, implicit) + bind = opts._bind + session = client_session.ClientSession(self, server_session, opts, implicit) + if bind: + SESSION.set(session) + return session def start_session( self, causal_consistency: Optional[bool] = None, default_transaction_options: Optional[client_session.TransactionOptions] = None, snapshot: Optional[bool] = False, + bind: Optional[bool] = False, ) -> client_session.ClientSession: """Start a logical session. @@ -1382,6 +1387,7 @@ def start_session( causal_consistency=causal_consistency, default_transaction_options=default_transaction_options, snapshot=snapshot, + bind=bind, ) def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]: diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 3655f49aab..243844c254 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -380,6 +380,13 @@ async def test_cursor_clone(self): clone = cursor.clone() self.assertTrue(clone.session is s) + # Explicit session via context variable. + async with self.client.start_session(bind=True) as s: + cursor = coll.find() + self.assertTrue(cursor.session is s) + clone = cursor.clone() + self.assertTrue(clone.session is s) + # No explicit session. cursor = coll.find(batch_size=2) await anext(cursor) diff --git a/test/test_session.py b/test/test_session.py index a6266884aa..8332920616 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -380,6 +380,13 @@ def test_cursor_clone(self): clone = cursor.clone() self.assertTrue(clone.session is s) + # Explicit session via context variable. + with self.client.start_session(bind=True) as s: + cursor = coll.find() + self.assertTrue(cursor.session is s) + clone = cursor.clone() + self.assertTrue(clone.session is s) + # No explicit session. cursor = coll.find(batch_size=2) next(cursor) From e09003555860488fe9d9a7c1c60cd22d2109a905 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Wed, 7 May 2025 20:17:32 -0400 Subject: [PATCH 2/9] Make session var private --- pymongo/asynchronous/client_session.py | 2 +- pymongo/asynchronous/cursor.py | 10 ++++++---- pymongo/asynchronous/mongo_client.py | 4 ++-- pymongo/synchronous/client_session.py | 2 +- pymongo/synchronous/cursor.py | 10 ++++++---- pymongo/synchronous/mongo_client.py | 4 ++-- test/asynchronous/test_session.py | 14 +++++++------- test/test_session.py | 14 +++++++------- 8 files changed, 32 insertions(+), 28 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 339a699f68..bd81873d93 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -1068,7 +1068,7 @@ def __copy__(self) -> NoReturn: raise TypeError("A AsyncClientSession cannot be copied, create a new session instead") -SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None) +_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None) class _EmptyServerSession: diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 1e683c269e..08b1895c1f 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -63,7 +63,7 @@ from _typeshed import SupportsItems from bson.codec_options import CodecOptions - from pymongo.asynchronous.client_session import SESSION, AsyncClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.pool import AsyncConnection from pymongo.read_preferences import _ServerMode @@ -136,13 +136,15 @@ def __init__( self._killed = False self._session: Optional[AsyncClientSession] - _SESSION = SESSION.get() + from .client_session import _SESSION + + bound_session = _SESSION.get() if session: self._session = session self._explicit_session = True - elif _SESSION: - self._session = _SESSION + elif bound_session: + self._session = bound_session self._explicit_session = True else: self._session = None diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 384f013e1e..4b2664ad7a 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -65,7 +65,7 @@ from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk -from pymongo.asynchronous.client_session import SESSION, _EmptyServerSession +from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext @@ -1358,7 +1358,7 @@ def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession: bind = opts._bind session = client_session.AsyncClientSession(self, server_session, opts, implicit) if bind: - SESSION.set(session) + _SESSION.set(session) return session def start_session( diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 30dfed9462..26636a9bc1 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -1063,7 +1063,7 @@ def __copy__(self) -> NoReturn: raise TypeError("A ClientSession cannot be copied, create a new session instead") -SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None) +_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None) class _EmptyServerSession: diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index c4ccb4b994..11f1327d53 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -64,7 +64,7 @@ from bson.codec_options import CodecOptions from pymongo.read_preferences import _ServerMode - from pymongo.synchronous.client_session import SESSION, ClientSession + from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.collection import Collection from pymongo.synchronous.pool import Connection @@ -136,13 +136,15 @@ def __init__( self._killed = False self._session: Optional[ClientSession] - _SESSION = SESSION.get() + from .client_session import _SESSION + + bound_session = _SESSION.get() if session: self._session = session self._explicit_session = True - elif _SESSION: - self._session = _SESSION + elif bound_session: + self._session = bound_session self._explicit_session = True else: self._session = None diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index db881b1b5c..c2fa1b01f9 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -107,7 +107,7 @@ from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk -from pymongo.synchronous.client_session import SESSION, _EmptyServerSession +from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext @@ -1356,7 +1356,7 @@ def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: bind = opts._bind session = client_session.ClientSession(self, server_session, opts, implicit) if bind: - SESSION.set(session) + _SESSION.set(session) return session def start_session( diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 243844c254..4c1c5fa44f 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -380,13 +380,6 @@ async def test_cursor_clone(self): clone = cursor.clone() self.assertTrue(clone.session is s) - # Explicit session via context variable. - async with self.client.start_session(bind=True) as s: - cursor = coll.find() - self.assertTrue(cursor.session is s) - clone = cursor.clone() - self.assertTrue(clone.session is s) - # No explicit session. cursor = coll.find(batch_size=2) await anext(cursor) @@ -401,6 +394,13 @@ async def test_cursor_clone(self): await cursor.close() await clone.close() + # Explicit session via context variable. + async with self.client.start_session(bind=True) as s: + cursor = coll.find() + self.assertTrue(cursor.session is s) + clone = cursor.clone() + self.assertTrue(clone.session is s) + async def test_cursor(self): listener = self.listener client = self.client diff --git a/test/test_session.py b/test/test_session.py index 8332920616..15342d033d 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -380,13 +380,6 @@ def test_cursor_clone(self): clone = cursor.clone() self.assertTrue(clone.session is s) - # Explicit session via context variable. - with self.client.start_session(bind=True) as s: - cursor = coll.find() - self.assertTrue(cursor.session is s) - clone = cursor.clone() - self.assertTrue(clone.session is s) - # No explicit session. cursor = coll.find(batch_size=2) next(cursor) @@ -401,6 +394,13 @@ def test_cursor_clone(self): cursor.close() clone.close() + # Explicit session via context variable. + with self.client.start_session(bind=True) as s: + cursor = coll.find() + self.assertTrue(cursor.session is s) + clone = cursor.clone() + self.assertTrue(clone.session is s) + def test_cursor(self): listener = self.listener client = self.client From de89b23058e24a5cb685126c1bb6b1f7abbac10f Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Thu, 8 May 2025 09:43:01 -0400 Subject: [PATCH 3/9] Add test for nested sessions. --- test/asynchronous/test_session.py | 14 ++++++++++++-- test/test_session.py | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 4c1c5fa44f..7a1272112c 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -394,12 +394,22 @@ async def test_cursor_clone(self): await cursor.close() await clone.close() + async def test_bind_session(self): + coll = self.client.pymongo_test.collection + # Explicit session via context variable. async with self.client.start_session(bind=True) as s: cursor = coll.find() self.assertTrue(cursor.session is s) - clone = cursor.clone() - self.assertTrue(clone.session is s) + + # Nested sessions. + session1 = self.client.start_session(bind=True) + with session1: + session2 = self.client.start_session(bind=True) + with session2: + coll.find_one() # uses session2 + coll.find_one() # uses session1 + coll.find_one() # uses implicit session async def test_cursor(self): listener = self.listener diff --git a/test/test_session.py b/test/test_session.py index 15342d033d..cbd78df1aa 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -394,12 +394,22 @@ def test_cursor_clone(self): cursor.close() clone.close() + def test_bind_session(self): + coll = self.client.pymongo_test.collection + # Explicit session via context variable. with self.client.start_session(bind=True) as s: cursor = coll.find() self.assertTrue(cursor.session is s) - clone = cursor.clone() - self.assertTrue(clone.session is s) + + # Nested sessions. + session1 = self.client.start_session(bind=True) + with session1: + session2 = self.client.start_session(bind=True) + with session2: + coll.find_one() # uses session2 + coll.find_one() # uses session1 + coll.find_one() # uses implicit session def test_cursor(self): listener = self.listener From 85add85da0bf388106e42af2d173cb640888ad79 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Thu, 8 May 2025 10:37:06 -0400 Subject: [PATCH 4/9] Add context manager --- pymongo/asynchronous/client_session.py | 21 ++++++++++++++++++++- pymongo/asynchronous/mongo_client.py | 4 ++-- pymongo/synchronous/client_session.py | 21 ++++++++++++++++++++- pymongo/synchronous/mongo_client.py | 4 ++-- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index bd81873d93..1d99b799f9 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -139,7 +139,8 @@ import time import uuid from collections.abc import Mapping as _Mapping -from contextvars import ContextVar +from contextlib import AbstractAsyncContextManager +from contextvars import ContextVar, Token from typing import ( TYPE_CHECKING, Any, @@ -1071,6 +1072,24 @@ def __copy__(self) -> NoReturn: _SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None) +class _BindSession(AbstractAsyncContextManager): + def __init__(self, session: AsyncClientSession) -> None: + self.session = session + self.token: Optional[Token[Optional[AsyncClientSession]]] = None + + async def __aenter__(self) -> None: + self.token = _SESSION.set(self.session) + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + if self.token is not None: + _SESSION.reset(self.token) + + class _EmptyServerSession: __slots__ = "dirty", "started_retryable_write" diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 4b2664ad7a..c83b25d0c3 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -65,7 +65,7 @@ from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk -from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession +from pymongo.asynchronous.client_session import _BindSession, _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext @@ -1358,7 +1358,7 @@ def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession: bind = opts._bind session = client_session.AsyncClientSession(self, server_session, opts, implicit) if bind: - _SESSION.set(session) + session = _BindSession(session) return session def start_session( diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 26636a9bc1..71e3d4236b 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -139,7 +139,8 @@ import time import uuid from collections.abc import Mapping as _Mapping -from contextvars import ContextVar +from contextlib import AbstractContextManager +from contextvars import ContextVar, Token from typing import ( TYPE_CHECKING, Any, @@ -1066,6 +1067,24 @@ def __copy__(self) -> NoReturn: _SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None) +class _BindSession(AbstractContextManager): + def __init__(self, session: ClientSession) -> None: + self.session = session + self.token: Optional[Token[Optional[ClientSession]]] = None + + def __enter__(self) -> None: + self.token = _SESSION.set(self.session) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + if self.token is not None: + _SESSION.reset(self.token) + + class _EmptyServerSession: __slots__ = "dirty", "started_retryable_write" diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index c2fa1b01f9..205597ec27 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -107,7 +107,7 @@ from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk -from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession +from pymongo.synchronous.client_session import _BindSession, _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext @@ -1356,7 +1356,7 @@ def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: bind = opts._bind session = client_session.ClientSession(self, server_session, opts, implicit) if bind: - _SESSION.set(session) + session = _BindSession(session) return session def start_session( From be283e858c1f79882ec852465e1e411c8b11c392 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Thu, 8 May 2025 14:17:04 -0400 Subject: [PATCH 5/9] Revert "Add context manager" This reverts commit 85add85da0bf388106e42af2d173cb640888ad79. --- pymongo/asynchronous/client_session.py | 21 +-------------------- pymongo/asynchronous/mongo_client.py | 4 ++-- pymongo/synchronous/client_session.py | 21 +-------------------- pymongo/synchronous/mongo_client.py | 4 ++-- 4 files changed, 6 insertions(+), 44 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 1d99b799f9..bd81873d93 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -139,8 +139,7 @@ import time import uuid from collections.abc import Mapping as _Mapping -from contextlib import AbstractAsyncContextManager -from contextvars import ContextVar, Token +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, @@ -1072,24 +1071,6 @@ def __copy__(self) -> NoReturn: _SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None) -class _BindSession(AbstractAsyncContextManager): - def __init__(self, session: AsyncClientSession) -> None: - self.session = session - self.token: Optional[Token[Optional[AsyncClientSession]]] = None - - async def __aenter__(self) -> None: - self.token = _SESSION.set(self.session) - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - if self.token is not None: - _SESSION.reset(self.token) - - class _EmptyServerSession: __slots__ = "dirty", "started_retryable_write" diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index c83b25d0c3..4b2664ad7a 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -65,7 +65,7 @@ from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk -from pymongo.asynchronous.client_session import _BindSession, _EmptyServerSession +from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext @@ -1358,7 +1358,7 @@ def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession: bind = opts._bind session = client_session.AsyncClientSession(self, server_session, opts, implicit) if bind: - session = _BindSession(session) + _SESSION.set(session) return session def start_session( diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 71e3d4236b..26636a9bc1 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -139,8 +139,7 @@ import time import uuid from collections.abc import Mapping as _Mapping -from contextlib import AbstractContextManager -from contextvars import ContextVar, Token +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, @@ -1067,24 +1066,6 @@ def __copy__(self) -> NoReturn: _SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None) -class _BindSession(AbstractContextManager): - def __init__(self, session: ClientSession) -> None: - self.session = session - self.token: Optional[Token[Optional[ClientSession]]] = None - - def __enter__(self) -> None: - self.token = _SESSION.set(self.session) - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - if self.token is not None: - _SESSION.reset(self.token) - - class _EmptyServerSession: __slots__ = "dirty", "started_retryable_write" diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 205597ec27..c2fa1b01f9 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -107,7 +107,7 @@ from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk -from pymongo.synchronous.client_session import _BindSession, _EmptyServerSession +from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext @@ -1356,7 +1356,7 @@ def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: bind = opts._bind session = client_session.ClientSession(self, server_session, opts, implicit) if bind: - session = _BindSession(session) + _SESSION.set(session) return session def start_session( From 89641f52564756b778babf7fa12dbd352decd1eb Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Thu, 8 May 2025 14:23:33 -0400 Subject: [PATCH 6/9] Fix session test --- pymongo/asynchronous/client_session.py | 1 + pymongo/synchronous/client_session.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index bd81873d93..fdd17ad5db 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -548,6 +548,7 @@ def _check_ended(self) -> None: raise InvalidOperation("Cannot use ended session") async def __aenter__(self) -> AsyncClientSession: + _SESSION.set(self) return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 26636a9bc1..e9d892ce63 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -547,6 +547,7 @@ def _check_ended(self) -> None: raise InvalidOperation("Cannot use ended session") def __enter__(self) -> ClientSession: + _SESSION.set(self) return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: From 657fcc208ca66acf0d4c4d96b5d155b7201c99fe Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Thu, 8 May 2025 15:45:31 -0400 Subject: [PATCH 7/9] Fix nested session test --- test/asynchronous/test_session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 7a1272112c..bffd3e4002 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -404,9 +404,9 @@ async def test_bind_session(self): # Nested sessions. session1 = self.client.start_session(bind=True) - with session1: + async with session1: session2 = self.client.start_session(bind=True) - with session2: + async with session2: coll.find_one() # uses session2 coll.find_one() # uses session1 coll.find_one() # uses implicit session From 8e008ed109a0dec7add5807469edbed00abb1d31 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Thu, 8 May 2025 15:49:50 -0400 Subject: [PATCH 8/9] Fix nested session test --- test/asynchronous/test_session.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index bffd3e4002..a12a29353d 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -407,9 +407,9 @@ async def test_bind_session(self): async with session1: session2 = self.client.start_session(bind=True) async with session2: - coll.find_one() # uses session2 - coll.find_one() # uses session1 - coll.find_one() # uses implicit session + await coll.find_one() # uses session2 + await coll.find_one() # uses session1 + await coll.find_one() # uses implicit session async def test_cursor(self): listener = self.listener From 0f9e93a992d13c7985c83d8934cb0d70f34fff5c Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Thu, 8 May 2025 18:10:48 -0400 Subject: [PATCH 9/9] Set token --- pymongo/asynchronous/client_session.py | 4 +++- pymongo/synchronous/client_session.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index fdd17ad5db..f72df467b6 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -548,10 +548,12 @@ def _check_ended(self) -> None: raise InvalidOperation("Cannot use ended session") async def __aenter__(self) -> AsyncClientSession: - _SESSION.set(self) + self._token = _SESSION.set(self) return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._token: + _SESSION.reset(self._token) await self._end_session(lock=True) @property diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index e9d892ce63..d17bcc0868 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -547,10 +547,12 @@ def _check_ended(self) -> None: raise InvalidOperation("Cannot use ended session") def __enter__(self) -> ClientSession: - _SESSION.set(self) + self._token = _SESSION.set(self) return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._token: + _SESSION.reset(self._token) self._end_session(lock=True) @property