Skip to content

Commit

Permalink
Rewrite RequestBatch (#767)
Browse files Browse the repository at this point in the history
RequestBatch rewrite.

[ committed by @AlyssaCote ]
[ reviewed by @ankona @al-rigazzi ]
  • Loading branch information
AlyssaCote authored Oct 29, 2024
1 parent 2d8b902 commit f4e76a4
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 182 deletions.
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Jump to:

Description

- RequestBatch rewrite
- Fix regression on hostlist param to DragonRunRequest
- Fix dragon build logging bug
- Merge core refactor into MLI feature branch
Expand Down
23 changes: 10 additions & 13 deletions smartsim/_core/mli/infrastructure/control/request_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def can_be_removed(self) -> bool:
"""
return self.empty() and self._disposable

def flush(self) -> list[t.Any]:
def flush(self) -> list[InferenceRequest]:
"""Get all requests from queue.
:returns: Requests waiting to be executed
Expand Down Expand Up @@ -334,7 +334,7 @@ def _check_callback(self, request: InferenceRequest) -> bool:
:param request: The request to validate
:returns: False if callback validation fails for the request, True otherwise
"""
if request.callback:
if request.callback_desc:
return True

logger.error("No callback channel provided in request")
Expand Down Expand Up @@ -376,9 +376,7 @@ def _on_iteration(self) -> None:
tensor_bytes_list = bytes_list[1:]
self._perf_timer.start_timings()

request = self._worker.deserialize_message(
request_bytes, self._callback_factory
)
request = self._worker.deserialize_message(request_bytes)
if request.has_input_meta and tensor_bytes_list:
request.raw_inputs = tensor_bytes_list

Expand All @@ -387,7 +385,11 @@ def _on_iteration(self) -> None:
if not self._validate_request(request):
exception_handler(
ValueError("Error validating the request"),
request.callback,
(
self._callback_factory(request.callback_desc)
if request.callback_desc
else None
),
None,
)
self._perf_timer.measure_time("validate_request")
Expand Down Expand Up @@ -493,10 +495,8 @@ def flush_requests(self) -> None:
if queue.ready:
self._perf_timer.measure_time("find_queue")
try:
batch = RequestBatch(
requests=queue.flush(),
inputs=None,
model_id=queue.model_id,
batch = RequestBatch.from_requests(
queue.flush(), queue.model_id
)
finally:
self._perf_timer.measure_time("flush_requests")
Expand Down Expand Up @@ -528,9 +528,6 @@ def flush_requests(self) -> None:

self._perf_timer.measure_time("transform_input")
batch.inputs = transformed_inputs
for request in batch.requests:
request.raw_inputs = []
request.input_meta = []

try:
self._outgoing_queue.put(batch)
Expand Down
68 changes: 43 additions & 25 deletions smartsim/_core/mli/infrastructure/control/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ def _check_feature_stores(self, batch: RequestBatch) -> bool:
fs_model: t.Set[str] = set()
if batch.model_id.key:
fs_model = {batch.model_id.descriptor}
fs_inputs = {key.descriptor for key in batch.input_keys}
fs_outputs = {key.descriptor for key in batch.output_keys}
fs_inputs = {key.descriptor for keys in batch.input_keys for key in keys}
fs_outputs = {
key.descriptor for keys in batch.output_key_refs.values() for key in keys
}

# identify which feature stores are requested and unknown
fs_desired = fs_model.union(fs_inputs).union(fs_outputs)
Expand All @@ -159,7 +161,7 @@ def _validate_batch(self, batch: RequestBatch) -> bool:
:param batch: The batch of requests to validate
:returns: False if the request fails any validation checks, True otherwise
"""
if batch is None or not batch.has_valid_requests:
if batch is None or not batch.has_callbacks:
return False

return self._check_feature_stores(batch)
Expand Down Expand Up @@ -188,7 +190,7 @@ def _on_iteration(self) -> None:
return

if not self._device_manager:
for request in batch.requests:
for callback_desc in batch.callback_descriptors:
msg = "No Device Manager found. WorkerManager._on_start() "
"must be called after initialization. If possible, "
"you should use `WorkerManager.execute()` instead of "
Expand All @@ -200,7 +202,7 @@ def _on_iteration(self) -> None:
"and will not be processed."
exception_handler(
RuntimeError(msg),
request.callback,
self._callback_factory(callback_desc),
"Error acquiring device manager",
)
return
Expand All @@ -212,10 +214,10 @@ def _on_iteration(self) -> None:
feature_stores=self._feature_stores,
)
except Exception as exc:
for request in batch.requests:
for callback_desc in batch.callback_descriptors:
exception_handler(
exc,
request.callback,
self._callback_factory(callback_desc),
"Error loading model on device or getting device.",
)
return
Expand All @@ -226,18 +228,20 @@ def _on_iteration(self) -> None:
try:
model_result = LoadModelResult(device.get_model(batch.model_id.key))
except Exception as exc:
for request in batch.requests:
for callback_desc in batch.callback_descriptors:
exception_handler(
exc, request.callback, "Error getting model from device."
exc,
self._callback_factory(callback_desc),
"Error getting model from device.",
)
return
self._perf_timer.measure_time("load_model")

if not batch.inputs:
for request in batch.requests:
for callback_desc in batch.callback_descriptors:
exception_handler(
ValueError("Error batching inputs"),
request.callback,
self._callback_factory(callback_desc),
None,
)
return
Expand All @@ -248,8 +252,12 @@ def _on_iteration(self) -> None:
batch, model_result, transformed_input, device.name
)
except Exception as e:
for request in batch.requests:
exception_handler(e, request.callback, "Error while executing.")
for callback_desc in batch.callback_descriptors:
exception_handler(
e,
self._callback_factory(callback_desc),
"Error while executing.",
)
return
self._perf_timer.measure_time("execute")

