Skip to content

Commit

Permalink
dragon test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Oct 31, 2024
1 parent 904b5cc commit 384bd7a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 284 deletions.
40 changes: 14 additions & 26 deletions tests/dragon_wlm/test_request_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from smartsim.log import get_logger

from .utils.channel import FileSystemCommChannel
from .utils.msg_pump import mock_messages
from .utils.msg_pump import mock_message

logger = get_logger(__name__)

Expand All @@ -90,7 +90,6 @@
pass


@pytest.mark.skip("TODO: Fix issue unpickling messages")
@pytest.mark.parametrize("num_iterations", [4])
def test_request_dispatcher(
num_iterations: int,
Expand Down Expand Up @@ -139,7 +138,6 @@ def test_request_dispatcher(

# put some messages into the work queue for the dispatcher to pickup
channels = []
processes = []
for i in range(num_iterations):
batch: t.Optional[RequestBatch] = None
mem_allocs = []
Expand All @@ -151,28 +149,22 @@ def test_request_dispatcher(
callback_channel = DragonCommChannel.from_local()
channels.append(callback_channel)

process = function_as_dragon_proc(
mock_messages,
[
worker_queue.descriptor,
backbone_fs.descriptor,
i,
callback_channel.descriptor,
],
[],
[],
)
processes.append(process)
process.start()
assert process.returncode is None, "The message pump failed to start"
# assert process.returncode is None, "The message pump failed to start"
# give dragon some time to populate the message queues
for i in range(15):
for j in range(5):
try:
if j < 2:
mock_message(
worker_queue.descriptor,
backbone_fs.descriptor,
j,
callback_channel.descriptor,
)
request_dispatcher._on_iteration()
batch = request_dispatcher.task_queue.get(timeout=1.0)
break
except Empty:
time.sleep(2)
time.sleep(1)
logger.warning(f"Task queue is empty on iteration {i}")
continue
except Exception as exc:
Expand Down Expand Up @@ -213,13 +205,9 @@ def test_request_dispatcher(
assert len(tensors) == 1
assert tensors[0].shape == torch.Size([2, 2])

for tensor in tensors:
for sample_idx in range(tensor.shape[0]):
tensor_in = tensor[sample_idx]
tensor_out = (sample_idx + 1) * torch.ones(
(2,), dtype=torch.float32
)
assert torch.equal(tensor_in, tensor_out)
exp_tensor = torch.Tensor([[1.0, 1.0], [2.0, 2.0]])

assert torch.equal(exp_tensor, tensors[0])

except Exception as exc:
raise exc
Expand Down
58 changes: 57 additions & 1 deletion tests/dragon_wlm/utils/msg_pump.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,60 @@ def persist_model_file(model_path: pathlib.Path) -> pathlib.Path:
return model_path


def mock_message(
dispatch_fli_descriptor: str,
fs_descriptor: str,
msg_number: int,
callback_descriptor: str,
) -> None:
"""Mock event producer for triggering the inference pipeline."""
model_key = "mini-model"
# mock_message sends 2 messages, so we offset by 2 * (# of iterations in caller)

feature_store = BackboneFeatureStore.from_descriptor(fs_descriptor)
request_dispatcher_queue = DragonFLIChannel.from_descriptor(dispatch_fli_descriptor)

feature_store[model_key] = load_model()
logger.debug(f"Sending mock message {msg_number}")

output_key = f"output-{msg_number}"

tensor = ((msg_number + 1) * torch.ones((1, 2), dtype=torch.float32)).numpy()
fsd = feature_store.descriptor

tensor_desc = MessageHandler.build_tensor_descriptor(
"c", "float32", list(tensor.shape)
)

message_tensor_output_key = MessageHandler.build_tensor_key(output_key, fsd)
message_model_key = MessageHandler.build_model_key(model_key, fsd)

request = MessageHandler.build_request(
reply_channel=callback_descriptor,
model=message_model_key,
inputs=[tensor_desc],
outputs=[message_tensor_output_key],
output_descriptors=[],
custom_attributes=None,
)

logger.info(f"Sending request {msg_number} to request_dispatcher_queue")
request_bytes = MessageHandler.serialize_request(request)

logger.info("Sending msg_envelope")

# cuid = request_dispatcher_queue._channel.cuid
# logger.info(f"\tInternal cuid: {cuid}")

# send the header & body together so they arrive together
try:
request_dispatcher_queue.send_multiple([request_bytes, tensor.tobytes()], 1.0)
logger.info(f"\tenvelope 0: {request_bytes[:5]}...")
logger.info(f"\tenvelope 1: {tensor.tobytes()[:5]} - ({tensor})")
except Exception as ex:
logger.exception("Unable to send request envelope")


def _mock_messages(
dispatch_fli_descriptor: str,
fs_descriptor: str,
Expand Down Expand Up @@ -163,7 +217,9 @@ def _mock_messages(

# send the header & body together so they arrive together
try:
request_dispatcher_queue.send_multiple([request_bytes, tensor.tobytes()])
request_dispatcher_queue.send_multiple(
[request_bytes, tensor.tobytes()], 1.0
)
logger.info(f"\tenvelope 0: {request_bytes[:5]}...")
logger.info(f"\tenvelope 1: {tensor.tobytes()[:5]}...")
except Exception as ex:
Expand Down
Loading

0 comments on commit 384bd7a

Please sign in to comment.