diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index ac514db98f..bb9f8e5915 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -26,6 +26,8 @@ from typing import ( TYPE_CHECKING, Any, + Callable, + Iterable, Iterator, Mapping, Optional, @@ -72,7 +74,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: - from pymongo.asynchronous.collection import AsyncCollection + from pymongo.asynchronous.collection import AsyncCollection, _WriteOp from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline @@ -128,13 +130,14 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - def add_insert(self, document: _DocumentOut) -> None: + def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() self.ops.append((_INSERT, document)) + return True def add_update( self, @@ -146,7 +149,7 @@ def add_update( array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi} @@ -164,10 +167,12 @@ def add_update( if sort is not None: self.uses_sort = True cmd["sort"] = sort + + self.ops.append((_UPDATE, cmd)) if multi: # A bulk_write containing an update_many is not retryable. - self.is_retryable = False - self.ops.append((_UPDATE, cmd)) + return False + return True def add_replace( self, @@ -177,7 +182,7 @@ def add_replace( collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) cmd: dict[str, Any] = {"q": selector, "u": replacement} @@ -193,6 +198,7 @@ def add_replace( self.uses_sort = True cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) + return True def add_delete( self, @@ -200,7 +206,7 @@ def add_delete( limit: int, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, - ) -> None: + ) -> bool: """Create a delete document and add it to the list of ops.""" cmd: dict[str, Any] = {"q": selector, "limit": limit} if collation is not None: @@ -209,33 +215,63 @@ def add_delete( if hint is not None: self.uses_hint_delete = True cmd["hint"] = hint + + self.ops.append((_DELETE, cmd)) if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. - self.is_retryable = False - self.ops.append((_DELETE, cmd)) + return False + return True - def gen_ordered(self) -> Iterator[Optional[_Run]]: + def gen_ordered( + self, + requests: Iterable[Any], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None - for idx, (op_type, operation) in enumerate(self.ops): + ctr = 0 + for idx, request in enumerate(requests): + retryable = process(request) + (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) - elif run.op_type != op_type: + elif run.op_type != op_type or ctr >= common.MAX_WRITE_BATCH_SIZE // 200: yield run + ctr = 0 run = _Run(op_type) + ctr += 1 run.add(idx, operation) + run.is_retryable = run.is_retryable and retryable + if run is None: + raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self) -> Iterator[_Run]: + def gen_unordered( + self, + requests: Iterable[Any], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + retryable = process(request) + (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - + operations[op_type].is_retryable = operations[op_type].is_retryable and retryable + if ( + len(operations[_INSERT].ops) == 0 + and len(operations[_UPDATE].ops) == 0 + and len(operations[_DELETE].ops) == 0 + ): + raise InvalidOperation("No operations to execute") for run in operations: if run.ops: yield run @@ -472,6 +508,7 @@ async def _execute_command( op_id: int, retryable: bool, full_result: MutableMapping[str, Any], + validate: bool, final_write_concern: Optional[WriteConcern] = None, ) -> None: db_name = self.collection.database.name @@ -489,6 +526,7 @@ async def _execute_command( last_run = False while run: + self.is_retryable = run.is_retryable if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: @@ -523,10 +561,12 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if retryable and self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to( + cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn + ) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -534,6 +574,8 @@ async def _execute_command( ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. + if validate: + await self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = await self._execute_batch(bwc, cmd, ops, client) @@ -565,6 +607,9 @@ async def _execute_command( break # Reset our state self.current_run = run = self.next_run + import gc + + gc.collect() async def execute_command( self, @@ -598,6 +643,7 @@ async def retryable_bulk( op_id, retryable, full_result, + validate=False, ) client = self.collection.database.client @@ -615,7 +661,7 @@ async def retryable_bulk( return full_result async def execute_op_msg_no_results( - self, conn: AsyncConnection, generator: Iterator[Any] + self, conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name @@ -649,6 +695,7 @@ async def execute_op_msg_no_results( conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. + await self.validate_batch(conn, write_concern) to_send = await self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) @@ -684,10 +731,14 @@ async def execute_command_no_results( op_id, False, full_result, + True, write_concern, ) - except OperationFailure: - pass + except OperationFailure as exc: + if "Cannot set bypass_document_validation with unacknowledged write concern" in str( + exc + ): + raise exc async def execute_no_results( self, @@ -696,6 +747,11 @@ async def execute_no_results( write_concern: WriteConcern, ) -> None: """Execute all operations, returning no results (w=0).""" + if self.ordered: + return await self.execute_command_no_results(conn, generator, write_concern) + return await self.execute_op_msg_no_results(conn, generator, write_concern) + + async def validate_batch(self, conn: AsyncConnection, write_concern: WriteConcern) -> None: if self.uses_collation: raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: @@ -720,19 +776,17 @@ async def execute_no_results( "Cannot set bypass_document_validation with unacknowledged write concern" ) - if self.ordered: - return await self.execute_command_no_results(conn, generator, write_concern) - return await self.execute_op_msg_no_results(conn, generator) - async def execute( self, + generator: Iterable[Any], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], write_concern: WriteConcern, session: Optional[AsyncClientSession], operation: str, ) -> Any: """Execute operations.""" - if not self.ops: - raise InvalidOperation("No operations to execute") if self.executed: raise InvalidOperation("Bulk operations can only be executed once.") self.executed = True @@ -740,9 +794,9 @@ async def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered() + generator = self.gen_ordered(generator, process) else: - generator = self.gen_unordered() + generator = self.gen_unordered(generator, process) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 5f7ac013e9..d55f8351b8 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -116,6 +116,7 @@ def __init__( self.is_retryable = self.client.options.retry_writes self.retrying = False self.started_retryable_write = False + self.current_run = None @property def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]: diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 7fb20b7ab3..32811663c9 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -699,7 +699,7 @@ async def _create( @_csot.apply async def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]], + requests: Iterable[_WriteOp], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, @@ -779,17 +779,21 @@ async def bulk_write( .. versionadded:: 3.0 """ - common.validate_list("requests", requests) + common.validate_list_or_generator("requests", requests) blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let) - for request in requests: + + write_concern = self._write_concern_for(session) + + def process_for_bulk(request: _WriteOp) -> bool: try: - request._add_to_bulk(blk) + return request._add_to_bulk(blk) except AttributeError: raise TypeError(f"{request!r} is not a valid request") from None - write_concern = self._write_concern_for(session) - bulk_api_result = await blk.execute(write_concern, session, _Op.INSERT) + bulk_api_result = await blk.execute( + requests, process_for_bulk, write_concern, session, _Op.INSERT + ) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) @@ -960,20 +964,19 @@ async def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" - for document in documents: - common.validate_is_document_type("document", document) - if not isinstance(document, RawBSONDocument): - if "_id" not in document: - document["_id"] = ObjectId() # type: ignore[index] - inserted_ids.append(document["_id"]) - yield (message._INSERT, document) + common.validate_is_document_type("document", document) + if not isinstance(document, RawBSONDocument): + if "_id" not in document: + document["_id"] = ObjectId() # type: ignore[index] + inserted_ids.append(document["_id"]) + blk.ops.append((message._INSERT, document)) + return True write_concern = self._write_concern_for(session) blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment) - blk.ops = list(gen()) - await blk.execute(write_concern, session, _Op.INSERT) + await blk.execute(documents, process_for_bulk, write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) async def _update( diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a236b21348..e79e6df475 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2723,6 +2723,11 @@ def __init__( self._operation_id = operation_id self._attempt_number = 0 + def _bulk_retryable(self) -> bool: + if self._bulk is not None: + return self._bulk.is_retryable + return True + async def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2733,11 +2738,15 @@ async def run(self) -> T: # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. - if self._is_session_state_retryable() and self._retryable and not self._is_read: + if ( + self._is_session_state_retryable() + and self._retryable + and self._bulk_retryable() + and not self._is_read + ): self._session._start_retryable_write() # type: ignore if self._bulk: self._bulk.started_retryable_write = True - while True: self._check_last_error(check_csot=True) try: @@ -2767,10 +2776,9 @@ async def run(self) -> T: self._attempt_number += 1 else: raise - # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if not self._retryable and not self._bulk_retryable(): raise if isinstance(exc, ClientBulkWriteException) and exc.error: retryable_write_error_exc = isinstance( @@ -2801,11 +2809,15 @@ async def run(self) -> T: def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" - return not self._retryable or (self._is_retrying() and not self._multiple_retries) + return ( + not self._retryable + or not self._bulk_retryable() + or (self._is_retrying() and not self._multiple_retries) + ) def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk else self._retrying + return self._bulk.retrying or self._retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2865,6 +2877,8 @@ async def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False + if self._bulk: + self._bulk.is_retryable = False if self._retrying: _debug_log( _COMMAND_LOGGER, @@ -2875,7 +2889,7 @@ async def _write(self) -> T: ) return await self._func(self._session, conn, self._retryable) # type: ignore except PyMongoError as exc: - if not self._retryable: + if not self._retryable or not self._bulk_retryable(): raise # Add the RetryableWriteError label, if applicable. _add_retryable_write_error(exc, max_wire_version, is_mongos) @@ -2892,7 +2906,7 @@ async def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and (not self._retryable or not self._bulk_retryable()): self._check_last_error() if self._retrying: _debug_log( diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py index 9276419d8a..4d1b2fe7cd 100644 --- a/pymongo/bulk_shared.py +++ b/pymongo/bulk_shared.py @@ -50,6 +50,8 @@ def __init__(self, op_type: int) -> None: self.index_map: list[int] = [] self.ops: list[Any] = [] self.idx_offset: int = 0 + self.is_retryable = True + self.retrying = False def index(self, idx: int) -> int: """Get the original index of an operation in this run. diff --git a/pymongo/common.py b/pymongo/common.py index 3d8095eedf..6d9bb2f37a 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -24,6 +24,7 @@ TYPE_CHECKING, Any, Callable, + Generator, Iterator, Mapping, MutableMapping, @@ -530,6 +531,13 @@ def validate_list(option: str, value: Any) -> list: return value +def validate_list_or_generator(option: str, value: Any) -> Union[list, Generator]: + """Validates that 'value' is a list or generator.""" + if isinstance(value, Generator): + return value + return validate_list(option, value) + + def validate_list_or_none(option: Any, value: Any) -> Optional[list]: """Validates that 'value' is a list or None.""" if value is None: diff --git a/pymongo/operations.py b/pymongo/operations.py index 300f1ba123..49b41ee614 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -106,9 +106,9 @@ def __init__(self, document: _DocumentType, namespace: Optional[str] = None) -> self._doc = document self._namespace = namespace - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_insert(self._doc) # type: ignore[arg-type] + return bulkobj.add_insert(self._doc) # type: ignore[arg-type] def _add_to_client_bulk(self, bulkobj: _AgnosticClientBulk) -> None: """Add this operation to the _AsyncClientBulk/_ClientBulk instance `bulkobj`.""" @@ -230,9 +230,9 @@ def __init__( """ super().__init__(filter, collation, hint, namespace) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_delete( + return bulkobj.add_delete( self._filter, 1, collation=validate_collation_or_none(self._collation), @@ -291,9 +291,9 @@ def __init__( """ super().__init__(filter, collation, hint, namespace) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_delete( + return bulkobj.add_delete( self._filter, 0, collation=validate_collation_or_none(self._collation), @@ -384,9 +384,9 @@ def __init__( self._collation = collation self._namespace = namespace - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_replace( + return bulkobj.add_replace( self._filter, self._doc, self._upsert, @@ -606,9 +606,9 @@ def __init__( """ super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, sort) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_update( + return bulkobj.add_update( self._filter, self._doc, False, @@ -687,9 +687,9 @@ def __init__( """ super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, None) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_update( + return bulkobj.add_update( self._filter, self._doc, True, diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index a528b09add..edbeebbf35 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -26,6 +26,8 @@ from typing import ( TYPE_CHECKING, Any, + Callable, + Iterable, Iterator, Mapping, Optional, @@ -72,7 +74,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: - from pymongo.synchronous.collection import Collection + from pymongo.synchronous.collection import Collection, _WriteOp from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline @@ -128,13 +130,14 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - def add_insert(self, document: _DocumentOut) -> None: + def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() self.ops.append((_INSERT, document)) + return True def add_update( self, @@ -146,7 +149,7 @@ def add_update( array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi} @@ -164,10 +167,12 @@ def add_update( if sort is not None: self.uses_sort = True cmd["sort"] = sort + + self.ops.append((_UPDATE, cmd)) if multi: # A bulk_write containing an update_many is not retryable. - self.is_retryable = False - self.ops.append((_UPDATE, cmd)) + return False + return True def add_replace( self, @@ -177,7 +182,7 @@ def add_replace( collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) cmd: dict[str, Any] = {"q": selector, "u": replacement} @@ -193,6 +198,7 @@ def add_replace( self.uses_sort = True cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) + return True def add_delete( self, @@ -200,7 +206,7 @@ def add_delete( limit: int, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, - ) -> None: + ) -> bool: """Create a delete document and add it to the list of ops.""" cmd: dict[str, Any] = {"q": selector, "limit": limit} if collation is not None: @@ -209,33 +215,63 @@ def add_delete( if hint is not None: self.uses_hint_delete = True cmd["hint"] = hint + + self.ops.append((_DELETE, cmd)) if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. - self.is_retryable = False - self.ops.append((_DELETE, cmd)) + return False + return True - def gen_ordered(self) -> Iterator[Optional[_Run]]: + def gen_ordered( + self, + requests: Iterable[Any], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None - for idx, (op_type, operation) in enumerate(self.ops): + ctr = 0 + for idx, request in enumerate(requests): + retryable = process(request) + (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) - elif run.op_type != op_type: + elif run.op_type != op_type or ctr >= common.MAX_WRITE_BATCH_SIZE // 200: yield run + ctr = 0 run = _Run(op_type) + ctr += 1 run.add(idx, operation) + run.is_retryable = run.is_retryable and retryable + if run is None: + raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self) -> Iterator[_Run]: + def gen_unordered( + self, + requests: Iterable[Any], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + retryable = process(request) + (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - + operations[op_type].is_retryable = operations[op_type].is_retryable and retryable + if ( + len(operations[_INSERT].ops) == 0 + and len(operations[_UPDATE].ops) == 0 + and len(operations[_DELETE].ops) == 0 + ): + raise InvalidOperation("No operations to execute") for run in operations: if run.ops: yield run @@ -472,6 +508,7 @@ def _execute_command( op_id: int, retryable: bool, full_result: MutableMapping[str, Any], + validate: bool, final_write_concern: Optional[WriteConcern] = None, ) -> None: db_name = self.collection.database.name @@ -489,6 +526,7 @@ def _execute_command( last_run = False while run: + self.is_retryable = run.is_retryable if not self.retrying: self.next_run = next(generator, None) if self.next_run is None: @@ -523,10 +561,12 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if retryable and self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to( + cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn + ) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -534,6 +574,8 @@ def _execute_command( ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. + if validate: + self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = self._execute_batch(bwc, cmd, ops, client) @@ -565,6 +607,9 @@ def _execute_command( break # Reset our state self.current_run = run = self.next_run + import gc + + gc.collect() def execute_command( self, @@ -598,6 +643,7 @@ def retryable_bulk( op_id, retryable, full_result, + validate=False, ) client = self.collection.database.client @@ -614,7 +660,9 @@ def retryable_bulk( _raise_bulk_write_error(full_result) return full_result - def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: + def execute_op_msg_no_results( + self, conn: Connection, generator: Iterator[Any], write_concern: WriteConcern + ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client @@ -647,6 +695,7 @@ def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. + self.validate_batch(conn, write_concern) to_send = self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) @@ -682,10 +731,14 @@ def execute_command_no_results( op_id, False, full_result, + True, write_concern, ) - except OperationFailure: - pass + except OperationFailure as exc: + if "Cannot set bypass_document_validation with unacknowledged write concern" in str( + exc + ): + raise exc def execute_no_results( self, @@ -694,6 +747,11 @@ def execute_no_results( write_concern: WriteConcern, ) -> None: """Execute all operations, returning no results (w=0).""" + if self.ordered: + return self.execute_command_no_results(conn, generator, write_concern) + return self.execute_op_msg_no_results(conn, generator, write_concern) + + def validate_batch(self, conn: Connection, write_concern: WriteConcern) -> None: if self.uses_collation: raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: @@ -718,19 +776,17 @@ def execute_no_results( "Cannot set bypass_document_validation with unacknowledged write concern" ) - if self.ordered: - return self.execute_command_no_results(conn, generator, write_concern) - return self.execute_op_msg_no_results(conn, generator) - def execute( self, + generator: Iterable[Any], + process: Union[ + Callable[[_WriteOp], bool], Callable[[Union[_DocumentType, RawBSONDocument]], bool] + ], write_concern: WriteConcern, session: Optional[ClientSession], operation: str, ) -> Any: """Execute operations.""" - if not self.ops: - raise InvalidOperation("No operations to execute") if self.executed: raise InvalidOperation("Bulk operations can only be executed once.") self.executed = True @@ -738,9 +794,9 @@ def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered() + generator = self.gen_ordered(generator, process) else: - generator = self.gen_unordered() + generator = self.gen_unordered(generator, process) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index d73bfb2a2b..06d4ca8872 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -116,6 +116,7 @@ def __init__( self.is_retryable = self.client.options.retry_writes self.retrying = False self.started_retryable_write = False + self.current_run = None @property def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]: diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 8a71768318..2a91412e33 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -698,7 +698,7 @@ def _create( @_csot.apply def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]], + requests: Iterable[_WriteOp], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, @@ -778,17 +778,21 @@ def bulk_write( .. versionadded:: 3.0 """ - common.validate_list("requests", requests) + common.validate_list_or_generator("requests", requests) blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let) - for request in requests: + + write_concern = self._write_concern_for(session) + + def process_for_bulk(request: _WriteOp) -> bool: try: - request._add_to_bulk(blk) + return request._add_to_bulk(blk) except AttributeError: raise TypeError(f"{request!r} is not a valid request") from None - write_concern = self._write_concern_for(session) - bulk_api_result = blk.execute(write_concern, session, _Op.INSERT) + bulk_api_result = blk.execute( + requests, process_for_bulk, write_concern, session, _Op.INSERT + ) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) @@ -959,20 +963,19 @@ def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" - for document in documents: - common.validate_is_document_type("document", document) - if not isinstance(document, RawBSONDocument): - if "_id" not in document: - document["_id"] = ObjectId() # type: ignore[index] - inserted_ids.append(document["_id"]) - yield (message._INSERT, document) + common.validate_is_document_type("document", document) + if not isinstance(document, RawBSONDocument): + if "_id" not in document: + document["_id"] = ObjectId() # type: ignore[index] + inserted_ids.append(document["_id"]) + blk.ops.append((message._INSERT, document)) + return True write_concern = self._write_concern_for(session) blk = _Bulk(self, ordered, bypass_document_validation, comment=comment) - blk.ops = list(gen()) - blk.execute(write_concern, session, _Op.INSERT) + blk.execute(documents, process_for_bulk, write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) def _update( diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 99a517e5c1..829e66aff6 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2709,6 +2709,11 @@ def __init__( self._operation_id = operation_id self._attempt_number = 0 + def _bulk_retryable(self) -> bool: + if self._bulk is not None: + return self._bulk.is_retryable + return True + def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2719,11 +2724,15 @@ def run(self) -> T: # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. - if self._is_session_state_retryable() and self._retryable and not self._is_read: + if ( + self._is_session_state_retryable() + and self._retryable + and self._bulk_retryable() + and not self._is_read + ): self._session._start_retryable_write() # type: ignore if self._bulk: self._bulk.started_retryable_write = True - while True: self._check_last_error(check_csot=True) try: @@ -2753,10 +2762,9 @@ def run(self) -> T: self._attempt_number += 1 else: raise - # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if not self._retryable and not self._bulk_retryable(): raise if isinstance(exc, ClientBulkWriteException) and exc.error: retryable_write_error_exc = isinstance( @@ -2787,11 +2795,15 @@ def run(self) -> T: def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" - return not self._retryable or (self._is_retrying() and not self._multiple_retries) + return ( + not self._retryable + or not self._bulk_retryable() + or (self._is_retrying() and not self._multiple_retries) + ) def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk else self._retrying + return self._bulk.retrying or self._retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2851,6 +2863,8 @@ def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False + if self._bulk: + self._bulk.is_retryable = False if self._retrying: _debug_log( _COMMAND_LOGGER, @@ -2861,7 +2875,7 @@ def _write(self) -> T: ) return self._func(self._session, conn, self._retryable) # type: ignore except PyMongoError as exc: - if not self._retryable: + if not self._retryable or not self._bulk_retryable(): raise # Add the RetryableWriteError label, if applicable. _add_retryable_write_error(exc, max_wire_version, is_mongos) @@ -2878,7 +2892,7 @@ def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and (not self._retryable or not self._bulk_retryable()): self._check_last_error() if self._retrying: _debug_log( diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 65ed6e236a..c4c1ed2cae 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -299,6 +299,29 @@ async def test_numerous_inserts(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_numerous_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = await async_client_context.max_write_batch_size + 100 + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests, ordered=False) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + + # Same with ordered bulk. + await self.coll.drop() + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + + async def test_huge_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = 1000000 + requests = (InsertOne({"x": "large" * 1024 * 1024}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_bulk_max_message_size(self): await self.coll.delete_many({}) self.addAsyncCleanup(self.coll.delete_many, {}) @@ -338,11 +361,6 @@ async def test_bulk_write_no_results(self): self.assertRaises(InvalidOperation, lambda: result.upserted_ids) async def test_bulk_write_invalid_arguments(self): - # The requests argument must be a list. - generator = (InsertOne[dict]({}) for _ in range(10)) - with self.assertRaises(TypeError): - await self.coll.bulk_write(generator) # type: ignore[arg-type] - # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): await self.coll.bulk_write([{}]) # type: ignore[list-item] diff --git a/test/test_bulk.py b/test/test_bulk.py index 8a863cc49b..3c8fb3d5fa 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -299,6 +299,29 @@ def test_numerous_inserts(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, self.coll.count_documents({})) + def test_numerous_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = client_context.max_write_batch_size + 100 + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests, ordered=False) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + + # Same with ordered bulk. + self.coll.drop() + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + + def test_huge_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = 1000000 + requests = (InsertOne({"x": "large" * 1024 * 1024}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + def test_bulk_max_message_size(self): self.coll.delete_many({}) self.addCleanup(self.coll.delete_many, {}) @@ -338,11 +361,6 @@ def test_bulk_write_no_results(self): self.assertRaises(InvalidOperation, lambda: result.upserted_ids) def test_bulk_write_invalid_arguments(self): - # The requests argument must be a list. - generator = (InsertOne[dict]({}) for _ in range(10)) - with self.assertRaises(TypeError): - self.coll.bulk_write(generator) # type: ignore[arg-type] - # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): self.coll.bulk_write([{}]) # type: ignore[list-item] diff --git a/test/test_typing.py b/test/test_typing.py index 65937020d2..f1b6a59e49 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -428,13 +428,9 @@ def test_typeddict_document_type_insertion(self) -> None: def test_bulk_write_document_type_insertion(self): client: MongoClient[MovieWithId] = MongoClient() coll: Collection[MovieWithId] = client.test.test - coll.bulk_write( - [InsertOne(Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type] - ) + coll.bulk_write([InsertOne(Movie({"name": "THX-1138", "year": 1971}))]) mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971} - coll.bulk_write( - [InsertOne(mov_dict)] # type:ignore[arg-type] - ) + coll.bulk_write([InsertOne(mov_dict)]) coll.bulk_write( [ InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore @@ -445,13 +441,9 @@ def test_bulk_write_document_type_insertion(self): def test_bulk_write_document_type_replacement(self): client: MongoClient[MovieWithId] = MongoClient() coll: Collection[MovieWithId] = client.test.test - coll.bulk_write( - [ReplaceOne({}, Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type] - ) + coll.bulk_write([ReplaceOne({}, Movie({"name": "THX-1138", "year": 1971}))]) mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971} - coll.bulk_write( - [ReplaceOne({}, mov_dict)] # type:ignore[arg-type] - ) + coll.bulk_write([ReplaceOne({}, mov_dict)]) coll.bulk_write( [ ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore