From b93360180ded35144ec73df49af75e8e53666de0 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 3 Apr 2025 15:30:01 -0700 Subject: [PATCH 01/10] bulk_write should be able to accept a generator --- pymongo/asynchronous/bulk.py | 37 ++++++++++++++++++++++-------- pymongo/asynchronous/collection.py | 12 ++++------ pymongo/common.py | 8 +++++++ pymongo/synchronous/bulk.py | 37 ++++++++++++++++++++++-------- pymongo/synchronous/collection.py | 12 ++++------ test/asynchronous/test_bulk.py | 15 ++++++++++++ test/test_bulk.py | 15 ++++++++++++ 7 files changed, 100 insertions(+), 36 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index ac514db98f..b4b042a632 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -26,6 +26,7 @@ from typing import ( TYPE_CHECKING, Any, + Generator, Iterator, Mapping, Optional, @@ -72,7 +73,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: - from pymongo.asynchronous.collection import AsyncCollection + from pymongo.asynchronous.collection import AsyncCollection, _WriteOp from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline @@ -214,28 +215,45 @@ def add_delete( self.is_retryable = False self.ops.append((_DELETE, cmd)) - def gen_ordered(self) -> Iterator[Optional[_Run]]: + def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + try: + request._add_to_bulk(self) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) elif run.op_type != op_type: yield run run = _Run(op_type) run.add(idx, operation) + if run is None: + raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self) -> Iterator[_Run]: + def gen_unordered(self, requests) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + try: + request._add_to_bulk(self) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - + if ( + len(operations[_INSERT].ops) == 0 + and len(operations[_UPDATE].ops) == 0 + and len(operations[_DELETE].ops) == 0 + ): + raise InvalidOperation("No operations to execute") for run in operations: if run.ops: yield run @@ -726,13 +744,12 @@ async def execute_no_results( async def execute( self, + generator: Generator[_WriteOp[_DocumentType]], write_concern: WriteConcern, session: Optional[AsyncClientSession], operation: str, ) -> Any: """Execute operations.""" - if not self.ops: - raise InvalidOperation("No operations to execute") if self.executed: raise InvalidOperation("Bulk operations can only be executed once.") self.executed = True @@ -740,9 +757,9 @@ async def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered() + generator = self.gen_ordered(generator) else: - generator = self.gen_unordered() + generator = self.gen_unordered(generator) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 7fb20b7ab3..3286dea4d4 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -23,6 +23,7 @@ AsyncContextManager, Callable, Coroutine, + Generator, Generic, Iterable, Iterator, @@ -699,7 +700,7 @@ async def _create( @_csot.apply async def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]], + requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, @@ -779,17 +780,12 @@ async def bulk_write( .. versionadded:: 3.0 """ - common.validate_list("requests", requests) + common.validate_list_or_generator("requests", requests) blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let) - for request in requests: - try: - request._add_to_bulk(blk) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None write_concern = self._write_concern_for(session) - bulk_api_result = await blk.execute(write_concern, session, _Op.INSERT) + bulk_api_result = await blk.execute(requests, write_concern, session, _Op.INSERT) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) diff --git a/pymongo/common.py b/pymongo/common.py index 3d8095eedf..6d9bb2f37a 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -24,6 +24,7 @@ TYPE_CHECKING, Any, Callable, + Generator, Iterator, Mapping, MutableMapping, @@ -530,6 +531,13 @@ def validate_list(option: str, value: Any) -> list: return value +def validate_list_or_generator(option: str, value: Any) -> Union[list, Generator]: + """Validates that 'value' is a list or generator.""" + if isinstance(value, Generator): + return value + return validate_list(option, value) + + def validate_list_or_none(option: Any, value: Any) -> Optional[list]: """Validates that 'value' is a list or None.""" if value is None: diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index a528b09add..b92bb1a511 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -26,6 +26,7 @@ from typing import ( TYPE_CHECKING, Any, + Generator, Iterator, Mapping, Optional, @@ -72,7 +73,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: - from pymongo.synchronous.collection import Collection + from pymongo.synchronous.collection import Collection, _WriteOp from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline @@ -214,28 +215,45 @@ def add_delete( self.is_retryable = False self.ops.append((_DELETE, cmd)) - def gen_ordered(self) -> Iterator[Optional[_Run]]: + def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + try: + request._add_to_bulk(self) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) elif run.op_type != op_type: yield run run = _Run(op_type) run.add(idx, operation) + if run is None: + raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self) -> Iterator[_Run]: + def gen_unordered(self, requests) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + try: + request._add_to_bulk(self) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - + if ( + len(operations[_INSERT].ops) == 0 + and len(operations[_UPDATE].ops) == 0 + and len(operations[_DELETE].ops) == 0 + ): + raise InvalidOperation("No operations to execute") for run in operations: if run.ops: yield run @@ -724,13 +742,12 @@ def execute_no_results( def execute( self, + generator: Generator[_WriteOp[_DocumentType]], write_concern: WriteConcern, session: Optional[ClientSession], operation: str, ) -> Any: """Execute operations.""" - if not self.ops: - raise InvalidOperation("No operations to execute") if self.executed: raise InvalidOperation("Bulk operations can only be executed once.") self.executed = True @@ -738,9 +755,9 @@ def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered() + generator = self.gen_ordered(generator) else: - generator = self.gen_unordered() + generator = self.gen_unordered(generator) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 8a71768318..52c25de744 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -22,6 +22,7 @@ Any, Callable, ContextManager, + Generator, Generic, Iterable, Iterator, @@ -698,7 +699,7 @@ def _create( @_csot.apply def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]], + requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, @@ -778,17 +779,12 @@ def bulk_write( .. versionadded:: 3.0 """ - common.validate_list("requests", requests) + common.validate_list_or_generator("requests", requests) blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let) - for request in requests: - try: - request._add_to_bulk(blk) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None write_concern = self._write_concern_for(session) - bulk_api_result = blk.execute(write_concern, session, _Op.INSERT) + bulk_api_result = blk.execute(requests, write_concern, session, _Op.INSERT) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 65ed6e236a..3becea0777 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -299,6 +299,21 @@ async def test_numerous_inserts(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_numerous_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = await async_client_context.max_write_batch_size + 100 + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests, ordered=False) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + + # Same with ordered bulk. + await self.coll.drop() + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_bulk_max_message_size(self): await self.coll.delete_many({}) self.addAsyncCleanup(self.coll.delete_many, {}) diff --git a/test/test_bulk.py b/test/test_bulk.py index 8a863cc49b..3e631e661f 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -299,6 +299,21 @@ def test_numerous_inserts(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, self.coll.count_documents({})) + def test_numerous_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = client_context.max_write_batch_size + 100 + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests, ordered=False) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + + # Same with ordered bulk. + self.coll.drop() + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + def test_bulk_max_message_size(self): self.coll.delete_many({}) self.addCleanup(self.coll.delete_many, {}) From 0648bcfafc0c81d11af5df7d17c2d213fb733e43 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 21 Apr 2025 15:00:48 -0700 Subject: [PATCH 02/10] wip --- pymongo/asynchronous/bulk.py | 120 +++++++++++++++--------- pymongo/asynchronous/client_bulk.py | 13 +-- pymongo/asynchronous/client_session.py | 7 +- pymongo/asynchronous/collection.py | 58 +++++------- pymongo/asynchronous/mongo_client.py | 61 ++++++++----- pymongo/bulk_shared.py | 3 + pymongo/operations.py | 24 ++--- pymongo/synchronous/bulk.py | 122 ++++++++++++++++--------- pymongo/synchronous/client_bulk.py | 13 +-- pymongo/synchronous/client_session.py | 5 +- pymongo/synchronous/collection.py | 64 +++++-------- pymongo/synchronous/mongo_client.py | 59 +++++++----- test/asynchronous/test_bulk.py | 5 - test/test_bulk.py | 5 - 14 files changed, 307 insertions(+), 252 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index b4b042a632..630be7c25e 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -26,7 +26,8 @@ from typing import ( TYPE_CHECKING, Any, - Generator, + Callable, + Iterable, Iterator, Mapping, Optional, @@ -111,9 +112,6 @@ def __init__( self.uses_hint_update = False self.uses_hint_delete = False self.uses_sort = False - self.is_retryable = True - self.retrying = False - self.started_retryable_write = False # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None @@ -129,13 +127,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - def add_insert(self, document: _DocumentOut) -> None: + @property + def is_retryable(self) -> bool: + if self.current_run: + return self.current_run.is_retryable + return True + + @property + def retrying(self) -> bool: + if self.current_run: + return self.current_run.retrying + return False + + @property + def started_retryable_write(self) -> bool: + if self.current_run: + return self.current_run.started_retryable_write + return False + + def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() self.ops.append((_INSERT, document)) + return True def add_update( self, @@ -147,7 +164,7 @@ def add_update( array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi} @@ -165,10 +182,12 @@ def add_update( if sort is not None: self.uses_sort = True cmd["sort"] = sort + + self.ops.append((_UPDATE, cmd)) if multi: # A bulk_write containing an update_many is not retryable. - self.is_retryable = False - self.ops.append((_UPDATE, cmd)) + return False + return True def add_replace( self, @@ -178,7 +197,7 @@ def add_replace( collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) cmd: dict[str, Any] = {"q": selector, "u": replacement} @@ -194,6 +213,7 @@ def add_replace( self.uses_sort = True cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) + return True def add_delete( self, @@ -201,7 +221,7 @@ def add_delete( limit: int, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, - ) -> None: + ) -> bool: """Create a delete document and add it to the list of ops.""" cmd: dict[str, Any] = {"q": selector, "limit": limit} if collation is not None: @@ -210,21 +230,24 @@ def add_delete( if hint is not None: self.uses_hint_delete = True cmd["hint"] = hint + + self.ops.append((_DELETE, cmd)) if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. - self.is_retryable = False - self.ops.append((_DELETE, cmd)) + return False + return True - def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: + def gen_ordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None for idx, request in enumerate(requests): - try: - request._add_to_bulk(self) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None + retryable = process(request) (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) @@ -232,22 +255,25 @@ def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: yield run run = _Run(op_type) run.add(idx, operation) + run.is_retryable = run.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self, requests) -> Iterator[_Run]: + def gen_unordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] for idx, request in enumerate(requests): - try: - request._add_to_bulk(self) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None + retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) + operations[op_type].is_retryable = operations[op_type].is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -488,8 +514,8 @@ async def _execute_command( session: Optional[AsyncClientSession], conn: AsyncConnection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], + validate: bool, final_write_concern: Optional[WriteConcern] = None, ) -> None: db_name = self.collection.database.name @@ -507,7 +533,7 @@ async def _execute_command( last_run = False while run: - if not self.retrying: + if not run.retrying: self.next_run = next(generator, None) if self.next_run is None: last_run = True @@ -541,10 +567,10 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if run.is_retryable and not run.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, run.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -552,9 +578,10 @@ async def _execute_command( ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. + if validate: + await self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = await self._execute_batch(bwc, cmd, ops, client) - # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -567,8 +594,8 @@ async def _execute_command( _merge_command(run, full_result, run.idx_offset, result) # We're no longer in a retry once a command succeeds. - self.retrying = False - self.started_retryable_write = False + run.retrying = False + run.started_retryable_write = False if self.ordered and "writeErrors" in result: break @@ -606,7 +633,8 @@ async def execute_command( op_id = _randint() async def retryable_bulk( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool + session: Optional[AsyncClientSession], + conn: AsyncConnection, ) -> None: await self._execute_command( generator, @@ -614,26 +642,24 @@ async def retryable_bulk( session, conn, op_id, - retryable, full_result, + validate=False, ) client = self.collection.database.client _ = await client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, bulk=self, # type: ignore[arg-type] operation_id=op_id, ) - if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result async def execute_op_msg_no_results( - self, conn: AsyncConnection, generator: Iterator[Any] + self, conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name @@ -667,6 +693,7 @@ async def execute_op_msg_no_results( conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. + await self.validate_batch(conn, write_concern) to_send = await self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) @@ -700,12 +727,15 @@ async def execute_command_no_results( None, conn, op_id, - False, full_result, + True, write_concern, ) - except OperationFailure: - pass + except OperationFailure as exc: + if "Cannot set bypass_document_validation with unacknowledged write concern" in str( + exc + ): + raise exc async def execute_no_results( self, @@ -714,6 +744,11 @@ async def execute_no_results( write_concern: WriteConcern, ) -> None: """Execute all operations, returning no results (w=0).""" + if self.ordered: + return await self.execute_command_no_results(conn, generator, write_concern) + return await self.execute_op_msg_no_results(conn, generator, write_concern) + + async def validate_batch(self, conn: AsyncConnection, write_concern: WriteConcern) -> None: if self.uses_collation: raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: @@ -738,13 +773,10 @@ async def execute_no_results( "Cannot set bypass_document_validation with unacknowledged write concern" ) - if self.ordered: - return await self.execute_command_no_results(conn, generator, write_concern) - return await self.execute_op_msg_no_results(conn, generator) - async def execute( self, - generator: Generator[_WriteOp[_DocumentType]], + generator: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], write_concern: WriteConcern, session: Optional[AsyncClientSession], operation: str, @@ -757,9 +789,9 @@ async def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered(generator) + generator = self.gen_ordered(generator, process) else: - generator = self.gen_unordered(generator) + generator = self.gen_unordered(generator, process) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 5f7ac013e9..dbbad9e0e8 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -116,6 +116,7 @@ def __init__( self.is_retryable = self.client.options.retry_writes self.retrying = False self.started_retryable_write = False + self.current_run = None @property def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]: @@ -488,7 +489,6 @@ async def _execute_command( session: Optional[AsyncClientSession], conn: AsyncConnection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], final_write_concern: Optional[WriteConcern] = None, ) -> None: @@ -534,10 +534,10 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -564,7 +564,7 @@ async def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if retryable and (retryable_top_level_error or retryable_network_error): + if self.is_retryable and (retryable_top_level_error or retryable_network_error): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -583,7 +583,7 @@ async def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if retryable: + if self.is_retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -638,7 +638,6 @@ async def execute_command( async def retryable_bulk( session: Optional[AsyncClientSession], conn: AsyncConnection, - retryable: bool, ) -> None: if conn.max_wire_version < 25: raise InvalidOperation( @@ -649,12 +648,10 @@ async def retryable_bulk( session, conn, op_id, - retryable, full_result, ) await self.client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index b808684dd4..b9d8449a34 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -854,13 +854,12 @@ async def _finish_transaction_with_retry(self, command_name: str) -> dict[str, A """ async def func( - _session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable: bool + _session: Optional[AsyncClientSession], + conn: AsyncConnection, ) -> dict[str, Any]: return await self._finish_transaction(conn, command_name) - return await self._client._retry_internal( - func, self, None, retryable=True, operation=_Op.ABORT - ) + return await self._client._retry_internal(func, self, None, operation=_Op.ABORT) async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 3286dea4d4..5ee67ddf89 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -23,7 +23,6 @@ AsyncContextManager, Callable, Coroutine, - Generator, Generic, Iterable, Iterator, @@ -700,7 +699,7 @@ async def _create( @_csot.apply async def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]], + requests: Iterable[_WriteOp], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, @@ -785,7 +784,16 @@ async def bulk_write( blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let) write_concern = self._write_concern_for(session) - bulk_api_result = await blk.execute(requests, write_concern, session, _Op.INSERT) + + def process_for_bulk(request: _WriteOp) -> bool: + try: + return request._add_to_bulk(blk) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + + bulk_api_result = await blk.execute( + requests, process_for_bulk, write_concern, session, _Op.INSERT + ) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) @@ -802,17 +810,15 @@ async def _insert_one( ) -> Any: """Internal helper for inserting a single document.""" write_concern = write_concern or self.write_concern - acknowledged = write_concern.acknowledged command = {"insert": self.name, "ordered": ordered, "documents": [doc]} if comment is not None: command["comment"] = comment async def _insert_command( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> None: if bypass_doc_val is not None: command["bypassDocumentValidation"] = bypass_doc_val - result = await conn.command( self._database.name, command, @@ -820,14 +826,11 @@ async def _insert_command( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) - await self._database.client._retryable_write( - acknowledged, _insert_command, session, operation=_Op.INSERT - ) + await self._database.client._retryable_write(_insert_command, session, operation=_Op.INSERT) if not isinstance(doc, RawBSONDocument): return doc.get("_id") @@ -956,20 +959,19 @@ async def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" - for document in documents: - common.validate_is_document_type("document", document) - if not isinstance(document, RawBSONDocument): - if "_id" not in document: - document["_id"] = ObjectId() # type: ignore[index] - inserted_ids.append(document["_id"]) - yield (message._INSERT, document) + common.validate_is_document_type("document", document) + if not isinstance(document, RawBSONDocument): + if "_id" not in document: + document["_id"] = ObjectId() # type: ignore[index] + inserted_ids.append(document["_id"]) + blk.ops.append((message._INSERT, document)) + return True write_concern = self._write_concern_for(session) blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment) - blk.ops = list(gen()) - await blk.execute(write_concern, session, _Op.INSERT) + await blk.execute(documents, process_for_bulk, write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) async def _update( @@ -987,7 +989,6 @@ async def _update( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1050,7 +1051,6 @@ async def _update( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) ).copy() _check_write_command_response(result) @@ -1090,7 +1090,7 @@ async def _update_retryable( """Internal update / replace helper.""" async def _update( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> Optional[Mapping[str, Any]]: return await self._update( conn, @@ -1106,14 +1106,12 @@ async def _update( array_filters=array_filters, hint=hint, session=session, - retryable_write=retryable_write, let=let, sort=sort, comment=comment, ) return await self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _update, session, operation, @@ -1503,7 +1501,6 @@ async def _delete( collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: @@ -1543,7 +1540,6 @@ async def _delete( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) return result @@ -1564,7 +1560,7 @@ async def _delete_retryable( """Internal delete helper.""" async def _delete( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> Mapping[str, Any]: return await self._delete( conn, @@ -1576,13 +1572,11 @@ async def _delete( collation=collation, hint=hint, session=session, - retryable_write=retryable_write, let=let, comment=comment, ) return await self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _delete, session, operation=_Op.DELETE, @@ -3227,7 +3221,7 @@ async def _find_and_modify( write_concern = self._write_concern_for_cmd(cmd, session) async def _find_and_modify_helper( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: @@ -3253,7 +3247,6 @@ async def _find_and_modify_helper( write_concern=write_concern, collation=collation, session=session, - retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS, ) _check_write_command_response(out) @@ -3261,7 +3254,6 @@ async def _find_and_modify_helper( return out.get("value") return await self._database.client._retryable_write( - write_concern.acknowledged, _find_and_modify_helper, session, operation=_Op.FIND_AND_MODIFY, diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 16753420c0..4c8230cea9 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -142,9 +142,7 @@ T = TypeVar("T") -_WriteCall = Callable[ - [Optional["AsyncClientSession"], "AsyncConnection", bool], Coroutine[Any, Any, T] -] +_WriteCall = Callable[[Optional["AsyncClientSession"], "AsyncConnection"], Coroutine[Any, Any, T]] _ReadCall = Callable[ [Optional["AsyncClientSession"], "Server", "AsyncConnection", _ServerMode], Coroutine[Any, Any, T], @@ -1894,7 +1892,6 @@ async def _cmd( async def _retry_with_session( self, - retryable: bool, func: _WriteCall[T], session: Optional[AsyncClientSession], bulk: Optional[Union[_AsyncBulk, _AsyncClientBulk]], @@ -1910,15 +1907,11 @@ async def _retry_with_session( """ # Ensure that the options supports retry_writes and there is a valid session not in # transaction, otherwise, we will not support retry behavior for this txn. - retryable = bool( - retryable and self.options.retry_writes and session and not session.in_transaction - ) return await self._retry_internal( func=func, session=session, bulk=bulk, operation=operation, - retryable=retryable, operation_id=operation_id, ) @@ -1932,7 +1925,6 @@ async def _retry_internal( is_read: bool = False, address: Optional[_Address] = None, read_pref: Optional[_ServerMode] = None, - retryable: bool = False, operation_id: Optional[int] = None, ) -> T: """Internal retryable helper for all client transactions. @@ -1957,7 +1949,6 @@ async def _retry_internal( session=session, read_pref=read_pref, address=address, - retryable=retryable, operation_id=operation_id, ).run() @@ -2000,13 +1991,11 @@ async def _retryable_read( is_read=True, address=address, read_pref=read_pref, - retryable=retryable, operation_id=operation_id, ) async def _retryable_write( self, - retryable: bool, func: _WriteCall[T], session: Optional[AsyncClientSession], operation: str, @@ -2027,7 +2016,7 @@ async def _retryable_write( :param bulk: bulk abstraction to execute operations in bulk, defaults to None """ async with self._tmp_session(session) as s: - return await self._retry_with_session(retryable, func, s, bulk, operation, operation_id) + return await self._retry_with_session(func, s, bulk, operation, operation_id) def _cleanup_cursor_no_lock( self, @@ -2662,7 +2651,6 @@ def __init__( session: Optional[AsyncClientSession] = None, read_pref: Optional[_ServerMode] = None, address: Optional[_Address] = None, - retryable: bool = False, operation_id: Optional[int] = None, ): self._last_error: Optional[Exception] = None @@ -2674,7 +2662,7 @@ def __init__( self._bulk = bulk self._session = session self._is_read = is_read - self._retryable = retryable + self._retryable = True self._read_pref = read_pref self._server_selector: Callable[[Selection], Selection] = ( read_pref if is_read else writable_server_selector # type: ignore @@ -2685,6 +2673,11 @@ def __init__( self._operation = operation self._operation_id = operation_id + def _bulk_retryable(self) -> bool: + if self._bulk is not None and self._bulk.current_run is not None: + return self._bulk.current_run.is_retryable + return True + async def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2695,10 +2688,15 @@ async def run(self) -> T: # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. - if self._is_session_state_retryable() and self._retryable and not self._is_read: + if ( + self._is_session_state_retryable() + and self._retryable + and self._bulk_retryable() + and not self._is_read + ): self._session._start_retryable_write() # type: ignore - if self._bulk: - self._bulk.started_retryable_write = True + if self._bulk and self._bulk.current_run: + self._bulk.current_run.started_retryable_write = True while True: self._check_last_error(check_csot=True) @@ -2731,7 +2729,7 @@ async def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if not self._retryable and not self._bulk_retryable(): raise if isinstance(exc, ClientBulkWriteException) and exc.error: retryable_write_error_exc = isinstance( @@ -2748,7 +2746,10 @@ async def run(self) -> T: else: raise if self._bulk: - self._bulk.retrying = True + if self._bulk.current_run: + self._bulk.current_run.retrying = True + else: + self._bulk.retrying = True else: self._retrying = True if not exc.has_error_label("NoWritesPerformed"): @@ -2761,11 +2762,19 @@ async def run(self) -> T: def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" - return not self._retryable or (self._is_retrying() and not self._multiple_retries) + return ( + not self._retryable + or not self._bulk_retryable() + or (self._is_retrying() and not self._multiple_retries) + ) def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk else self._retrying + return ( + self._bulk.current_run.retrying + if self._bulk is not None and self._bulk.current_run is not None + else self._retrying + ) def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2825,9 +2834,11 @@ async def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False - return await self._func(self._session, conn, self._retryable) # type: ignore + if self._bulk and self._bulk.current_run: + self._bulk.current_run.is_retryable = False + return await self._func(self._session, conn) # type: ignore except PyMongoError as exc: - if not self._retryable: + if not self._retryable or not self._bulk_retryable(): raise # Add the RetryableWriteError label, if applicable. _add_retryable_write_error(exc, max_wire_version, is_mongos) @@ -2844,7 +2855,7 @@ async def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and (not self._retryable or not self._bulk_retryable()): self._check_last_error() return await self._func(self._session, self._server, conn, read_pref) # type: ignore diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py index 9276419d8a..b157edd2e2 100644 --- a/pymongo/bulk_shared.py +++ b/pymongo/bulk_shared.py @@ -50,6 +50,9 @@ def __init__(self, op_type: int) -> None: self.index_map: list[int] = [] self.ops: list[Any] = [] self.idx_offset: int = 0 + self.is_retryable = True + self.retrying = False + self.started_retryable_write = False def index(self, idx: int) -> int: """Get the original index of an operation in this run. diff --git a/pymongo/operations.py b/pymongo/operations.py index 300f1ba123..49b41ee614 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -106,9 +106,9 @@ def __init__(self, document: _DocumentType, namespace: Optional[str] = None) -> self._doc = document self._namespace = namespace - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_insert(self._doc) # type: ignore[arg-type] + return bulkobj.add_insert(self._doc) # type: ignore[arg-type] def _add_to_client_bulk(self, bulkobj: _AgnosticClientBulk) -> None: """Add this operation to the _AsyncClientBulk/_ClientBulk instance `bulkobj`.""" @@ -230,9 +230,9 @@ def __init__( """ super().__init__(filter, collation, hint, namespace) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_delete( + return bulkobj.add_delete( self._filter, 1, collation=validate_collation_or_none(self._collation), @@ -291,9 +291,9 @@ def __init__( """ super().__init__(filter, collation, hint, namespace) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_delete( + return bulkobj.add_delete( self._filter, 0, collation=validate_collation_or_none(self._collation), @@ -384,9 +384,9 @@ def __init__( self._collation = collation self._namespace = namespace - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_replace( + return bulkobj.add_replace( self._filter, self._doc, self._upsert, @@ -606,9 +606,9 @@ def __init__( """ super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, sort) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_update( + return bulkobj.add_update( self._filter, self._doc, False, @@ -687,9 +687,9 @@ def __init__( """ super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, None) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_update( + return bulkobj.add_update( self._filter, self._doc, True, diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index b92bb1a511..2734d8d3fc 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -26,7 +26,8 @@ from typing import ( TYPE_CHECKING, Any, - Generator, + Callable, + Iterable, Iterator, Mapping, Optional, @@ -111,9 +112,6 @@ def __init__( self.uses_hint_update = False self.uses_hint_delete = False self.uses_sort = False - self.is_retryable = True - self.retrying = False - self.started_retryable_write = False # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None @@ -129,13 +127,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - def add_insert(self, document: _DocumentOut) -> None: + @property + def is_retryable(self) -> bool: + if self.current_run: + return self.current_run.is_retryable + return True + + @property + def retrying(self) -> bool: + if self.current_run: + return self.current_run.retrying + return False + + @property + def started_retryable_write(self) -> bool: + if self.current_run: + return self.current_run.started_retryable_write + return False + + def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() self.ops.append((_INSERT, document)) + return True def add_update( self, @@ -147,7 +164,7 @@ def add_update( array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi} @@ -165,10 +182,12 @@ def add_update( if sort is not None: self.uses_sort = True cmd["sort"] = sort + + self.ops.append((_UPDATE, cmd)) if multi: # A bulk_write containing an update_many is not retryable. - self.is_retryable = False - self.ops.append((_UPDATE, cmd)) + return False + return True def add_replace( self, @@ -178,7 +197,7 @@ def add_replace( collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) cmd: dict[str, Any] = {"q": selector, "u": replacement} @@ -194,6 +213,7 @@ def add_replace( self.uses_sort = True cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) + return True def add_delete( self, @@ -201,7 +221,7 @@ def add_delete( limit: int, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, - ) -> None: + ) -> bool: """Create a delete document and add it to the list of ops.""" cmd: dict[str, Any] = {"q": selector, "limit": limit} if collation is not None: @@ -210,21 +230,24 @@ def add_delete( if hint is not None: self.uses_hint_delete = True cmd["hint"] = hint + + self.ops.append((_DELETE, cmd)) if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. - self.is_retryable = False - self.ops.append((_DELETE, cmd)) + return False + return True - def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: + def gen_ordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None for idx, request in enumerate(requests): - try: - request._add_to_bulk(self) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None + retryable = process(request) (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) @@ -232,22 +255,25 @@ def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: yield run run = _Run(op_type) run.add(idx, operation) + run.is_retryable = run.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self, requests) -> Iterator[_Run]: + def gen_unordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] for idx, request in enumerate(requests): - try: - request._add_to_bulk(self) - except AttributeError: - raise TypeError(f"{request!r} is not a valid request") from None + retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) + operations[op_type].is_retryable = operations[op_type].is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -488,8 +514,8 @@ def _execute_command( session: Optional[ClientSession], conn: Connection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], + validate: bool, final_write_concern: Optional[WriteConcern] = None, ) -> None: db_name = self.collection.database.name @@ -507,7 +533,7 @@ def _execute_command( last_run = False while run: - if not self.retrying: + if not run.retrying: self.next_run = next(generator, None) if self.next_run is None: last_run = True @@ -541,10 +567,10 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if run.is_retryable and not run.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, run.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -552,9 +578,10 @@ def _execute_command( ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. + if validate: + self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = self._execute_batch(bwc, cmd, ops, client) - # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -567,8 +594,8 @@ def _execute_command( _merge_command(run, full_result, run.idx_offset, result) # We're no longer in a retry once a command succeeds. - self.retrying = False - self.started_retryable_write = False + run.retrying = False + run.started_retryable_write = False if self.ordered and "writeErrors" in result: break @@ -606,7 +633,8 @@ def execute_command( op_id = _randint() def retryable_bulk( - session: Optional[ClientSession], conn: Connection, retryable: bool + session: Optional[ClientSession], + conn: Connection, ) -> None: self._execute_command( generator, @@ -614,25 +642,25 @@ def retryable_bulk( session, conn, op_id, - retryable, full_result, + validate=False, ) client = self.collection.database.client _ = client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, bulk=self, # type: ignore[arg-type] operation_id=op_id, ) - if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result - def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: + def execute_op_msg_no_results( + self, conn: Connection, generator: Iterator[Any], write_concern: WriteConcern + ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client @@ -665,6 +693,7 @@ def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. + self.validate_batch(conn, write_concern) to_send = self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) @@ -698,12 +727,15 @@ def execute_command_no_results( None, conn, op_id, - False, full_result, + True, write_concern, ) - except OperationFailure: - pass + except OperationFailure as exc: + if "Cannot set bypass_document_validation with unacknowledged write concern" in str( + exc + ): + raise exc def execute_no_results( self, @@ -712,6 +744,11 @@ def execute_no_results( write_concern: WriteConcern, ) -> None: """Execute all operations, returning no results (w=0).""" + if self.ordered: + return self.execute_command_no_results(conn, generator, write_concern) + return self.execute_op_msg_no_results(conn, generator, write_concern) + + def validate_batch(self, conn: Connection, write_concern: WriteConcern) -> None: if self.uses_collation: raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: @@ -736,13 +773,10 @@ def execute_no_results( "Cannot set bypass_document_validation with unacknowledged write concern" ) - if self.ordered: - return self.execute_command_no_results(conn, generator, write_concern) - return self.execute_op_msg_no_results(conn, generator) - def execute( self, - generator: Generator[_WriteOp[_DocumentType]], + generator: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], write_concern: WriteConcern, session: Optional[ClientSession], operation: str, @@ -755,9 +789,9 @@ def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered(generator) + generator = self.gen_ordered(generator, process) else: - generator = self.gen_unordered(generator) + generator = self.gen_unordered(generator, process) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index d73bfb2a2b..0b0d4190f9 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -116,6 +116,7 @@ def __init__( self.is_retryable = self.client.options.retry_writes self.retrying = False self.started_retryable_write = False + self.current_run = None @property def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]: @@ -486,7 +487,6 @@ def _execute_command( session: Optional[ClientSession], conn: Connection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], final_write_concern: Optional[WriteConcern] = None, ) -> None: @@ -532,10 +532,10 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -562,7 +562,7 @@ def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if retryable and (retryable_top_level_error or retryable_network_error): + if self.is_retryable and (retryable_top_level_error or retryable_network_error): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -581,7 +581,7 @@ def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if retryable: + if self.is_retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -636,7 +636,6 @@ def execute_command( def retryable_bulk( session: Optional[ClientSession], conn: Connection, - retryable: bool, ) -> None: if conn.max_wire_version < 25: raise InvalidOperation( @@ -647,12 +646,10 @@ def retryable_bulk( session, conn, op_id, - retryable, full_result, ) self.client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index aaf2d7574f..dc52a24911 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -851,11 +851,12 @@ def _finish_transaction_with_retry(self, command_name: str) -> dict[str, Any]: """ def func( - _session: Optional[ClientSession], conn: Connection, _retryable: bool + _session: Optional[ClientSession], + conn: Connection, ) -> dict[str, Any]: return self._finish_transaction(conn, command_name) - return self._client._retry_internal(func, self, None, retryable=True, operation=_Op.ABORT) + return self._client._retry_internal(func, self, None, operation=_Op.ABORT) def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 52c25de744..27b2a072d3 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -22,7 +22,6 @@ Any, Callable, ContextManager, - Generator, Generic, Iterable, Iterator, @@ -699,7 +698,7 @@ def _create( @_csot.apply def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]], + requests: Iterable[_WriteOp], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, @@ -784,7 +783,16 @@ def bulk_write( blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let) write_concern = self._write_concern_for(session) - bulk_api_result = blk.execute(requests, write_concern, session, _Op.INSERT) + + def process_for_bulk(request: _WriteOp) -> bool: + try: + return request._add_to_bulk(blk) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + + bulk_api_result = blk.execute( + requests, process_for_bulk, write_concern, session, _Op.INSERT + ) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) @@ -801,17 +809,13 @@ def _insert_one( ) -> Any: """Internal helper for inserting a single document.""" write_concern = write_concern or self.write_concern - acknowledged = write_concern.acknowledged command = {"insert": self.name, "ordered": ordered, "documents": [doc]} if comment is not None: command["comment"] = comment - def _insert_command( - session: Optional[ClientSession], conn: Connection, retryable_write: bool - ) -> None: + def _insert_command(session: Optional[ClientSession], conn: Connection) -> None: if bypass_doc_val is not None: command["bypassDocumentValidation"] = bypass_doc_val - result = conn.command( self._database.name, command, @@ -819,14 +823,11 @@ def _insert_command( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) - self._database.client._retryable_write( - acknowledged, _insert_command, session, operation=_Op.INSERT - ) + self._database.client._retryable_write(_insert_command, session, operation=_Op.INSERT) if not isinstance(doc, RawBSONDocument): return doc.get("_id") @@ -955,20 +956,19 @@ def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" - for document in documents: - common.validate_is_document_type("document", document) - if not isinstance(document, RawBSONDocument): - if "_id" not in document: - document["_id"] = ObjectId() # type: ignore[index] - inserted_ids.append(document["_id"]) - yield (message._INSERT, document) + common.validate_is_document_type("document", document) + if not isinstance(document, RawBSONDocument): + if "_id" not in document: + document["_id"] = ObjectId() # type: ignore[index] + inserted_ids.append(document["_id"]) + blk.ops.append((message._INSERT, document)) + return True write_concern = self._write_concern_for(session) blk = _Bulk(self, ordered, bypass_document_validation, comment=comment) - blk.ops = list(gen()) - blk.execute(write_concern, session, _Op.INSERT) + blk.execute(documents, process_for_bulk, write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) def _update( @@ -986,7 +986,6 @@ def _update( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1049,7 +1048,6 @@ def _update( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) ).copy() _check_write_command_response(result) @@ -1089,7 +1087,7 @@ def _update_retryable( """Internal update / replace helper.""" def _update( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[ClientSession], conn: Connection ) -> Optional[Mapping[str, Any]]: return self._update( conn, @@ -1105,14 +1103,12 @@ def _update( array_filters=array_filters, hint=hint, session=session, - retryable_write=retryable_write, let=let, sort=sort, comment=comment, ) return self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _update, session, operation, @@ -1502,7 +1498,6 @@ def _delete( collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: @@ -1542,7 +1537,6 @@ def _delete( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) return result @@ -1562,9 +1556,7 @@ def _delete_retryable( ) -> Mapping[str, Any]: """Internal delete helper.""" - def _delete( - session: Optional[ClientSession], conn: Connection, retryable_write: bool - ) -> Mapping[str, Any]: + def _delete(session: Optional[ClientSession], conn: Connection) -> Mapping[str, Any]: return self._delete( conn, criteria, @@ -1575,13 +1567,11 @@ def _delete( collation=collation, hint=hint, session=session, - retryable_write=retryable_write, let=let, comment=comment, ) return self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _delete, session, operation=_Op.DELETE, @@ -3219,9 +3209,7 @@ def _find_and_modify( write_concern = self._write_concern_for_cmd(cmd, session) - def _find_and_modify_helper( - session: Optional[ClientSession], conn: Connection, retryable_write: bool - ) -> Any: + def _find_and_modify_helper(session: Optional[ClientSession], conn: Connection) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: if not acknowledged: @@ -3246,7 +3234,6 @@ def _find_and_modify_helper( write_concern=write_concern, collation=collation, session=session, - retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS, ) _check_write_command_response(out) @@ -3254,7 +3241,6 @@ def _find_and_modify_helper( return out.get("value") return self._database.client._retryable_write( - write_concern.acknowledged, _find_and_modify_helper, session, operation=_Op.FIND_AND_MODIFY, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 2d8d6d730b..3c657c214c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -141,7 +141,7 @@ T = TypeVar("T") -_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T] +_WriteCall = Callable[[Optional["ClientSession"], "Connection"], T] _ReadCall = Callable[ [Optional["ClientSession"], "Server", "Connection", _ServerMode], T, @@ -1888,7 +1888,6 @@ def _cmd( def _retry_with_session( self, - retryable: bool, func: _WriteCall[T], session: Optional[ClientSession], bulk: Optional[Union[_Bulk, _ClientBulk]], @@ -1904,15 +1903,11 @@ def _retry_with_session( """ # Ensure that the options supports retry_writes and there is a valid session not in # transaction, otherwise, we will not support retry behavior for this txn. - retryable = bool( - retryable and self.options.retry_writes and session and not session.in_transaction - ) return self._retry_internal( func=func, session=session, bulk=bulk, operation=operation, - retryable=retryable, operation_id=operation_id, ) @@ -1926,7 +1921,6 @@ def _retry_internal( is_read: bool = False, address: Optional[_Address] = None, read_pref: Optional[_ServerMode] = None, - retryable: bool = False, operation_id: Optional[int] = None, ) -> T: """Internal retryable helper for all client transactions. @@ -1951,7 +1945,6 @@ def _retry_internal( session=session, read_pref=read_pref, address=address, - retryable=retryable, operation_id=operation_id, ).run() @@ -1994,13 +1987,11 @@ def _retryable_read( is_read=True, address=address, read_pref=read_pref, - retryable=retryable, operation_id=operation_id, ) def _retryable_write( self, - retryable: bool, func: _WriteCall[T], session: Optional[ClientSession], operation: str, @@ -2021,7 +2012,7 @@ def _retryable_write( :param bulk: bulk abstraction to execute operations in bulk, defaults to None """ with self._tmp_session(session) as s: - return self._retry_with_session(retryable, func, s, bulk, operation, operation_id) + return self._retry_with_session(func, s, bulk, operation, operation_id) def _cleanup_cursor_no_lock( self, @@ -2648,7 +2639,6 @@ def __init__( session: Optional[ClientSession] = None, read_pref: Optional[_ServerMode] = None, address: Optional[_Address] = None, - retryable: bool = False, operation_id: Optional[int] = None, ): self._last_error: Optional[Exception] = None @@ -2660,7 +2650,7 @@ def __init__( self._bulk = bulk self._session = session self._is_read = is_read - self._retryable = retryable + self._retryable = True self._read_pref = read_pref self._server_selector: Callable[[Selection], Selection] = ( read_pref if is_read else writable_server_selector # type: ignore @@ -2671,6 +2661,11 @@ def __init__( self._operation = operation self._operation_id = operation_id + def _bulk_retryable(self) -> bool: + if self._bulk is not None and self._bulk.current_run is not None: + return self._bulk.current_run.is_retryable + return True + def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2681,10 +2676,15 @@ def run(self) -> T: # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. - if self._is_session_state_retryable() and self._retryable and not self._is_read: + if ( + self._is_session_state_retryable() + and self._retryable + and self._bulk_retryable() + and not self._is_read + ): self._session._start_retryable_write() # type: ignore - if self._bulk: - self._bulk.started_retryable_write = True + if self._bulk and self._bulk.current_run: + self._bulk.current_run.started_retryable_write = True while True: self._check_last_error(check_csot=True) @@ -2717,7 +2717,7 @@ def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if not self._retryable and not self._bulk_retryable(): raise if isinstance(exc, ClientBulkWriteException) and exc.error: retryable_write_error_exc = isinstance( @@ -2734,7 +2734,10 @@ def run(self) -> T: else: raise if self._bulk: - self._bulk.retrying = True + if self._bulk.current_run: + self._bulk.current_run.retrying = True + else: + self._bulk.retrying = True else: self._retrying = True if not exc.has_error_label("NoWritesPerformed"): @@ -2747,11 +2750,19 @@ def run(self) -> T: def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" - return not self._retryable or (self._is_retrying() and not self._multiple_retries) + return ( + not self._retryable + or not self._bulk_retryable() + or (self._is_retrying() and not self._multiple_retries) + ) def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk else self._retrying + return ( + self._bulk.current_run.retrying + if self._bulk is not None and self._bulk.current_run is not None + else self._retrying + ) def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2811,9 +2822,11 @@ def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False - return self._func(self._session, conn, self._retryable) # type: ignore + if self._bulk and self._bulk.current_run: + self._bulk.current_run.is_retryable = False + return self._func(self._session, conn) # type: ignore except PyMongoError as exc: - if not self._retryable: + if not self._retryable or not self._bulk_retryable(): raise # Add the RetryableWriteError label, if applicable. _add_retryable_write_error(exc, max_wire_version, is_mongos) @@ -2830,7 +2843,7 @@ def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and (not self._retryable or not self._bulk_retryable()): self._check_last_error() return self._func(self._session, self._server, conn, read_pref) # type: ignore diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 3becea0777..4d2338eae2 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -353,11 +353,6 @@ async def test_bulk_write_no_results(self): self.assertRaises(InvalidOperation, lambda: result.upserted_ids) async def test_bulk_write_invalid_arguments(self): - # The requests argument must be a list. - generator = (InsertOne[dict]({}) for _ in range(10)) - with self.assertRaises(TypeError): - await self.coll.bulk_write(generator) # type: ignore[arg-type] - # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): await self.coll.bulk_write([{}]) # type: ignore[list-item] diff --git a/test/test_bulk.py b/test/test_bulk.py index 3e631e661f..9696f6da1d 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -353,11 +353,6 @@ def test_bulk_write_no_results(self): self.assertRaises(InvalidOperation, lambda: result.upserted_ids) def test_bulk_write_invalid_arguments(self): - # The requests argument must be a list. - generator = (InsertOne[dict]({}) for _ in range(10)) - with self.assertRaises(TypeError): - self.coll.bulk_write(generator) # type: ignore[arg-type] - # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): self.coll.bulk_write([{}]) # type: ignore[list-item] From a0ef0549a9e09ce351df7fa6a34afa09ed252a16 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 21 Apr 2025 15:31:17 -0700 Subject: [PATCH 03/10] retrying vars back in bulk --- pymongo/asynchronous/bulk.py | 47 +++++++++++++++------------- pymongo/asynchronous/mongo_client.py | 23 +++++--------- pymongo/bulk_shared.py | 3 -- pymongo/synchronous/bulk.py | 47 +++++++++++++++------------- pymongo/synchronous/mongo_client.py | 23 +++++--------- 5 files changed, 66 insertions(+), 77 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 630be7c25e..a98c2b99c1 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -112,6 +112,9 @@ def __init__( self.uses_hint_update = False self.uses_hint_delete = False self.uses_sort = False + self.is_retryable = True + self.retrying = False + self.started_retryable_write = False # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None @@ -127,23 +130,23 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - @property - def is_retryable(self) -> bool: - if self.current_run: - return self.current_run.is_retryable - return True - - @property - def retrying(self) -> bool: - if self.current_run: - return self.current_run.retrying - return False - - @property - def started_retryable_write(self) -> bool: - if self.current_run: - return self.current_run.started_retryable_write - return False + # @property + # def is_retryable(self) -> bool: + # if self.current_run: + # return self.current_run.is_retryable + # return True + # + # @property + # def retrying(self) -> bool: + # if self.current_run: + # return self.current_run.retrying + # return False + # + # @property + # def started_retryable_write(self) -> bool: + # if self.current_run: + # return self.current_run.started_retryable_write + # return False def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" @@ -255,7 +258,7 @@ def gen_ordered( yield run run = _Run(op_type) run.add(idx, operation) - run.is_retryable = run.is_retryable and retryable + self.is_retryable = self.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run @@ -273,7 +276,7 @@ def gen_unordered( retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - operations[op_type].is_retryable = operations[op_type].is_retryable and retryable + self.is_retryable = self.is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -533,7 +536,7 @@ async def _execute_command( last_run = False while run: - if not run.retrying: + if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: last_run = True @@ -567,10 +570,10 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if run.is_retryable and not run.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, run.is_retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 4c8230cea9..1675ce801d 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2674,8 +2674,8 @@ def __init__( self._operation_id = operation_id def _bulk_retryable(self) -> bool: - if self._bulk is not None and self._bulk.current_run is not None: - return self._bulk.current_run.is_retryable + if self._bulk is not None: + return self._bulk.is_retryable return True async def run(self) -> T: @@ -2695,8 +2695,8 @@ async def run(self) -> T: and not self._is_read ): self._session._start_retryable_write() # type: ignore - if self._bulk and self._bulk.current_run: - self._bulk.current_run.started_retryable_write = True + if self._bulk: + self._bulk.started_retryable_write = True while True: self._check_last_error(check_csot=True) @@ -2746,10 +2746,7 @@ async def run(self) -> T: else: raise if self._bulk: - if self._bulk.current_run: - self._bulk.current_run.retrying = True - else: - self._bulk.retrying = True + self._bulk.retrying = True else: self._retrying = True if not exc.has_error_label("NoWritesPerformed"): @@ -2770,11 +2767,7 @@ def _is_not_eligible_for_retry(self) -> bool: def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return ( - self._bulk.current_run.retrying - if self._bulk is not None and self._bulk.current_run is not None - else self._retrying - ) + return self._bulk.retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2834,8 +2827,8 @@ async def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False - if self._bulk and self._bulk.current_run: - self._bulk.current_run.is_retryable = False + if self._bulk: + self._bulk.is_retryable = False return await self._func(self._session, conn) # type: ignore except PyMongoError as exc: if not self._retryable or not self._bulk_retryable(): diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py index b157edd2e2..9276419d8a 100644 --- a/pymongo/bulk_shared.py +++ b/pymongo/bulk_shared.py @@ -50,9 +50,6 @@ def __init__(self, op_type: int) -> None: self.index_map: list[int] = [] self.ops: list[Any] = [] self.idx_offset: int = 0 - self.is_retryable = True - self.retrying = False - self.started_retryable_write = False def index(self, idx: int) -> int: """Get the original index of an operation in this run. diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 2734d8d3fc..c3323ed841 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -112,6 +112,9 @@ def __init__( self.uses_hint_update = False self.uses_hint_delete = False self.uses_sort = False + self.is_retryable = True + self.retrying = False + self.started_retryable_write = False # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None @@ -127,23 +130,23 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - @property - def is_retryable(self) -> bool: - if self.current_run: - return self.current_run.is_retryable - return True - - @property - def retrying(self) -> bool: - if self.current_run: - return self.current_run.retrying - return False - - @property - def started_retryable_write(self) -> bool: - if self.current_run: - return self.current_run.started_retryable_write - return False + # @property + # def is_retryable(self) -> bool: + # if self.current_run: + # return self.current_run.is_retryable + # return True + # + # @property + # def retrying(self) -> bool: + # if self.current_run: + # return self.current_run.retrying + # return False + # + # @property + # def started_retryable_write(self) -> bool: + # if self.current_run: + # return self.current_run.started_retryable_write + # return False def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" @@ -255,7 +258,7 @@ def gen_ordered( yield run run = _Run(op_type) run.add(idx, operation) - run.is_retryable = run.is_retryable and retryable + self.is_retryable = self.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run @@ -273,7 +276,7 @@ def gen_unordered( retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - operations[op_type].is_retryable = operations[op_type].is_retryable and retryable + self.is_retryable = self.is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -533,7 +536,7 @@ def _execute_command( last_run = False while run: - if not run.retrying: + if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: last_run = True @@ -567,10 +570,10 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if run.is_retryable and not run.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, run.is_retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 3c657c214c..695e9be8b1 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2662,8 +2662,8 @@ def __init__( self._operation_id = operation_id def _bulk_retryable(self) -> bool: - if self._bulk is not None and self._bulk.current_run is not None: - return self._bulk.current_run.is_retryable + if self._bulk is not None: + return self._bulk.is_retryable return True def run(self) -> T: @@ -2683,8 +2683,8 @@ def run(self) -> T: and not self._is_read ): self._session._start_retryable_write() # type: ignore - if self._bulk and self._bulk.current_run: - self._bulk.current_run.started_retryable_write = True + if self._bulk: + self._bulk.started_retryable_write = True while True: self._check_last_error(check_csot=True) @@ -2734,10 +2734,7 @@ def run(self) -> T: else: raise if self._bulk: - if self._bulk.current_run: - self._bulk.current_run.retrying = True - else: - self._bulk.retrying = True + self._bulk.retrying = True else: self._retrying = True if not exc.has_error_label("NoWritesPerformed"): @@ -2758,11 +2755,7 @@ def _is_not_eligible_for_retry(self) -> bool: def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return ( - self._bulk.current_run.retrying - if self._bulk is not None and self._bulk.current_run is not None - else self._retrying - ) + return self._bulk.retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2822,8 +2815,8 @@ def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False - if self._bulk and self._bulk.current_run: - self._bulk.current_run.is_retryable = False + if self._bulk: + self._bulk.is_retryable = False return self._func(self._session, conn) # type: ignore except PyMongoError as exc: if not self._retryable or not self._bulk_retryable(): From 19bcce4deaffe251d673c4b66acb4393747920f0 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 28 Apr 2025 16:30:45 -0700 Subject: [PATCH 04/10] bring back retryable var cuz default of false isn't great and move retryable var to bulk class to match client_bulk --- pymongo/asynchronous/bulk.py | 28 +++++++++++++-------- pymongo/asynchronous/client_bulk.py | 18 ++++++++++--- pymongo/asynchronous/client_session.py | 7 +++--- pymongo/asynchronous/collection.py | 29 +++++++++++++++------ pymongo/asynchronous/mongo_client.py | 18 ++++++++----- pymongo/bulk_shared.py | 3 +++ pymongo/synchronous/bulk.py | 28 +++++++++++++-------- pymongo/synchronous/client_bulk.py | 18 ++++++++++--- pymongo/synchronous/client_session.py | 5 ++-- pymongo/synchronous/collection.py | 35 ++++++++++++++++++++------ pymongo/synchronous/mongo_client.py | 16 +++++++----- 11 files changed, 145 insertions(+), 60 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index a98c2b99c1..8b287a8c83 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -258,7 +258,7 @@ def gen_ordered( yield run run = _Run(op_type) run.add(idx, operation) - self.is_retryable = self.is_retryable and retryable + run.is_retryable = run.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run @@ -276,7 +276,7 @@ def gen_unordered( retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - self.is_retryable = self.is_retryable and retryable + operations[op_type].is_retryable = operations[op_type].is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -517,6 +517,7 @@ async def _execute_command( session: Optional[AsyncClientSession], conn: AsyncConnection, op_id: int, + retryable: bool, full_result: MutableMapping[str, Any], validate: bool, final_write_concern: Optional[WriteConcern] = None, @@ -536,6 +537,9 @@ async def _execute_command( last_run = False while run: + self.is_retryable = run.is_retryable + self.retrying = run.retrying + self.started_retryable_write = run.started_retryable_write if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: @@ -570,10 +574,13 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if self.is_retryable and not self.started_retryable_write: + if retryable and self.is_retryable and not self.started_retryable_write: + # print("starting retrayable write") session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) + session._apply_to( + cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn + ) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -593,12 +600,10 @@ async def _execute_command( full = copy.deepcopy(full_result) _merge_command(run, full, run.idx_offset, result) _raise_bulk_write_error(full) - _merge_command(run, full_result, run.idx_offset, result) - # We're no longer in a retry once a command succeeds. - run.retrying = False - run.started_retryable_write = False + self.retrying = False + self.started_retryable_write = False if self.ordered and "writeErrors" in result: break @@ -636,8 +641,7 @@ async def execute_command( op_id = _randint() async def retryable_bulk( - session: Optional[AsyncClientSession], - conn: AsyncConnection, + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool ) -> None: await self._execute_command( generator, @@ -645,18 +649,21 @@ async def retryable_bulk( session, conn, op_id, + retryable, full_result, validate=False, ) client = self.collection.database.client _ = await client._retryable_write( + self.is_retryable, retryable_bulk, session, operation, bulk=self, # type: ignore[arg-type] operation_id=op_id, ) + if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result @@ -730,6 +737,7 @@ async def execute_command_no_results( None, conn, op_id, + False, full_result, True, write_concern, diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index dbbad9e0e8..9926a52fd0 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -489,6 +489,7 @@ async def _execute_command( session: Optional[AsyncClientSession], conn: AsyncConnection, op_id: int, + retryable: bool, full_result: MutableMapping[str, Any], final_write_concern: Optional[WriteConcern] = None, ) -> None: @@ -534,10 +535,12 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if self.is_retryable and not self.started_retryable_write: + if retryable and self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) + session._apply_to( + cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn + ) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -564,7 +567,11 @@ async def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if self.is_retryable and (retryable_top_level_error or retryable_network_error): + if ( + retryable + and self.is_retryable + and (retryable_top_level_error or retryable_network_error) + ): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -583,7 +590,7 @@ async def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if self.is_retryable: + if retryable and self.is_retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -638,6 +645,7 @@ async def execute_command( async def retryable_bulk( session: Optional[AsyncClientSession], conn: AsyncConnection, + retryable: bool, ) -> None: if conn.max_wire_version < 25: raise InvalidOperation( @@ -648,10 +656,12 @@ async def retryable_bulk( session, conn, op_id, + retryable, full_result, ) await self.client._retryable_write( + self.is_retryable, retryable_bulk, session, operation, diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index b9d8449a34..b808684dd4 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -854,12 +854,13 @@ async def _finish_transaction_with_retry(self, command_name: str) -> dict[str, A """ async def func( - _session: Optional[AsyncClientSession], - conn: AsyncConnection, + _session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable: bool ) -> dict[str, Any]: return await self._finish_transaction(conn, command_name) - return await self._client._retry_internal(func, self, None, operation=_Op.ABORT) + return await self._client._retry_internal( + func, self, None, retryable=True, operation=_Op.ABORT + ) async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 5ee67ddf89..b7cd20bf1c 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -785,7 +785,7 @@ async def bulk_write( write_concern = self._write_concern_for(session) - def process_for_bulk(request: _WriteOp) -> bool: + def process_for_bulk(request: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool: try: return request._add_to_bulk(blk) except AttributeError: @@ -810,15 +810,17 @@ async def _insert_one( ) -> Any: """Internal helper for inserting a single document.""" write_concern = write_concern or self.write_concern + acknowledged = write_concern.acknowledged command = {"insert": self.name, "ordered": ordered, "documents": [doc]} if comment is not None: command["comment"] = comment async def _insert_command( - session: Optional[AsyncClientSession], conn: AsyncConnection + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> None: if bypass_doc_val is not None: command["bypassDocumentValidation"] = bypass_doc_val + result = await conn.command( self._database.name, command, @@ -826,11 +828,14 @@ async def _insert_command( codec_options=self._write_response_codec_options, session=session, client=self._database.client, + retryable_write=retryable_write, ) _check_write_command_response(result) - await self._database.client._retryable_write(_insert_command, session, operation=_Op.INSERT) + await self._database.client._retryable_write( + acknowledged, _insert_command, session, operation=_Op.INSERT + ) if not isinstance(doc, RawBSONDocument): return doc.get("_id") @@ -959,7 +964,7 @@ async def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool: """A generator that validates documents and handles _ids.""" common.validate_is_document_type("document", document) if not isinstance(document, RawBSONDocument): @@ -989,6 +994,7 @@ async def _update( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, + retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1051,6 +1057,7 @@ async def _update( codec_options=self._write_response_codec_options, session=session, client=self._database.client, + retryable_write=retryable_write, ) ).copy() _check_write_command_response(result) @@ -1090,7 +1097,7 @@ async def _update_retryable( """Internal update / replace helper.""" async def _update( - session: Optional[AsyncClientSession], conn: AsyncConnection + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> Optional[Mapping[str, Any]]: return await self._update( conn, @@ -1106,12 +1113,14 @@ async def _update( array_filters=array_filters, hint=hint, session=session, + retryable_write=retryable_write, let=let, sort=sort, comment=comment, ) return await self._database.client._retryable_write( + (write_concern or self.write_concern).acknowledged and not multi, _update, session, operation, @@ -1501,6 +1510,7 @@ async def _delete( collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, + retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: @@ -1540,6 +1550,7 @@ async def _delete( codec_options=self._write_response_codec_options, session=session, client=self._database.client, + retryable_write=retryable_write, ) _check_write_command_response(result) return result @@ -1560,7 +1571,7 @@ async def _delete_retryable( """Internal delete helper.""" async def _delete( - session: Optional[AsyncClientSession], conn: AsyncConnection + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> Mapping[str, Any]: return await self._delete( conn, @@ -1572,11 +1583,13 @@ async def _delete( collation=collation, hint=hint, session=session, + retryable_write=retryable_write, let=let, comment=comment, ) return await self._database.client._retryable_write( + (write_concern or self.write_concern).acknowledged and not multi, _delete, session, operation=_Op.DELETE, @@ -3221,7 +3234,7 @@ async def _find_and_modify( write_concern = self._write_concern_for_cmd(cmd, session) async def _find_and_modify_helper( - session: Optional[AsyncClientSession], conn: AsyncConnection + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: @@ -3247,6 +3260,7 @@ async def _find_and_modify_helper( write_concern=write_concern, collation=collation, session=session, + retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS, ) _check_write_command_response(out) @@ -3254,6 +3268,7 @@ async def _find_and_modify_helper( return out.get("value") return await self._database.client._retryable_write( + write_concern.acknowledged, _find_and_modify_helper, session, operation=_Op.FIND_AND_MODIFY, diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 92918981ac..fa88cca779 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -149,7 +149,9 @@ T = TypeVar("T") -_WriteCall = Callable[[Optional["AsyncClientSession"], "AsyncConnection"], Coroutine[Any, Any, T]] +_WriteCall = Callable[ + [Optional["AsyncClientSession"], "AsyncConnection", bool], Coroutine[Any, Any, T] +] _ReadCall = Callable[ [Optional["AsyncClientSession"], "Server", "AsyncConnection", _ServerMode], Coroutine[Any, Any, T], @@ -1929,6 +1931,7 @@ async def _cmd( async def _retry_with_session( self, + retryable: bool, func: _WriteCall[T], session: Optional[AsyncClientSession], bulk: Optional[Union[_AsyncBulk, _AsyncClientBulk]], @@ -1949,6 +1952,7 @@ async def _retry_with_session( session=session, bulk=bulk, operation=operation, + retryable=retryable, operation_id=operation_id, ) @@ -1962,6 +1966,7 @@ async def _retry_internal( is_read: bool = False, address: Optional[_Address] = None, read_pref: Optional[_ServerMode] = None, + retryable: bool = False, operation_id: Optional[int] = None, ) -> T: """Internal retryable helper for all client transactions. @@ -1986,6 +1991,7 @@ async def _retry_internal( session=session, read_pref=read_pref, address=address, + retryable=retryable, operation_id=operation_id, ).run() @@ -2028,11 +2034,13 @@ async def _retryable_read( is_read=True, address=address, read_pref=read_pref, + retryable=retryable, operation_id=operation_id, ) async def _retryable_write( self, + retryable: bool, func: _WriteCall[T], session: Optional[AsyncClientSession], operation: str, @@ -2053,7 +2061,7 @@ async def _retryable_write( :param bulk: bulk abstraction to execute operations in bulk, defaults to None """ async with self._tmp_session(session) as s: - return await self._retry_with_session(func, s, bulk, operation, operation_id) + return await self._retry_with_session(retryable, func, s, bulk, operation, operation_id) def _cleanup_cursor_no_lock( self, @@ -2736,7 +2744,6 @@ async def run(self) -> T: self._session._start_retryable_write() # type: ignore if self._bulk: self._bulk.started_retryable_write = True - while True: self._check_last_error(check_csot=True) try: @@ -2766,7 +2773,6 @@ async def run(self) -> T: self._attempt_number += 1 else: raise - # Specialized catch on write operation if not self._is_read: if not self._retryable and not self._bulk_retryable(): @@ -2808,7 +2814,7 @@ def _is_not_eligible_for_retry(self) -> bool: def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk is not None else self._retrying + return self._bulk.retrying or self._retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2878,7 +2884,7 @@ async def _write(self) -> T: commandName=self._operation, operationId=self._operation_id, ) - return await self._func(self._session, conn) # type: ignore + return await self._func(self._session, conn, self._retryable) # type: ignore except PyMongoError as exc: if not self._retryable or not self._bulk_retryable(): raise diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py index 9276419d8a..b157edd2e2 100644 --- a/pymongo/bulk_shared.py +++ b/pymongo/bulk_shared.py @@ -50,6 +50,9 @@ def __init__(self, op_type: int) -> None: self.index_map: list[int] = [] self.ops: list[Any] = [] self.idx_offset: int = 0 + self.is_retryable = True + self.retrying = False + self.started_retryable_write = False def index(self, idx: int) -> int: """Get the original index of an operation in this run. diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index c3323ed841..7867a6eab0 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -258,7 +258,7 @@ def gen_ordered( yield run run = _Run(op_type) run.add(idx, operation) - self.is_retryable = self.is_retryable and retryable + run.is_retryable = run.is_retryable and retryable if run is None: raise InvalidOperation("No operations to execute") yield run @@ -276,7 +276,7 @@ def gen_unordered( retryable = process(request) (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - self.is_retryable = self.is_retryable and retryable + operations[op_type].is_retryable = operations[op_type].is_retryable and retryable if ( len(operations[_INSERT].ops) == 0 and len(operations[_UPDATE].ops) == 0 @@ -517,6 +517,7 @@ def _execute_command( session: Optional[ClientSession], conn: Connection, op_id: int, + retryable: bool, full_result: MutableMapping[str, Any], validate: bool, final_write_concern: Optional[WriteConcern] = None, @@ -536,6 +537,9 @@ def _execute_command( last_run = False while run: + self.is_retryable = run.is_retryable + self.retrying = run.retrying + self.started_retryable_write = run.started_retryable_write if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: @@ -570,10 +574,13 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if self.is_retryable and not self.started_retryable_write: + if retryable and self.is_retryable and not self.started_retryable_write: + # print("starting retrayable write") session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) + session._apply_to( + cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn + ) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -593,12 +600,10 @@ def _execute_command( full = copy.deepcopy(full_result) _merge_command(run, full, run.idx_offset, result) _raise_bulk_write_error(full) - _merge_command(run, full_result, run.idx_offset, result) - # We're no longer in a retry once a command succeeds. - run.retrying = False - run.started_retryable_write = False + self.retrying = False + self.started_retryable_write = False if self.ordered and "writeErrors" in result: break @@ -636,8 +641,7 @@ def execute_command( op_id = _randint() def retryable_bulk( - session: Optional[ClientSession], - conn: Connection, + session: Optional[ClientSession], conn: Connection, retryable: bool ) -> None: self._execute_command( generator, @@ -645,18 +649,21 @@ def retryable_bulk( session, conn, op_id, + retryable, full_result, validate=False, ) client = self.collection.database.client _ = client._retryable_write( + self.is_retryable, retryable_bulk, session, operation, bulk=self, # type: ignore[arg-type] operation_id=op_id, ) + if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result @@ -730,6 +737,7 @@ def execute_command_no_results( None, conn, op_id, + False, full_result, True, write_concern, diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 0b0d4190f9..bd25d042b4 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -487,6 +487,7 @@ def _execute_command( session: Optional[ClientSession], conn: Connection, op_id: int, + retryable: bool, full_result: MutableMapping[str, Any], final_write_concern: Optional[WriteConcern] = None, ) -> None: @@ -532,10 +533,12 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if self.is_retryable and not self.started_retryable_write: + if retryable and self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) + session._apply_to( + cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn + ) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -562,7 +565,11 @@ def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if self.is_retryable and (retryable_top_level_error or retryable_network_error): + if ( + retryable + and self.is_retryable + and (retryable_top_level_error or retryable_network_error) + ): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -581,7 +588,7 @@ def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if self.is_retryable: + if retryable and self.is_retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -636,6 +643,7 @@ def execute_command( def retryable_bulk( session: Optional[ClientSession], conn: Connection, + retryable: bool, ) -> None: if conn.max_wire_version < 25: raise InvalidOperation( @@ -646,10 +654,12 @@ def retryable_bulk( session, conn, op_id, + retryable, full_result, ) self.client._retryable_write( + self.is_retryable, retryable_bulk, session, operation, diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index dc52a24911..aaf2d7574f 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -851,12 +851,11 @@ def _finish_transaction_with_retry(self, command_name: str) -> dict[str, Any]: """ def func( - _session: Optional[ClientSession], - conn: Connection, + _session: Optional[ClientSession], conn: Connection, _retryable: bool ) -> dict[str, Any]: return self._finish_transaction(conn, command_name) - return self._client._retry_internal(func, self, None, operation=_Op.ABORT) + return self._client._retry_internal(func, self, None, retryable=True, operation=_Op.ABORT) def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 27b2a072d3..8154633f32 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -784,7 +784,7 @@ def bulk_write( write_concern = self._write_concern_for(session) - def process_for_bulk(request: _WriteOp) -> bool: + def process_for_bulk(request: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool: try: return request._add_to_bulk(blk) except AttributeError: @@ -809,13 +809,17 @@ def _insert_one( ) -> Any: """Internal helper for inserting a single document.""" write_concern = write_concern or self.write_concern + acknowledged = write_concern.acknowledged command = {"insert": self.name, "ordered": ordered, "documents": [doc]} if comment is not None: command["comment"] = comment - def _insert_command(session: Optional[ClientSession], conn: Connection) -> None: + def _insert_command( + session: Optional[ClientSession], conn: Connection, retryable_write: bool + ) -> None: if bypass_doc_val is not None: command["bypassDocumentValidation"] = bypass_doc_val + result = conn.command( self._database.name, command, @@ -823,11 +827,14 @@ def _insert_command(session: Optional[ClientSession], conn: Connection) -> None: codec_options=self._write_response_codec_options, session=session, client=self._database.client, + retryable_write=retryable_write, ) _check_write_command_response(result) - self._database.client._retryable_write(_insert_command, session, operation=_Op.INSERT) + self._database.client._retryable_write( + acknowledged, _insert_command, session, operation=_Op.INSERT + ) if not isinstance(doc, RawBSONDocument): return doc.get("_id") @@ -956,7 +963,7 @@ def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool: """A generator that validates documents and handles _ids.""" common.validate_is_document_type("document", document) if not isinstance(document, RawBSONDocument): @@ -986,6 +993,7 @@ def _update( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, + retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1048,6 +1056,7 @@ def _update( codec_options=self._write_response_codec_options, session=session, client=self._database.client, + retryable_write=retryable_write, ) ).copy() _check_write_command_response(result) @@ -1087,7 +1096,7 @@ def _update_retryable( """Internal update / replace helper.""" def _update( - session: Optional[ClientSession], conn: Connection + session: Optional[ClientSession], conn: Connection, retryable_write: bool ) -> Optional[Mapping[str, Any]]: return self._update( conn, @@ -1103,12 +1112,14 @@ def _update( array_filters=array_filters, hint=hint, session=session, + retryable_write=retryable_write, let=let, sort=sort, comment=comment, ) return self._database.client._retryable_write( + (write_concern or self.write_concern).acknowledged and not multi, _update, session, operation, @@ -1498,6 +1509,7 @@ def _delete( collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, + retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: @@ -1537,6 +1549,7 @@ def _delete( codec_options=self._write_response_codec_options, session=session, client=self._database.client, + retryable_write=retryable_write, ) _check_write_command_response(result) return result @@ -1556,7 +1569,9 @@ def _delete_retryable( ) -> Mapping[str, Any]: """Internal delete helper.""" - def _delete(session: Optional[ClientSession], conn: Connection) -> Mapping[str, Any]: + def _delete( + session: Optional[ClientSession], conn: Connection, retryable_write: bool + ) -> Mapping[str, Any]: return self._delete( conn, criteria, @@ -1567,11 +1582,13 @@ def _delete(session: Optional[ClientSession], conn: Connection) -> Mapping[str, collation=collation, hint=hint, session=session, + retryable_write=retryable_write, let=let, comment=comment, ) return self._database.client._retryable_write( + (write_concern or self.write_concern).acknowledged and not multi, _delete, session, operation=_Op.DELETE, @@ -3209,7 +3226,9 @@ def _find_and_modify( write_concern = self._write_concern_for_cmd(cmd, session) - def _find_and_modify_helper(session: Optional[ClientSession], conn: Connection) -> Any: + def _find_and_modify_helper( + session: Optional[ClientSession], conn: Connection, retryable_write: bool + ) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: if not acknowledged: @@ -3234,6 +3253,7 @@ def _find_and_modify_helper(session: Optional[ClientSession], conn: Connection) write_concern=write_concern, collation=collation, session=session, + retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS, ) _check_write_command_response(out) @@ -3241,6 +3261,7 @@ def _find_and_modify_helper(session: Optional[ClientSession], conn: Connection) return out.get("value") return self._database.client._retryable_write( + write_concern.acknowledged, _find_and_modify_helper, session, operation=_Op.FIND_AND_MODIFY, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index a8ab74c63a..089ab33477 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -148,7 +148,7 @@ T = TypeVar("T") -_WriteCall = Callable[[Optional["ClientSession"], "Connection"], T] +_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T] _ReadCall = Callable[ [Optional["ClientSession"], "Server", "Connection", _ServerMode], T, @@ -1925,6 +1925,7 @@ def _cmd( def _retry_with_session( self, + retryable: bool, func: _WriteCall[T], session: Optional[ClientSession], bulk: Optional[Union[_Bulk, _ClientBulk]], @@ -1945,6 +1946,7 @@ def _retry_with_session( session=session, bulk=bulk, operation=operation, + retryable=retryable, operation_id=operation_id, ) @@ -1958,6 +1960,7 @@ def _retry_internal( is_read: bool = False, address: Optional[_Address] = None, read_pref: Optional[_ServerMode] = None, + retryable: bool = False, operation_id: Optional[int] = None, ) -> T: """Internal retryable helper for all client transactions. @@ -1982,6 +1985,7 @@ def _retry_internal( session=session, read_pref=read_pref, address=address, + retryable=retryable, operation_id=operation_id, ).run() @@ -2024,11 +2028,13 @@ def _retryable_read( is_read=True, address=address, read_pref=read_pref, + retryable=retryable, operation_id=operation_id, ) def _retryable_write( self, + retryable: bool, func: _WriteCall[T], session: Optional[ClientSession], operation: str, @@ -2049,7 +2055,7 @@ def _retryable_write( :param bulk: bulk abstraction to execute operations in bulk, defaults to None """ with self._tmp_session(session) as s: - return self._retry_with_session(func, s, bulk, operation, operation_id) + return self._retry_with_session(retryable, func, s, bulk, operation, operation_id) def _cleanup_cursor_no_lock( self, @@ -2724,7 +2730,6 @@ def run(self) -> T: self._session._start_retryable_write() # type: ignore if self._bulk: self._bulk.started_retryable_write = True - while True: self._check_last_error(check_csot=True) try: @@ -2754,7 +2759,6 @@ def run(self) -> T: self._attempt_number += 1 else: raise - # Specialized catch on write operation if not self._is_read: if not self._retryable and not self._bulk_retryable(): @@ -2796,7 +2800,7 @@ def _is_not_eligible_for_retry(self) -> bool: def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk is not None else self._retrying + return self._bulk.retrying or self._retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2866,7 +2870,7 @@ def _write(self) -> T: commandName=self._operation, operationId=self._operation_id, ) - return self._func(self._session, conn) # type: ignore + return self._func(self._session, conn, self._retryable) # type: ignore except PyMongoError as exc: if not self._retryable or not self._bulk_retryable(): raise From 9485e429f11490a49100a13c2433cf2499b1d577 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 28 Apr 2025 19:18:32 -0700 Subject: [PATCH 05/10] fix typing -- i was being so silly earlier --- pymongo/asynchronous/bulk.py | 12 +++++++++--- pymongo/asynchronous/collection.py | 4 ++-- pymongo/synchronous/bulk.py | 12 +++++++++--- pymongo/synchronous/collection.py | 4 ++-- test/test_typing.py | 16 ++++------------ 5 files changed, 26 insertions(+), 22 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 8b287a8c83..7a75297e87 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -243,7 +243,9 @@ def add_delete( def gen_ordered( self, requests: Iterable[Any], - process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. @@ -266,7 +268,9 @@ def gen_ordered( def gen_unordered( self, requests: Iterable[Any], - process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. @@ -787,7 +791,9 @@ async def validate_batch(self, conn: AsyncConnection, write_concern: WriteConcer async def execute( self, generator: Iterable[Any], - process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], write_concern: WriteConcern, session: Optional[AsyncClientSession], operation: str, diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index b7cd20bf1c..32811663c9 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -785,7 +785,7 @@ async def bulk_write( write_concern = self._write_concern_for(session) - def process_for_bulk(request: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool: + def process_for_bulk(request: _WriteOp) -> bool: try: return request._add_to_bulk(blk) except AttributeError: @@ -964,7 +964,7 @@ async def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def process_for_bulk(document: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" common.validate_is_document_type("document", document) if not isinstance(document, RawBSONDocument): diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 7867a6eab0..dc21773eb9 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -243,7 +243,9 @@ def add_delete( def gen_ordered( self, requests: Iterable[Any], - process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. @@ -266,7 +268,9 @@ def gen_ordered( def gen_unordered( self, requests: Iterable[Any], - process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. @@ -787,7 +791,9 @@ def validate_batch(self, conn: Connection, write_concern: WriteConcern) -> None: def execute( self, generator: Iterable[Any], - process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], write_concern: WriteConcern, session: Optional[ClientSession], operation: str, diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 8154633f32..2a91412e33 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -784,7 +784,7 @@ def bulk_write( write_concern = self._write_concern_for(session) - def process_for_bulk(request: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool: + def process_for_bulk(request: _WriteOp) -> bool: try: return request._add_to_bulk(blk) except AttributeError: @@ -963,7 +963,7 @@ def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def process_for_bulk(document: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" common.validate_is_document_type("document", document) if not isinstance(document, RawBSONDocument): diff --git a/test/test_typing.py b/test/test_typing.py index 65937020d2..f1b6a59e49 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -428,13 +428,9 @@ def test_typeddict_document_type_insertion(self) -> None: def test_bulk_write_document_type_insertion(self): client: MongoClient[MovieWithId] = MongoClient() coll: Collection[MovieWithId] = client.test.test - coll.bulk_write( - [InsertOne(Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type] - ) + coll.bulk_write([InsertOne(Movie({"name": "THX-1138", "year": 1971}))]) mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971} - coll.bulk_write( - [InsertOne(mov_dict)] # type:ignore[arg-type] - ) + coll.bulk_write([InsertOne(mov_dict)]) coll.bulk_write( [ InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore @@ -445,13 +441,9 @@ def test_bulk_write_document_type_insertion(self): def test_bulk_write_document_type_replacement(self): client: MongoClient[MovieWithId] = MongoClient() coll: Collection[MovieWithId] = client.test.test - coll.bulk_write( - [ReplaceOne({}, Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type] - ) + coll.bulk_write([ReplaceOne({}, Movie({"name": "THX-1138", "year": 1971}))]) mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971} - coll.bulk_write( - [ReplaceOne({}, mov_dict)] # type:ignore[arg-type] - ) + coll.bulk_write([ReplaceOne({}, mov_dict)]) coll.bulk_write( [ ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore From 2de0d88b9410c318e80dca4d9adbc4e524ef5228 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 30 Apr 2025 13:17:10 -0700 Subject: [PATCH 06/10] remove unintended changes and fix retryable? --- pymongo/asynchronous/bulk.py | 22 +++------------------- pymongo/asynchronous/client_bulk.py | 7 ++++++- pymongo/asynchronous/mongo_client.py | 3 +++ pymongo/synchronous/bulk.py | 22 +++------------------- pymongo/synchronous/client_bulk.py | 7 ++++++- pymongo/synchronous/mongo_client.py | 3 +++ 6 files changed, 24 insertions(+), 40 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 7a75297e87..ae1d492821 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -130,24 +130,6 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - # @property - # def is_retryable(self) -> bool: - # if self.current_run: - # return self.current_run.is_retryable - # return True - # - # @property - # def retrying(self) -> bool: - # if self.current_run: - # return self.current_run.retrying - # return False - # - # @property - # def started_retryable_write(self) -> bool: - # if self.current_run: - # return self.current_run.started_retryable_write - # return False - def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) @@ -579,7 +561,6 @@ async def _execute_command( # Start a new retryable write unless one was already # started for this command. if retryable and self.is_retryable and not self.started_retryable_write: - # print("starting retrayable write") session._start_retryable_write() self.started_retryable_write = True session._apply_to( @@ -596,6 +577,7 @@ async def _execute_command( await self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = await self._execute_batch(bwc, cmd, ops, client) + # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -604,7 +586,9 @@ async def _execute_command( full = copy.deepcopy(full_result) _merge_command(run, full, run.idx_offset, result) _raise_bulk_write_error(full) + _merge_command(run, full_result, run.idx_offset, result) + # We're no longer in a retry once a command succeeds. self.retrying = False self.started_retryable_write = False diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 9926a52fd0..ad0774ff65 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -535,7 +535,12 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and self.is_retryable and not self.started_retryable_write: + if ( + retryable + and self.is_retryable + and not self.started_retryable_write + and not session.in_transaction + ): session._start_retryable_write() self.started_retryable_write = True session._apply_to( diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index fa88cca779..e79e6df475 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1947,6 +1947,9 @@ async def _retry_with_session( """ # Ensure that the options supports retry_writes and there is a valid session not in # transaction, otherwise, we will not support retry behavior for this txn. + retryable = bool( + retryable and self.options.retry_writes and session and not session.in_transaction + ) return await self._retry_internal( func=func, session=session, diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index dc21773eb9..5105b86ebd 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -130,24 +130,6 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - # @property - # def is_retryable(self) -> bool: - # if self.current_run: - # return self.current_run.is_retryable - # return True - # - # @property - # def retrying(self) -> bool: - # if self.current_run: - # return self.current_run.retrying - # return False - # - # @property - # def started_retryable_write(self) -> bool: - # if self.current_run: - # return self.current_run.started_retryable_write - # return False - def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) @@ -579,7 +561,6 @@ def _execute_command( # Start a new retryable write unless one was already # started for this command. if retryable and self.is_retryable and not self.started_retryable_write: - # print("starting retrayable write") session._start_retryable_write() self.started_retryable_write = True session._apply_to( @@ -596,6 +577,7 @@ def _execute_command( self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = self._execute_batch(bwc, cmd, ops, client) + # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -604,7 +586,9 @@ def _execute_command( full = copy.deepcopy(full_result) _merge_command(run, full, run.idx_offset, result) _raise_bulk_write_error(full) + _merge_command(run, full_result, run.idx_offset, result) + # We're no longer in a retry once a command succeeds. self.retrying = False self.started_retryable_write = False diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index bd25d042b4..dfac7b5f93 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -533,7 +533,12 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and self.is_retryable and not self.started_retryable_write: + if ( + retryable + and self.is_retryable + and not self.started_retryable_write + and not session.in_transaction + ): session._start_retryable_write() self.started_retryable_write = True session._apply_to( diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 089ab33477..829e66aff6 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1941,6 +1941,9 @@ def _retry_with_session( """ # Ensure that the options supports retry_writes and there is a valid session not in # transaction, otherwise, we will not support retry behavior for this txn. + retryable = bool( + retryable and self.options.retry_writes and session and not session.in_transaction + ) return self._retry_internal( func=func, session=session, From 89e8b3dad95b48bedf29655ea681635456336881 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 30 Apr 2025 15:33:36 -0700 Subject: [PATCH 07/10] hmm undo client bulk retry changes --- pymongo/asynchronous/client_bulk.py | 19 ++++--------------- pymongo/synchronous/client_bulk.py | 19 ++++--------------- 2 files changed, 8 insertions(+), 30 deletions(-) diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index ad0774ff65..d55f8351b8 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -535,17 +535,10 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if ( - retryable - and self.is_retryable - and not self.started_retryable_write - and not session.in_transaction - ): + if retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to( - cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn - ) + session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -572,11 +565,7 @@ async def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if ( - retryable - and self.is_retryable - and (retryable_top_level_error or retryable_network_error) - ): + if retryable and (retryable_top_level_error or retryable_network_error): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -595,7 +584,7 @@ async def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if retryable and self.is_retryable: + if retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index dfac7b5f93..06d4ca8872 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -533,17 +533,10 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if ( - retryable - and self.is_retryable - and not self.started_retryable_write - and not session.in_transaction - ): + if retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to( - cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn - ) + session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -570,11 +563,7 @@ def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if ( - retryable - and self.is_retryable - and (retryable_top_level_error or retryable_network_error) - ): + if retryable and (retryable_top_level_error or retryable_network_error): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -593,7 +582,7 @@ def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if retryable and self.is_retryable: + if retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: From af4f6a634d8c79897b43687b2d4ac6633fba0276 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 1 May 2025 13:52:09 -0700 Subject: [PATCH 08/10] started_retryable_write should only belong to bulk class and not run --- pymongo/asynchronous/bulk.py | 1 - pymongo/bulk_shared.py | 1 - pymongo/synchronous/bulk.py | 1 - 3 files changed, 3 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index ae1d492821..c3801edb5a 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -525,7 +525,6 @@ async def _execute_command( while run: self.is_retryable = run.is_retryable self.retrying = run.retrying - self.started_retryable_write = run.started_retryable_write if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py index b157edd2e2..4d1b2fe7cd 100644 --- a/pymongo/bulk_shared.py +++ b/pymongo/bulk_shared.py @@ -52,7 +52,6 @@ def __init__(self, op_type: int) -> None: self.idx_offset: int = 0 self.is_retryable = True self.retrying = False - self.started_retryable_write = False def index(self, idx: int) -> int: """Get the original index of an operation in this run. diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 5105b86ebd..37325f74d8 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -525,7 +525,6 @@ def _execute_command( while run: self.is_retryable = run.is_retryable self.retrying = run.retrying - self.started_retryable_write = run.started_retryable_write if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: From 4223168d98890f95b3fc1f3208cc1c187d6170b5 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 1 May 2025 20:56:25 -0700 Subject: [PATCH 09/10] hmm retrying is also strictly an attribute on the bulk? --- pymongo/asynchronous/bulk.py | 1 - pymongo/synchronous/bulk.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index c3801edb5a..41a2c9a27c 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -524,7 +524,6 @@ async def _execute_command( while run: self.is_retryable = run.is_retryable - self.retrying = run.retrying if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 37325f74d8..53f89cfe56 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -524,7 +524,6 @@ def _execute_command( while run: self.is_retryable = run.is_retryable - self.retrying = run.retrying if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: From 623f4deff94708870e4a8a732e92d63978cac92d Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 2 May 2025 15:06:16 -0700 Subject: [PATCH 10/10] test? --- pymongo/asynchronous/bulk.py | 8 +++++++- pymongo/synchronous/bulk.py | 8 +++++++- test/asynchronous/test_bulk.py | 8 ++++++++ test/test_bulk.py | 8 ++++++++ 4 files changed, 30 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 41a2c9a27c..bb9f8e5915 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -233,14 +233,17 @@ def gen_ordered( operation, in the order **provided**. """ run = None + ctr = 0 for idx, request in enumerate(requests): retryable = process(request) (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) - elif run.op_type != op_type: + elif run.op_type != op_type or ctr >= common.MAX_WRITE_BATCH_SIZE // 200: yield run + ctr = 0 run = _Run(op_type) + ctr += 1 run.add(idx, operation) run.is_retryable = run.is_retryable and retryable if run is None: @@ -604,6 +607,9 @@ async def _execute_command( break # Reset our state self.current_run = run = self.next_run + import gc + + gc.collect() async def execute_command( self, diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 53f89cfe56..edbeebbf35 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -233,14 +233,17 @@ def gen_ordered( operation, in the order **provided**. """ run = None + ctr = 0 for idx, request in enumerate(requests): retryable = process(request) (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) - elif run.op_type != op_type: + elif run.op_type != op_type or ctr >= common.MAX_WRITE_BATCH_SIZE // 200: yield run + ctr = 0 run = _Run(op_type) + ctr += 1 run.add(idx, operation) run.is_retryable = run.is_retryable and retryable if run is None: @@ -604,6 +607,9 @@ def _execute_command( break # Reset our state self.current_run = run = self.next_run + import gc + + gc.collect() def execute_command( self, diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 4d2338eae2..c4c1ed2cae 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -314,6 +314,14 @@ async def test_numerous_inserts_generator(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_huge_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = 1000000 + requests = (InsertOne({"x": "large" * 1024 * 1024}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_bulk_max_message_size(self): await self.coll.delete_many({}) self.addAsyncCleanup(self.coll.delete_many, {}) diff --git a/test/test_bulk.py b/test/test_bulk.py index 9696f6da1d..3c8fb3d5fa 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -314,6 +314,14 @@ def test_numerous_inserts_generator(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, self.coll.count_documents({})) + def test_huge_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = 1000000 + requests = (InsertOne({"x": "large" * 1024 * 1024}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + def test_bulk_max_message_size(self): self.coll.delete_many({}) self.addCleanup(self.coll.delete_many, {})