Expand All @@ -258,24 +266,35 @@ def _on_iteration(self) -> None:
batch, execute_result
)
except Exception as e:
for request in batch.requests:
for callback_desc in batch.callback_descriptors:
exception_handler(
e, request.callback, "Error while transforming the output."
e,
self._callback_factory(callback_desc),
"Error while transforming the output.",
)
return

for request, transformed_output in zip(batch.requests, transformed_outputs):
for callback_desc, transformed_output in zip(
batch.callback_descriptors, transformed_outputs
):
reply = InferenceReply()
if request.has_output_keys:
if batch.output_key_refs:
try:
output_keys = batch.output_key_refs[callback_desc]
reply.output_keys = self._worker.place_output(
request,
output_keys,
transformed_output,
self._feature_stores,
)
except KeyError:
# the callback is not in the output_key_refs dict
# because it doesn't have output_keys associated with it
continue
except Exception as e:
exception_handler(
e, request.callback, "Error while placing the output."
e,
self._callback_factory(callback_desc),
"Error while placing the output.",
)
continue
else:
Expand All @@ -302,12 +321,11 @@ def _on_iteration(self) -> None:

self._perf_timer.measure_time("serialize_resp")

if request.callback:
request.callback.send(serialized_resp)
if reply.has_outputs:
# send tensor data after response
for output in reply.outputs:
request.callback.send(output)
callback = self._callback_factory(callback_desc)
callback.send(serialized_resp)
if reply.has_outputs:
for output in reply.outputs:
callback.send(output)
self._perf_timer.measure_time("send")

self._perf_timer.end_timings()
Expand Down
33 changes: 17 additions & 16 deletions smartsim/_core/mli/infrastructure/worker/torch_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@

from .....error import SmartSimError
from .....log import get_logger
from ...mli_schemas.tensor import tensor_capnp
from .worker import (
ExecuteResult,
FetchInputResult,
FetchModelResult,
LoadModelResult,
MachineLearningWorkerBase,
RequestBatch,
TensorMeta,
TransformInputResult,
TransformOutputResult,
)
Expand All @@ -64,7 +64,8 @@ def load_model(
"""Given a loaded MachineLearningModel, ensure it is loaded into
device memory.
:param request: The request that triggered the pipeline
:param batch: The batch that triggered the pipeline
:param fetch_result: The fetched model
:param device: The device on which the model must be placed
:returns: LoadModelResult wrapping the model loaded for the request
:raises ValueError: If model reference object is not found
Expand Down Expand Up @@ -97,13 +98,13 @@ def load_model(
@staticmethod
def transform_input(
batch: RequestBatch,
fetch_results: list[FetchInputResult],
fetch_results: FetchInputResult,
mem_pool: MemoryPool,
) -> TransformInputResult:
"""Given a collection of data, perform a transformation on the data and put
the raw tensor data on a MemoryPool allocation.
:param request: The request that triggered the pipeline
:param batch: The batch that triggered the pipeline
:param fetch_result: Raw outputs from fetching inputs out of a feature store
:param mem_pool: The memory pool used to access batched input tensors
:returns: The transformed inputs wrapped in a TransformInputResult
Expand All @@ -116,32 +117,32 @@ def transform_input(

all_dims: list[list[int]] = []
all_dtypes: list[str] = []
if fetch_results[0].meta is None:
if fetch_results.meta is None:
raise ValueError("Cannot reconstruct tensor without meta information")
# Traverse inputs to get total number of samples and compute slices
# Assumption: first dimension is samples, all tensors in the same input
# have same number of samples
# thus we only look at the first tensor for each input
for res_idx, fetch_result in enumerate(fetch_results):
if fetch_result.meta is None or any(
item_meta is None for item_meta in fetch_result.meta
for res_idx, res_meta_list in enumerate(fetch_results.meta):
if res_meta_list is None or any(
item_meta is None for item_meta in res_meta_list
):
raise ValueError("Cannot reconstruct tensor without meta information")
first_tensor_desc: tensor_capnp.TensorDescriptor = fetch_result.meta[0]
first_tensor_desc: TensorMeta = res_meta_list[0] # type: ignore
num_samples = first_tensor_desc.dimensions[0]
slices.append(slice(total_samples, total_samples + num_samples))
total_samples = total_samples + num_samples

if res_idx == len(fetch_results) - 1:
if res_idx == len(fetch_results.meta) - 1:
# For each tensor in the last input, get remaining dimensions
# Assumptions: all inputs have the same number of tensors and
# last N-1 dimensions match across inputs for corresponding tensors
# thus: resulting array will be of size (num_samples, all_other_dims)
for item_meta in fetch_result.meta:
tensor_desc: tensor_capnp.TensorDescriptor = item_meta
tensor_dims = list(tensor_desc.dimensions)
for item_meta in res_meta_list:
tensor_desc: TensorMeta = item_meta # type: ignore
tensor_dims = tensor_desc.dimensions
all_dims.append([total_samples, *tensor_dims[1:]])
all_dtypes.append(str(tensor_desc.dataType))
all_dtypes.append(tensor_desc.datatype)

for result_tensor_idx, (dims, dtype) in enumerate(zip(all_dims, all_dtypes)):
itemsize = np.empty((1), dtype=dtype).itemsize
Expand All @@ -151,8 +152,8 @@ def transform_input(
try:
mem_view[:alloc_size] = b"".join(
[
fetch_result.inputs[result_tensor_idx]
for fetch_result in fetch_results
fetch_result[result_tensor_idx]
for fetch_result in fetch_results.inputs
]
)
except IndexError as e:
Expand Down
Loading

0 comments on commit f4e76a4

Please sign in to comment.