diff --git a/doc/changelog.md b/doc/changelog.md index febf5d0b7d..98ef6c0f50 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- RequestBatch rewrite - Fix regression on hostlist param to DragonRunRequest - Fix dragon build logging bug - Merge core refactor into MLI feature branch diff --git a/smartsim/_core/mli/infrastructure/control/request_dispatcher.py b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py index e22a2c8f62..6dafd9b7c9 100644 --- a/smartsim/_core/mli/infrastructure/control/request_dispatcher.py +++ b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py @@ -172,7 +172,7 @@ def can_be_removed(self) -> bool: """ return self.empty() and self._disposable - def flush(self) -> list[t.Any]: + def flush(self) -> list[InferenceRequest]: """Get all requests from queue. :returns: Requests waiting to be executed @@ -334,7 +334,7 @@ def _check_callback(self, request: InferenceRequest) -> bool: :param request: The request to validate :returns: False if callback validation fails for the request, True otherwise """ - if request.callback: + if request.callback_desc: return True logger.error("No callback channel provided in request") @@ -376,9 +376,7 @@ def _on_iteration(self) -> None: tensor_bytes_list = bytes_list[1:] self._perf_timer.start_timings() - request = self._worker.deserialize_message( - request_bytes, self._callback_factory - ) + request = self._worker.deserialize_message(request_bytes) if request.has_input_meta and tensor_bytes_list: request.raw_inputs = tensor_bytes_list @@ -387,7 +385,11 @@ def _on_iteration(self) -> None: if not self._validate_request(request): exception_handler( ValueError("Error validating the request"), - request.callback, + ( + self._callback_factory(request.callback_desc) + if request.callback_desc + else None + ), None, ) self._perf_timer.measure_time("validate_request") @@ -493,10 +495,8 @@ def flush_requests(self) -> None: if queue.ready: self._perf_timer.measure_time("find_queue") try: - batch = RequestBatch( - requests=queue.flush(), - inputs=None, - model_id=queue.model_id, + batch = RequestBatch.from_requests( + queue.flush(), queue.model_id ) finally: self._perf_timer.measure_time("flush_requests") @@ -528,9 +528,6 @@ def flush_requests(self) -> None: self._perf_timer.measure_time("transform_input") batch.inputs = transformed_inputs - for request in batch.requests: - request.raw_inputs = [] - request.input_meta = [] try: self._outgoing_queue.put(batch) diff --git a/smartsim/_core/mli/infrastructure/control/worker_manager.py b/smartsim/_core/mli/infrastructure/control/worker_manager.py index bf6fddb81d..1c93276074 100644 --- a/smartsim/_core/mli/infrastructure/control/worker_manager.py +++ b/smartsim/_core/mli/infrastructure/control/worker_manager.py @@ -132,8 +132,10 @@ def _check_feature_stores(self, batch: RequestBatch) -> bool: fs_model: t.Set[str] = set() if batch.model_id.key: fs_model = {batch.model_id.descriptor} - fs_inputs = {key.descriptor for key in batch.input_keys} - fs_outputs = {key.descriptor for key in batch.output_keys} + fs_inputs = {key.descriptor for keys in batch.input_keys for key in keys} + fs_outputs = { + key.descriptor for keys in batch.output_key_refs.values() for key in keys + } # identify which feature stores are requested and unknown fs_desired = fs_model.union(fs_inputs).union(fs_outputs) @@ -159,7 +161,7 @@ def _validate_batch(self, batch: RequestBatch) -> bool: :param batch: The batch of requests to validate :returns: False if the request fails any validation checks, True otherwise """ - if batch is None or not batch.has_valid_requests: + if batch is None or not batch.has_callbacks: return False return self._check_feature_stores(batch) @@ -188,7 +190,7 @@ def _on_iteration(self) -> None: return if not self._device_manager: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: msg = "No Device Manager found. WorkerManager._on_start() " "must be called after initialization. If possible, " "you should use `WorkerManager.execute()` instead of " @@ -200,7 +202,7 @@ def _on_iteration(self) -> None: "and will not be processed." exception_handler( RuntimeError(msg), - request.callback, + self._callback_factory(callback_desc), "Error acquiring device manager", ) return @@ -212,10 +214,10 @@ def _on_iteration(self) -> None: feature_stores=self._feature_stores, ) except Exception as exc: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: exception_handler( exc, - request.callback, + self._callback_factory(callback_desc), "Error loading model on device or getting device.", ) return @@ -226,18 +228,20 @@ def _on_iteration(self) -> None: try: model_result = LoadModelResult(device.get_model(batch.model_id.key)) except Exception as exc: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: exception_handler( - exc, request.callback, "Error getting model from device." + exc, + self._callback_factory(callback_desc), + "Error getting model from device.", ) return self._perf_timer.measure_time("load_model") if not batch.inputs: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: exception_handler( ValueError("Error batching inputs"), - request.callback, + self._callback_factory(callback_desc), None, ) return @@ -248,8 +252,12 @@ def _on_iteration(self) -> None: batch, model_result, transformed_input, device.name ) except Exception as e: - for request in batch.requests: - exception_handler(e, request.callback, "Error while executing.") + for callback_desc in batch.callback_descriptors: + exception_handler( + e, + self._callback_factory(callback_desc), + "Error while executing.", + ) return self._perf_timer.measure_time("execute") @@ -258,24 +266,35 @@ def _on_iteration(self) -> None: batch, execute_result ) except Exception as e: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: exception_handler( - e, request.callback, "Error while transforming the output." + e, + self._callback_factory(callback_desc), + "Error while transforming the output.", ) return - for request, transformed_output in zip(batch.requests, transformed_outputs): + for callback_desc, transformed_output in zip( + batch.callback_descriptors, transformed_outputs + ): reply = InferenceReply() - if request.has_output_keys: + if batch.output_key_refs: try: + output_keys = batch.output_key_refs[callback_desc] reply.output_keys = self._worker.place_output( - request, + output_keys, transformed_output, self._feature_stores, ) + except KeyError: + # the callback is not in the output_key_refs dict + # because it doesn't have output_keys associated with it + continue except Exception as e: exception_handler( - e, request.callback, "Error while placing the output." + e, + self._callback_factory(callback_desc), + "Error while placing the output.", ) continue else: @@ -302,12 +321,11 @@ def _on_iteration(self) -> None: self._perf_timer.measure_time("serialize_resp") - if request.callback: - request.callback.send(serialized_resp) - if reply.has_outputs: - # send tensor data after response - for output in reply.outputs: - request.callback.send(output) + callback = self._callback_factory(callback_desc) + callback.send(serialized_resp) + if reply.has_outputs: + for output in reply.outputs: + callback.send(output) self._perf_timer.measure_time("send") self._perf_timer.end_timings() diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py index 64e94e5eb6..f8d6e7c2de 100644 --- a/smartsim/_core/mli/infrastructure/worker/torch_worker.py +++ b/smartsim/_core/mli/infrastructure/worker/torch_worker.py @@ -34,7 +34,6 @@ from .....error import SmartSimError from .....log import get_logger -from ...mli_schemas.tensor import tensor_capnp from .worker import ( ExecuteResult, FetchInputResult, @@ -42,6 +41,7 @@ LoadModelResult, MachineLearningWorkerBase, RequestBatch, + TensorMeta, TransformInputResult, TransformOutputResult, ) @@ -64,7 +64,8 @@ def load_model( """Given a loaded MachineLearningModel, ensure it is loaded into device memory. - :param request: The request that triggered the pipeline + :param batch: The batch that triggered the pipeline + :param fetch_result: The fetched model :param device: The device on which the model must be placed :returns: LoadModelResult wrapping the model loaded for the request :raises ValueError: If model reference object is not found @@ -97,13 +98,13 @@ def load_model( @staticmethod def transform_input( batch: RequestBatch, - fetch_results: list[FetchInputResult], + fetch_results: FetchInputResult, mem_pool: MemoryPool, ) -> TransformInputResult: """Given a collection of data, perform a transformation on the data and put the raw tensor data on a MemoryPool allocation. - :param request: The request that triggered the pipeline + :param batch: The batch that triggered the pipeline :param fetch_result: Raw outputs from fetching inputs out of a feature store :param mem_pool: The memory pool used to access batched input tensors :returns: The transformed inputs wrapped in a TransformInputResult @@ -116,32 +117,32 @@ def transform_input( all_dims: list[list[int]] = [] all_dtypes: list[str] = [] - if fetch_results[0].meta is None: + if fetch_results.meta is None: raise ValueError("Cannot reconstruct tensor without meta information") # Traverse inputs to get total number of samples and compute slices # Assumption: first dimension is samples, all tensors in the same input # have same number of samples # thus we only look at the first tensor for each input - for res_idx, fetch_result in enumerate(fetch_results): - if fetch_result.meta is None or any( - item_meta is None for item_meta in fetch_result.meta + for res_idx, res_meta_list in enumerate(fetch_results.meta): + if res_meta_list is None or any( + item_meta is None for item_meta in res_meta_list ): raise ValueError("Cannot reconstruct tensor without meta information") - first_tensor_desc: tensor_capnp.TensorDescriptor = fetch_result.meta[0] + first_tensor_desc: TensorMeta = res_meta_list[0] # type: ignore num_samples = first_tensor_desc.dimensions[0] slices.append(slice(total_samples, total_samples + num_samples)) total_samples = total_samples + num_samples - if res_idx == len(fetch_results) - 1: + if res_idx == len(fetch_results.meta) - 1: # For each tensor in the last input, get remaining dimensions # Assumptions: all inputs have the same number of tensors and # last N-1 dimensions match across inputs for corresponding tensors # thus: resulting array will be of size (num_samples, all_other_dims) - for item_meta in fetch_result.meta: - tensor_desc: tensor_capnp.TensorDescriptor = item_meta - tensor_dims = list(tensor_desc.dimensions) + for item_meta in res_meta_list: + tensor_desc: TensorMeta = item_meta # type: ignore + tensor_dims = tensor_desc.dimensions all_dims.append([total_samples, *tensor_dims[1:]]) - all_dtypes.append(str(tensor_desc.dataType)) + all_dtypes.append(tensor_desc.datatype) for result_tensor_idx, (dims, dtype) in enumerate(zip(all_dims, all_dtypes)): itemsize = np.empty((1), dtype=dtype).itemsize @@ -151,8 +152,8 @@ def transform_input( try: mem_view[:alloc_size] = b"".join( [ - fetch_result.inputs[result_tensor_idx] - for fetch_result in fetch_results + fetch_result[result_tensor_idx] + for fetch_result in fetch_results.inputs ] ) except IndexError as e: diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index 9556b8e438..b122a1d9ba 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -36,14 +36,13 @@ from .....error import SmartSimError from .....log import get_logger -from ...comm.channel.channel import CommChannelBase from ...message_handler import MessageHandler from ...mli_schemas.model.model_capnp import Model +from ...mli_schemas.tensor.tensor_capnp import TensorDescriptor from ..storage.feature_store import FeatureStore, ModelKey, TensorKey if t.TYPE_CHECKING: from smartsim._core.mli.mli_schemas.response.response_capnp import Status - from smartsim._core.mli.mli_schemas.tensor.tensor_capnp import TensorDescriptor logger = get_logger(__name__) @@ -57,10 +56,10 @@ class InferenceRequest: def __init__( self, model_key: t.Optional[ModelKey] = None, - callback: t.Optional[CommChannelBase] = None, + callback_desc: t.Optional[str] = None, raw_inputs: t.Optional[t.List[bytes]] = None, input_keys: t.Optional[t.List[TensorKey]] = None, - input_meta: t.Optional[t.List[t.Any]] = None, + input_meta: t.Optional[t.List[TensorDescriptor]] = None, output_keys: t.Optional[t.List[TensorKey]] = None, raw_model: t.Optional[Model] = None, batch_size: int = 0, @@ -68,7 +67,8 @@ def __init__( """Initialize the InferenceRequest. :param model_key: A tuple containing a (key, descriptor) pair - :param callback: The channel used for notification of inference completion + :param callback_desc: The channel descriptor used for notification + of inference completion :param raw_inputs: Raw bytes of tensor inputs :param input_keys: A list of tuples containing a (key, descriptor) pair :param input_meta: Metadata about the input data @@ -80,7 +80,7 @@ def __init__( """A tuple containing a (key, descriptor) pair""" self.raw_model = raw_model """Raw bytes of an ML model""" - self.callback = callback + self.callback_desc = callback_desc """The channel used for notification of inference completion""" self.raw_inputs = raw_inputs or [] """Raw bytes of tensor inputs""" @@ -191,6 +191,20 @@ def has_output_keys(self) -> bool: return self.output_keys is not None and bool(self.output_keys) +@dataclass +class TensorMeta: + """Metadata about a tensor, built from TensorDescriptors.""" + + dimensions: t.List[int] + """Dimensions of the tensor""" + order: str + """Order of the tensor in row major ("c"), or + column major ("f") format""" + datatype: str + """Datatype of the tensor as specified by the TensorDescriptor + NumericalType enums. Examples include "float32", "int8", etc.""" + + class LoadModelResult: """A wrapper around a loaded model.""" @@ -253,7 +267,11 @@ def __init__(self, result: t.Any, slices: list[slice]) -> None: class FetchInputResult: """A wrapper around fetched inputs.""" - def __init__(self, result: t.List[bytes], meta: t.Optional[t.List[t.Any]]) -> None: + def __init__( + self, + result: t.List[t.List[bytes]], + meta: t.List[t.List[t.Optional[TensorMeta]]], + ) -> None: """Initialize the FetchInputResult. :param result: List of input tensor bytes @@ -316,20 +334,31 @@ def __init__(self, result: bytes) -> None: class RequestBatch: """A batch of aggregated inference requests.""" - requests: list[InferenceRequest] - """List of InferenceRequests in the batch""" + raw_model: t.Optional[Model] + """Raw bytes of the model""" + callback_descriptors: t.List[str] + """The descriptors for channels used for notification of inference completion""" + raw_inputs: t.List[t.List[bytes]] + """Raw bytes of tensor inputs""" + input_meta: t.List[t.List[TensorMeta]] + """Metadata about the input data""" + input_keys: t.List[t.List[TensorKey]] + """A list of tuples containing a (key, descriptor) pair""" + output_key_refs: t.Dict[str, t.List[TensorKey]] + """A dictionary mapping callbacks descriptors to output keys""" inputs: t.Optional[TransformInputResult] """Transformed batch of input tensors""" model_id: "ModelIdentifier" """Model (key, descriptor) tuple""" @property - def has_valid_requests(self) -> bool: - """Returns whether the batch contains at least one request. + def has_callbacks(self) -> bool: + """Determines if the batch has at least one callback channel + available for sending results. - :returns: True if at least one request is available + :returns: True if at least one callback is present """ - return len(self.requests) > 0 + return len(self.callback_descriptors) > 0 @property def has_raw_model(self) -> bool: @@ -339,37 +368,49 @@ def has_raw_model(self) -> bool: """ return self.raw_model is not None - @property - def raw_model(self) -> t.Optional[t.Any]: - """Returns the raw model to use to execute for this batch - if it is available. - - :returns: A model if available, otherwise None""" - if self.has_valid_requests: - return self.requests[0].raw_model - return None - - @property - def input_keys(self) -> t.List[TensorKey]: - """All input keys available in this batch's requests. - - :returns: All input keys belonging to requests in this batch""" - keys = [] - for request in self.requests: - keys.extend(request.input_keys) - - return keys - - @property - def output_keys(self) -> t.List[TensorKey]: - """All output keys available in this batch's requests. - - :returns: All output keys belonging to requests in this batch""" - keys = [] - for request in self.requests: - keys.extend(request.output_keys) - - return keys + @classmethod + def from_requests( + cls, + requests: t.List[InferenceRequest], + model_id: ModelIdentifier, + ) -> "RequestBatch": + """Create a RequestBatch from a list of requests. + + :param requests: The requests to batch + :param model_id: The model identifier + :returns: A RequestBatch instance + """ + return cls( + raw_model=requests[0].raw_model, + callback_descriptors=[ + request.callback_desc for request in requests if request.callback_desc + ], + raw_inputs=[ + request.raw_inputs for request in requests if request.raw_inputs + ], + input_meta=[ + [ + TensorMeta( + dimensions=list(meta.dimensions), + order=str(meta.order), + datatype=str(meta.dataType), + ) + for meta in request.input_meta + ] + for request in requests + if request.input_meta + ], + input_keys=[ + request.input_keys for request in requests if request.input_keys + ], + output_key_refs={ + request.callback_desc: request.output_keys + for request in requests + if request.callback_desc and request.output_keys + }, + inputs=None, + model_id=model_id, + ) class MachineLearningWorkerCore: @@ -378,13 +419,10 @@ class MachineLearningWorkerCore: @staticmethod def deserialize_message( data_blob: bytes, - callback_factory: t.Callable[[str], CommChannelBase], ) -> InferenceRequest: """Deserialize a message from a byte stream into an InferenceRequest. :param data_blob: The byte stream to deserialize - :param callback_factory: A factory method that can create an instance - of the desired concrete comm channel type :returns: The raw input message deserialized into an InferenceRequest """ request = MessageHandler.deserialize_request(data_blob) @@ -400,7 +438,6 @@ def deserialize_message( model_bytes = request.model.data callback_key = request.replyChannel.descriptor - comm_channel = callback_factory(callback_key) input_keys: t.Optional[t.List[TensorKey]] = None input_bytes: t.Optional[t.List[bytes]] = None output_keys: t.Optional[t.List[TensorKey]] = None @@ -422,7 +459,7 @@ def deserialize_message( inference_request = InferenceRequest( model_key=model_key, - callback=comm_channel, + callback_desc=callback_key, raw_inputs=input_bytes, input_meta=input_meta, input_keys=input_keys, @@ -497,7 +534,7 @@ def fetch_model( @staticmethod def fetch_inputs( batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] - ) -> t.List[FetchInputResult]: + ) -> FetchInputResult: """Given a collection of ResourceKeys, identify the physical location and input metadata. @@ -507,49 +544,52 @@ def fetch_inputs( :raises ValueError: If neither an input key or an input tensor are provided :raises SmartSimError: If a tensor for a given key cannot be retrieved """ - fetch_results = [] - for request in batch.requests: - if request.raw_inputs: - fetch_results.append( - FetchInputResult(request.raw_inputs, request.input_meta) - ) - continue - - if not feature_stores: - raise ValueError("No input and no feature store provided") - - if request.has_input_keys: - data: t.List[bytes] = [] + if not batch.raw_inputs and not batch.input_keys: + raise ValueError("No input source") - for fs_key in request.input_keys: + if not feature_stores: + raise ValueError("No feature stores provided") + + data_list: t.List[t.List[bytes]] = [] + meta_list: t.List[t.List[t.Optional[TensorMeta]]] = [] + # meta_list will be t.List[t.List[TensorMeta]] once input_key metadata + # is available to be retrieved from the feature store + + if batch.raw_inputs: + for raw_inputs, input_meta in zip(batch.raw_inputs, batch.input_meta): + data_list.append(raw_inputs) + meta_list.append(input_meta) # type: ignore + + if batch.input_keys: + for batch_keys in batch.input_keys: + batch_data: t.List[bytes] = [] + for fs_key in batch_keys: try: feature_store = feature_stores[fs_key.descriptor] tensor_bytes = t.cast(bytes, feature_store[fs_key.key]) - data.append(tensor_bytes) + batch_data.append(tensor_bytes) except KeyError as ex: logger.exception(ex) raise SmartSimError( f"Tensor could not be retrieved with key {fs_key.key}" ) from ex - fetch_results.append( - FetchInputResult(data, meta=None) - ) # fixme: need to get both tensor and descriptor - continue + data_list.append(batch_data) + meta_list.append([None] * len(batch_data)) + # fixme: need to get both tensor and descriptor + # this will eventually append meta info retrieved from the feature store - raise ValueError("No input source") - - return fetch_results + return FetchInputResult(result=data_list, meta=meta_list) @staticmethod def place_output( - request: InferenceRequest, + output_keys: t.List[TensorKey], transform_result: TransformOutputResult, feature_stores: t.Dict[str, FeatureStore], ) -> t.Collection[t.Optional[TensorKey]]: """Given a collection of data, make it available as a shared resource in the feature store. - :param request: The request that triggered the pipeline + :param output_keys: The output_keys that will be placed in the feature store :param transform_result: Transformed version of the inference result :param feature_stores: Available feature stores used for persistence :returns: A collection of keys that were placed in the feature store @@ -563,7 +603,7 @@ def place_output( # accurately placed, datum might need to include this. # Consider parallelizing all PUT feature_store operations - for fs_key, v in zip(request.output_keys, transform_result.outputs): + for fs_key, v in zip(output_keys, transform_result.outputs): feature_store = feature_stores[fs_key.descriptor] feature_store[fs_key.key] = v keys.append(fs_key) @@ -596,7 +636,7 @@ def load_model( @abstractmethod def transform_input( batch: RequestBatch, - fetch_results: list[FetchInputResult], + fetch_results: FetchInputResult, mem_pool: MemoryPool, ) -> TransformInputResult: """Given a collection of data, perform a transformation on the data and put diff --git a/tests/dragon_wlm/test_core_machine_learning_worker.py b/tests/dragon_wlm/test_core_machine_learning_worker.py index f9295d9e86..febdb00c8d 100644 --- a/tests/dragon_wlm/test_core_machine_learning_worker.py +++ b/tests/dragon_wlm/test_core_machine_learning_worker.py @@ -42,6 +42,7 @@ TransformOutputResult, ) +from .channel import FileSystemCommChannel from .feature_store import FileSystemFeatureStore, MemoryFeatureStore # The tests in this file belong to the dragon group @@ -100,7 +101,7 @@ def test_fetch_model_disk(persist_torch_model: pathlib.Path, test_dir: str) -> N model_key = ModelKey(key=key, descriptor=fsd) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_model(batch, {fsd: feature_store}) assert fetch_result.model_bytes @@ -118,7 +119,7 @@ def test_fetch_model_disk_missing() -> None: model_key = ModelKey(key=key, descriptor=fsd) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) with pytest.raises(sse.SmartSimError) as ex: worker.fetch_model(batch, {fsd: feature_store}) @@ -143,7 +144,7 @@ def test_fetch_model_feature_store(persist_torch_model: pathlib.Path) -> None: model_key = ModelKey(key=key, descriptor=feature_store.descriptor) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_model(batch, {fsd: feature_store}) assert fetch_result.model_bytes @@ -161,7 +162,7 @@ def test_fetch_model_feature_store_missing() -> None: model_key = ModelKey(key=key, descriptor=feature_store.descriptor) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) # todo: consider that raising this exception shows impl. replace... with pytest.raises(sse.SmartSimError) as ex: @@ -184,7 +185,7 @@ def test_fetch_model_memory(persist_torch_model: pathlib.Path) -> None: model_key = ModelKey(key=key, descriptor=feature_store.descriptor) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_model(batch, {fsd: feature_store}) assert fetch_result.model_bytes @@ -202,14 +203,14 @@ def test_fetch_input_disk(persist_torch_tensor: pathlib.Path) -> None: request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) worker = MachineLearningWorkerCore feature_store[tensor_name] = persist_torch_tensor.read_bytes() fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) - assert fetch_result[0].inputs is not None + assert fetch_result.inputs[0] is not None def test_fetch_input_disk_missing() -> None: @@ -224,7 +225,7 @@ def test_fetch_input_disk_missing() -> None: request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) with pytest.raises(sse.SmartSimError) as ex: worker.fetch_inputs(batch, {fsd: feature_store}) @@ -249,12 +250,12 @@ def test_fetch_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: feature_store[tensor_name] = persist_torch_tensor.read_bytes() model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) - assert fetch_result[0].inputs + assert fetch_result.inputs[0] assert ( - list(fetch_result[0].inputs)[0][:10] == persist_torch_tensor.read_bytes()[:10] + list(fetch_result.inputs[0])[0][:10] == persist_torch_tensor.read_bytes()[:10] ) @@ -287,12 +288,11 @@ def test_fetch_multi_input_feature_store(persist_torch_tensor: pathlib.Path) -> ) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) - raw_bytes = list(fetch_result[0].inputs) - assert raw_bytes + raw_bytes = list(fetch_result.inputs[0]) assert raw_bytes[0][:10] == persist_torch_tensor.read_bytes()[:10] assert raw_bytes[1][:10] == body2[:10] assert raw_bytes[2][:10] == body3[:10] @@ -309,7 +309,7 @@ def test_fetch_input_feature_store_missing() -> None: request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) with pytest.raises(sse.SmartSimError) as ex: worker.fetch_inputs(batch, {fsd: feature_store}) @@ -331,13 +331,43 @@ def test_fetch_input_memory(persist_torch_tensor: pathlib.Path) -> None: request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) - assert fetch_result[0].inputs is not None + assert fetch_result.inputs[0] is not None -def test_place_outputs() -> None: +def test_fetch_inputs_no_input_source(): + """Verify that the ML worker fails when there are no inputs provided""" + worker = MachineLearningWorkerCore + request1 = InferenceRequest() + request2 = InferenceRequest() + + batch = RequestBatch.from_requests( + [request1, request2], ModelKey("test-model", "desc") + ) + + with pytest.raises(ValueError) as exc: + fetch_result = worker.fetch_inputs(batch, {}) + assert str(exc.value) == "No input source" + + +def test_fetch_inputs_no_feature_store(): + """Verify that the ML worker fails when there is no feature store""" + worker = MachineLearningWorkerCore + request1 = InferenceRequest(raw_inputs=[b"abcdef"]) + request2 = InferenceRequest(raw_inputs=[b"ghijkl"]) + + batch = RequestBatch.from_requests( + [request1, request2], ModelKey("test-model", "desc") + ) + + with pytest.raises(ValueError) as exc: + fetch_result = worker.fetch_inputs(batch, {}) + assert str(exc.value) == "No feature stores provided" + + +def test_place_outputs(test_dir: str) -> None: """Verify outputs are shared using the feature store""" worker = MachineLearningWorkerCore @@ -351,18 +381,42 @@ def test_place_outputs() -> None: TensorKey(key=key_name + "2", descriptor=fsd), TensorKey(key=key_name + "3", descriptor=fsd), ] + + keys2 = [ + TensorKey(key=key_name + "4", descriptor=fsd), + TensorKey(key=key_name + "5", descriptor=fsd), + TensorKey(key=key_name + "6", descriptor=fsd), + ] data = [b"abcdef", b"ghijkl", b"mnopqr"] + data2 = [b"stuvwx", b"yzabcd", b"efghij"] - for fsk, v in zip(keys, data): - feature_store[fsk.key] = v + callback1 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback1").descriptor + callback2 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback2").descriptor - request = InferenceRequest(output_keys=keys) + model_id = ModelKey(key="test-model", descriptor=fsd) + request = InferenceRequest(callback_desc=callback1, output_keys=keys) + request2 = InferenceRequest(callback_desc=callback2, output_keys=keys2) transform_result = TransformOutputResult(data, [1], "c", "float32") + transform_result2 = TransformOutputResult(data2, [1], "c", "float32") - worker.place_output(request, transform_result, {fsd: feature_store}) + request_batch = RequestBatch.from_requests([request, request2], model_id) + + worker.place_output( + request_batch.output_key_refs[callback1], + transform_result, + {fsd: feature_store}, + ) + + worker.place_output( + request_batch.output_key_refs[callback2], + transform_result2, + {fsd: feature_store}, + ) - for i in range(3): - assert feature_store[keys[i].key] == data[i] + all_keys = keys + keys2 + all_data = data + data2 + for i in range(6): + assert feature_store[all_keys[i].key] == all_data[i] @pytest.mark.parametrize( diff --git a/tests/dragon_wlm/test_device_manager.py b/tests/dragon_wlm/test_device_manager.py index d270e921cb..8a82c3d32b 100644 --- a/tests/dragon_wlm/test_device_manager.py +++ b/tests/dragon_wlm/test_device_manager.py @@ -123,7 +123,7 @@ def test_device_manager_model_in_request(): request = InferenceRequest( model_key=model_key, - callback=None, + callback_desc=None, raw_inputs=None, input_keys=[tensor_key], input_meta=None, @@ -132,10 +132,13 @@ def test_device_manager_model_in_request(): batch_size=0, ) - request_batch = RequestBatch( + request_batch = RequestBatch.from_requests( [request], - TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), - model_id=model_key, + model_key, + ) + + request_batch.inputs = TransformInputResult( + b"transformed", [slice(0, 1)], [[1, 2]], ["float32"] ) with device_manager.get_device( @@ -161,7 +164,7 @@ def test_device_manager_model_key(): request = InferenceRequest( model_key=model_key, - callback=None, + callback_desc=None, raw_inputs=None, input_keys=[tensor_key], input_meta=None, @@ -170,10 +173,13 @@ def test_device_manager_model_key(): batch_size=0, ) - request_batch = RequestBatch( + request_batch = RequestBatch.from_requests( [request], - TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), - model_id=model_key, + model_key, + ) + + request_batch.inputs = TransformInputResult( + b"transformed", [slice(0, 1)], [[1, 2]], ["float32"] ) with device_manager.get_device( diff --git a/tests/dragon_wlm/test_error_handling.py b/tests/dragon_wlm/test_error_handling.py index aacd47b556..a0bfbd7e9f 100644 --- a/tests/dragon_wlm/test_error_handling.py +++ b/tests/dragon_wlm/test_error_handling.py @@ -24,11 +24,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pathlib import typing as t from unittest.mock import MagicMock import pytest +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel + dragon = pytest.importorskip("dragon") import multiprocessing as mp @@ -123,9 +126,13 @@ def setup_worker_manager_model_bytes( tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + callback_descriptor = FileSystemCommChannel( + pathlib.Path(test_dir) / "callback1" + ).descriptor + inf_request = InferenceRequest( model_key=None, - callback=None, + callback_desc=callback_descriptor, raw_inputs=None, input_keys=[tensor_key], input_meta=None, @@ -136,10 +143,13 @@ def setup_worker_manager_model_bytes( model_id = ModelKey(key="key", descriptor=app_feature_store.descriptor) - request_batch = RequestBatch( + request_batch = RequestBatch.from_requests( [inf_request], - TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), - model_id=model_id, + model_id, + ) + + request_batch.inputs = TransformInputResult( + b"transformed", [slice(0, 1)], [[1, 2]], ["float32"] ) dispatcher_task_queue.put(request_batch) @@ -182,9 +192,13 @@ def setup_worker_manager_model_key( output_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) model_id = ModelKey(key="model key", descriptor=app_feature_store.descriptor) + callback_descriptor = FileSystemCommChannel( + pathlib.Path(test_dir) / "callback2" + ).descriptor + request = InferenceRequest( model_key=model_id, - callback=None, + callback_desc=callback_descriptor, raw_inputs=None, input_keys=[tensor_key], input_meta=None, @@ -192,10 +206,13 @@ def setup_worker_manager_model_key( raw_model=b"model", batch_size=0, ) - request_batch = RequestBatch( + request_batch = RequestBatch.from_requests( [request], - TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), - model_id=model_id, + model_id, + ) + + request_batch.inputs = TransformInputResult( + b"transformed", [slice(0, 1)], [[1, 2]], ["float32"] ) dispatcher_task_queue.put(request_batch) diff --git a/tests/dragon_wlm/test_request_dispatcher.py b/tests/dragon_wlm/test_request_dispatcher.py index 8dc0f67a31..1b1fb1b013 100644 --- a/tests/dragon_wlm/test_request_dispatcher.py +++ b/tests/dragon_wlm/test_request_dispatcher.py @@ -26,6 +26,7 @@ import gc import os +import pathlib import time import typing as t from queue import Empty @@ -49,7 +50,6 @@ # isort: on - from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel from smartsim._core.mli.comm.channel.dragon_util import create_local @@ -69,9 +69,13 @@ from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( DragonFeatureStore, ) +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey, TensorKey from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.infrastructure.worker.worker import InferenceRequest, TensorMeta +from smartsim._core.mli.message_handler import MessageHandler from smartsim.log import get_logger +from .utils.channel import FileSystemCommChannel from .utils.msg_pump import mock_messages logger = get_logger(__name__) @@ -161,7 +165,6 @@ def test_request_dispatcher( processes.append(process) process.start() assert process.returncode is None, "The message pump failed to start" - # give dragon some time to populate the message queues for i in range(15): try: @@ -177,7 +180,7 @@ def test_request_dispatcher( raise exc assert batch is not None - assert batch.has_valid_requests + assert batch.has_callbacks model_key = batch.model_id.key @@ -200,7 +203,7 @@ def test_request_dispatcher( ) ) - assert len(batch.requests) == 2 + assert len(batch.callback_descriptors) == 2 assert batch.model_id.key == model_key assert model_key in request_dispatcher._queues assert model_key in request_dispatcher._active_queues @@ -235,3 +238,88 @@ def test_request_dispatcher( # Try to remove the dispatcher and free the memory del request_dispatcher gc.collect() + + +def test_request_batch(test_dir: str) -> None: + """Test the RequestBatch.from_requests instantiates properly""" + tensor_key = TensorKey(key="key", descriptor="desc1") + tensor_key2 = TensorKey(key="key2", descriptor="desc1") + output_key = TensorKey(key="key", descriptor="desc2") + output_key2 = TensorKey(key="key2", descriptor="desc2") + model_id1 = ModelKey(key="model key", descriptor="model desc") + model_id2 = ModelKey(key="model key2", descriptor="model desc") + tensor_desc = MessageHandler.build_tensor_descriptor("c", "float32", [1, 2]) + req_batch_model_id = ModelKey(key="req key", descriptor="desc") + + callback1 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback1").descriptor + callback2 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback2").descriptor + callback3 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback3").descriptor + + request1 = InferenceRequest( + model_key=model_id1, + callback_desc=callback1, + raw_inputs=[b"input data"], + input_keys=[tensor_key], + input_meta=[tensor_desc], + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + + request2 = InferenceRequest( + model_key=model_id2, + callback_desc=callback2, + raw_inputs=None, + input_keys=None, + input_meta=None, + output_keys=[output_key, output_key2], + raw_model=b"model", + batch_size=0, + ) + + request3 = InferenceRequest( + model_key=model_id2, + callback_desc=callback3, + raw_inputs=None, + input_keys=[tensor_key, tensor_key2], + input_meta=[tensor_desc], + output_keys=None, + raw_model=b"model", + batch_size=0, + ) + + request_batch = RequestBatch.from_requests( + [request1, request2, request3], req_batch_model_id + ) + + assert len(request_batch.callback_descriptors) == 3 + for callback in request_batch.callback_descriptors: + assert isinstance(callback, str) + assert len(request_batch.output_key_refs.keys()) == 2 + assert request_batch.has_callbacks + assert request_batch.model_id == req_batch_model_id + assert request_batch.inputs == None + assert request_batch.raw_model == b"model" + assert request_batch.raw_inputs == [[b"input data"]] + assert request_batch.input_meta == [ + [TensorMeta([1, 2], "c", "float32")], + [TensorMeta([1, 2], "c", "float32")], + ] + assert request_batch.input_keys == [ + [tensor_key], + [tensor_key, tensor_key2], + ] + assert request_batch.output_key_refs == { + callback1: [output_key], + callback2: [output_key, output_key2], + } + + +def test_request_batch_has_no_callbacks(): + """Verify that a request batch with no callbacks is correctly identified""" + request1 = InferenceRequest() + request2 = InferenceRequest() + + batch = RequestBatch.from_requests([request1, request2], ModelKey("model", "desc")) + + assert not batch.has_callbacks diff --git a/tests/dragon_wlm/test_torch_worker.py b/tests/dragon_wlm/test_torch_worker.py index 2a9e7d01bd..5be0177c5a 100644 --- a/tests/dragon_wlm/test_torch_worker.py +++ b/tests/dragon_wlm/test_torch_worker.py @@ -110,7 +110,7 @@ def get_request() -> InferenceRequest: return InferenceRequest( model_key=ModelKey(key="model", descriptor="xyz"), - callback=None, + callback_desc=None, raw_inputs=tensor_numpy, input_keys=None, input_meta=serialized_tensors_descriptors, @@ -121,10 +121,13 @@ def get_request() -> InferenceRequest: def get_request_batch_from_request( - request: InferenceRequest, inputs: t.Optional[TransformInputResult] = None + request: InferenceRequest, ) -> RequestBatch: - return RequestBatch([request], inputs, request.model_key) + return RequestBatch.from_requests( + [request], + request.model_key, + ) sample_request: InferenceRequest = get_request() @@ -146,13 +149,13 @@ def test_load_model(mlutils) -> None: def test_transform_input(mlutils) -> None: fetch_input_result = FetchInputResult( - sample_request.raw_inputs, sample_request.input_meta + sample_request_batch.raw_inputs, sample_request_batch.input_meta ) mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) transform_input_result = worker.transform_input( - sample_request_batch, [fetch_input_result], mem_pool + sample_request_batch, fetch_input_result, mem_pool ) batch = get_batch().numpy() @@ -184,15 +187,15 @@ def test_execute(mlutils) -> None: Net().to(torch_device[mlutils.get_test_device().lower()]) ) fetch_input_result = FetchInputResult( - sample_request.raw_inputs, sample_request.input_meta + sample_request_batch.raw_inputs, sample_request_batch.input_meta ) - request_batch = get_request_batch_from_request(sample_request, fetch_input_result) + request_batch = get_request_batch_from_request(sample_request) mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) transform_result = worker.transform_input( - request_batch, [fetch_input_result], mem_pool + request_batch, fetch_input_result, mem_pool ) execute_result = worker.execute(