From 8a19dee1fd37be9e04c14a9d486c3314a676dbe3 Mon Sep 17 00:00:00 2001 From: Chris McBride Date: Sat, 19 Oct 2024 14:08:06 -0400 Subject: [PATCH] Merge MLI feature branch into v1.0 branch (#754) Create a v1.0 branch to combine ongoing efforts in `mli-feature` and `smartsim-refactor` feature branches --------- Co-authored-by: Alyssa Cote <46540273+AlyssaCote@users.noreply.github.com> Co-authored-by: Al Rigazzi --- .github/workflows/run_tests.yml | 27 +- Makefile | 13 +- conftest.py | 2 +- doc/changelog.md | 33 + doc/installation_instructions/basic.rst | 14 + ex/high_throughput_inference/mli_driver.py | 77 ++ ex/high_throughput_inference/mock_app.py | 142 +++ .../mock_app_redis.py | 90 ++ ex/high_throughput_inference/redis_driver.py | 66 ++ .../standalone_worker_manager.py | 218 +++++ pyproject.toml | 1 + setup.py | 2 + smartsim/_core/_cli/build.py | 54 +- smartsim/_core/_cli/scripts/dragon_install.py | 368 ++++++-- smartsim/_core/_install/builder.py | 8 +- smartsim/_core/config/config.py | 35 +- smartsim/_core/entrypoints/service.py | 185 ++++ .../_core/launcher/dragon/dragonBackend.py | 400 +++++++-- .../_core/launcher/dragon/dragonConnector.py | 85 +- .../_core/launcher/dragon/dragonLauncher.py | 2 + smartsim/_core/launcher/dragon/pqueue.py | 461 ++++++++++ smartsim/_core/launcher/step/dragonStep.py | 2 + smartsim/_core/mli/__init__.py | 0 smartsim/_core/mli/client/__init__.py | 0 smartsim/_core/mli/client/protoclient.py | 348 ++++++++ smartsim/_core/mli/comm/channel/__init__.py | 0 smartsim/_core/mli/comm/channel/channel.py | 82 ++ .../_core/mli/comm/channel/dragon_channel.py | 127 +++ smartsim/_core/mli/comm/channel/dragon_fli.py | 158 ++++ .../_core/mli/comm/channel/dragon_util.py | 131 +++ smartsim/_core/mli/infrastructure/__init__.py | 0 .../_core/mli/infrastructure/comm/__init__.py | 0 .../mli/infrastructure/comm/broadcaster.py | 239 +++++ .../_core/mli/infrastructure/comm/consumer.py | 281 ++++++ .../_core/mli/infrastructure/comm/event.py | 162 ++++ .../_core/mli/infrastructure/comm/producer.py | 44 + .../mli/infrastructure/control/__init__.py | 0 .../infrastructure/control/device_manager.py | 166 ++++ .../infrastructure/control/error_handling.py | 78 ++ .../mli/infrastructure/control/listener.py | 352 ++++++++ .../control/request_dispatcher.py | 559 ++++++++++++ .../infrastructure/control/worker_manager.py | 330 +++++++ .../mli/infrastructure/environment_loader.py | 116 +++ .../mli/infrastructure/storage/__init__.py | 0 .../storage/backbone_feature_store.py | 259 ++++++ .../storage/dragon_feature_store.py | 126 +++ .../mli/infrastructure/storage/dragon_util.py | 101 +++ .../infrastructure/storage/feature_store.py | 224 +++++ .../mli/infrastructure/worker/__init__.py | 0 .../mli/infrastructure/worker/torch_worker.py | 276 ++++++ .../_core/mli/infrastructure/worker/worker.py | 646 ++++++++++++++ smartsim/_core/mli/message_handler.py | 602 +++++++++++++ .../mli_schemas/data/data_references.capnp | 37 + .../mli_schemas/data/data_references_capnp.py | 41 + .../data/data_references_capnp.pyi | 107 +++ .../_core/mli/mli_schemas/model/__init__.py | 0 .../_core/mli/mli_schemas/model/model.capnp | 33 + .../mli/mli_schemas/model/model_capnp.py | 38 + .../mli/mli_schemas/model/model_capnp.pyi | 72 ++ .../mli/mli_schemas/request/request.capnp | 55 ++ .../request_attributes.capnp | 49 + .../request_attributes_capnp.py | 41 + .../request_attributes_capnp.pyi | 109 +++ .../mli/mli_schemas/request/request_capnp.py | 41 + .../mli/mli_schemas/request/request_capnp.pyi | 319 +++++++ .../mli/mli_schemas/response/response.capnp | 52 ++ .../response_attributes.capnp | 33 + .../response_attributes_capnp.py | 41 + .../response_attributes_capnp.pyi | 103 +++ .../mli_schemas/response/response_capnp.py | 38 + .../mli_schemas/response/response_capnp.pyi | 212 +++++ .../_core/mli/mli_schemas/tensor/tensor.capnp | 75 ++ .../mli/mli_schemas/tensor/tensor_capnp.py | 41 + .../mli/mli_schemas/tensor/tensor_capnp.pyi | 142 +++ smartsim/_core/schemas/utils.py | 4 +- smartsim/_core/utils/helpers.py | 2 +- smartsim/_core/utils/timings.py | 175 ++++ smartsim/database/orchestrator.py | 4 +- smartsim/experiment.py | 2 +- smartsim/log.py | 13 +- smartsim/settings/dragonRunSettings.py | 20 + tests/dragon/__init__.py | 0 tests/dragon/channel.py | 125 +++ tests/dragon/conftest.py | 129 +++ tests/dragon/feature_store.py | 152 ++++ .../test_core_machine_learning_worker.py | 377 ++++++++ tests/dragon/test_device_manager.py | 186 ++++ tests/dragon/test_dragon_backend.py | 308 +++++++ tests/dragon/test_dragon_ddict_utils.py | 117 +++ tests/dragon/test_environment_loader.py | 147 +++ tests/dragon/test_error_handling.py | 511 +++++++++++ tests/dragon/test_event_consumer.py | 386 ++++++++ tests/dragon/test_featurestore.py | 327 +++++++ tests/dragon/test_featurestore_base.py | 844 ++++++++++++++++++ tests/dragon/test_featurestore_integration.py | 213 +++++ tests/dragon/test_inference_reply.py | 76 ++ tests/dragon/test_inference_request.py | 118 +++ tests/dragon/test_protoclient.py | 313 +++++++ tests/dragon/test_reply_building.py | 64 ++ tests/dragon/test_request_dispatcher.py | 233 +++++ tests/dragon/test_torch_worker.py | 221 +++++ tests/dragon/test_worker_manager.py | 314 +++++++ tests/dragon/utils/__init__.py | 0 tests/dragon/utils/channel.py | 125 +++ tests/dragon/utils/msg_pump.py | 225 +++++ tests/dragon/utils/worker.py | 104 +++ tests/mli/__init__.py | 0 tests/mli/channel.py | 125 +++ tests/mli/feature_store.py | 144 +++ tests/mli/test_integrated_torch_worker.py | 275 ++++++ tests/mli/test_service.py | 290 ++++++ tests/mli/worker.py | 104 +++ tests/on_wlm/test_dragon.py | 2 +- tests/test_config.py | 28 +- tests/test_dragon_comm_utils.py | 257 ++++++ tests/test_dragon_installer.py | 142 ++- tests/test_dragon_launcher.py | 30 +- tests/test_dragon_run_policy.py | 9 - tests/test_dragon_run_request.py | 380 ++++---- tests/test_dragon_runsettings.py | 119 +++ tests/test_dragon_step.py | 13 + tests/test_message_handler/__init__.py | 0 .../test_message_handler/test_build_model.py | 72 ++ .../test_build_model_key.py | 47 + .../test_build_request_attributes.py | 55 ++ .../test_build_tensor_desc.py | 90 ++ .../test_build_tensor_key.py | 46 + .../test_output_descriptor.py | 78 ++ tests/test_message_handler/test_request.py | 449 ++++++++++ tests/test_message_handler/test_response.py | 191 ++++ tests/test_node_prioritizer.py | 553 ++++++++++++ 131 files changed, 18278 insertions(+), 427 deletions(-) create mode 100644 ex/high_throughput_inference/mli_driver.py create mode 100644 ex/high_throughput_inference/mock_app.py create mode 100644 ex/high_throughput_inference/mock_app_redis.py create mode 100644 ex/high_throughput_inference/redis_driver.py create mode 100644 ex/high_throughput_inference/standalone_worker_manager.py create mode 100644 smartsim/_core/entrypoints/service.py create mode 100644 smartsim/_core/launcher/dragon/pqueue.py create mode 100644 smartsim/_core/mli/__init__.py create mode 100644 smartsim/_core/mli/client/__init__.py create mode 100644 smartsim/_core/mli/client/protoclient.py create mode 100644 smartsim/_core/mli/comm/channel/__init__.py create mode 100644 smartsim/_core/mli/comm/channel/channel.py create mode 100644 smartsim/_core/mli/comm/channel/dragon_channel.py create mode 100644 smartsim/_core/mli/comm/channel/dragon_fli.py create mode 100644 smartsim/_core/mli/comm/channel/dragon_util.py create mode 100644 smartsim/_core/mli/infrastructure/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/comm/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/comm/broadcaster.py create mode 100644 smartsim/_core/mli/infrastructure/comm/consumer.py create mode 100644 smartsim/_core/mli/infrastructure/comm/event.py create mode 100644 smartsim/_core/mli/infrastructure/comm/producer.py create mode 100644 smartsim/_core/mli/infrastructure/control/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/control/device_manager.py create mode 100644 smartsim/_core/mli/infrastructure/control/error_handling.py create mode 100644 smartsim/_core/mli/infrastructure/control/listener.py create mode 100644 smartsim/_core/mli/infrastructure/control/request_dispatcher.py create mode 100644 smartsim/_core/mli/infrastructure/control/worker_manager.py create mode 100644 smartsim/_core/mli/infrastructure/environment_loader.py create mode 100644 smartsim/_core/mli/infrastructure/storage/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py create mode 100644 smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py create mode 100644 smartsim/_core/mli/infrastructure/storage/dragon_util.py create mode 100644 smartsim/_core/mli/infrastructure/storage/feature_store.py create mode 100644 smartsim/_core/mli/infrastructure/worker/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/worker/torch_worker.py create mode 100644 smartsim/_core/mli/infrastructure/worker/worker.py create mode 100644 smartsim/_core/mli/message_handler.py create mode 100644 smartsim/_core/mli/mli_schemas/data/data_references.capnp create mode 100644 smartsim/_core/mli/mli_schemas/data/data_references_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/model/__init__.py create mode 100644 smartsim/_core/mli/mli_schemas/model/model.capnp create mode 100644 smartsim/_core/mli/mli_schemas/model/model_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/model/model_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/request/request.capnp create mode 100644 smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp create mode 100644 smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/request/request_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/request/request_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/response/response.capnp create mode 100644 smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp create mode 100644 smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/response/response_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/response/response_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/tensor/tensor.capnp create mode 100644 smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi create mode 100644 smartsim/_core/utils/timings.py create mode 100644 tests/dragon/__init__.py create mode 100644 tests/dragon/channel.py create mode 100644 tests/dragon/conftest.py create mode 100644 tests/dragon/feature_store.py create mode 100644 tests/dragon/test_core_machine_learning_worker.py create mode 100644 tests/dragon/test_device_manager.py create mode 100644 tests/dragon/test_dragon_backend.py create mode 100644 tests/dragon/test_dragon_ddict_utils.py create mode 100644 tests/dragon/test_environment_loader.py create mode 100644 tests/dragon/test_error_handling.py create mode 100644 tests/dragon/test_event_consumer.py create mode 100644 tests/dragon/test_featurestore.py create mode 100644 tests/dragon/test_featurestore_base.py create mode 100644 tests/dragon/test_featurestore_integration.py create mode 100644 tests/dragon/test_inference_reply.py create mode 100644 tests/dragon/test_inference_request.py create mode 100644 tests/dragon/test_protoclient.py create mode 100644 tests/dragon/test_reply_building.py create mode 100644 tests/dragon/test_request_dispatcher.py create mode 100644 tests/dragon/test_torch_worker.py create mode 100644 tests/dragon/test_worker_manager.py create mode 100644 tests/dragon/utils/__init__.py create mode 100644 tests/dragon/utils/channel.py create mode 100644 tests/dragon/utils/msg_pump.py create mode 100644 tests/dragon/utils/worker.py create mode 100644 tests/mli/__init__.py create mode 100644 tests/mli/channel.py create mode 100644 tests/mli/feature_store.py create mode 100644 tests/mli/test_integrated_torch_worker.py create mode 100644 tests/mli/test_service.py create mode 100644 tests/mli/worker.py create mode 100644 tests/test_dragon_comm_utils.py create mode 100644 tests/test_message_handler/__init__.py create mode 100644 tests/test_message_handler/test_build_model.py create mode 100644 tests/test_message_handler/test_build_model_key.py create mode 100644 tests/test_message_handler/test_build_request_attributes.py create mode 100644 tests/test_message_handler/test_build_tensor_desc.py create mode 100644 tests/test_message_handler/test_build_tensor_key.py create mode 100644 tests/test_message_handler/test_output_descriptor.py create mode 100644 tests/test_message_handler/test_request.py create mode 100644 tests/test_message_handler/test_response.py create mode 100644 tests/test_node_prioritizer.py diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index e3c808410b..9b988520a4 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -54,7 +54,7 @@ jobs: strategy: fail-fast: false matrix: - subset: [backends, slow_tests, group_a, group_b] + subset: [backends, slow_tests, group_a, group_b, dragon] os: [macos-12, macos-14, ubuntu-22.04] # Operating systems compiler: [8] # GNU compiler version rai: [1.2.7] # Redis AI versions @@ -109,8 +109,24 @@ jobs: python -m pip install .[dev,mypy] - name: Install ML Runtimes + if: matrix.subset != 'dragon' run: smart build --device cpu -v + + - name: Install ML Runtimes (with dragon) + if: matrix.subset == 'dragon' + env: + SMARTSIM_DRAGON_TOKEN: ${{ secrets.DRAGON_TOKEN }} + run: | + if [ -n "${SMARTSIM_DRAGON_TOKEN}" ]; then + smart build --device cpu -v --dragon-repo dragonhpc/dragon-nightly --dragon-version 0.10 + else + smart build --device cpu -v --dragon + fi + SP=$(python -c 'import site; print(site.getsitepackages()[0])')/smartsim/_core/config/dragon/.env + LLP=$(cat $SP | grep LD_LIBRARY_PATH | awk '{split($0, array, "="); print array[2]}') + echo "LD_LIBRARY_PATH=$LLP:$LD_LIBRARY_PATH" >> $GITHUB_ENV + - name: Run mypy run: | make check-mypy @@ -134,9 +150,16 @@ jobs: echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ ./tests/backends + # Run pytest (dragon subtests) + - name: Run Dragon Pytest + if: (matrix.subset == 'dragon' && matrix.os == 'ubuntu-22.04') + run: | + echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV + dragon -s py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ -m ${{ matrix.subset }} ./tests + # Run pytest (test subsets) - name: Run Pytest - if: "!contains(matrix.subset, 'backends')" # if not running backend tests + if: (matrix.subset != 'backends' && matrix.subset != 'dragon') # if not running backend tests or dragon tests run: | echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ -m ${{ matrix.subset }} ./tests diff --git a/Makefile b/Makefile index 457bb040ac..4e64033d63 100644 --- a/Makefile +++ b/Makefile @@ -164,22 +164,22 @@ tutorials-prod: # help: test - Run all tests .PHONY: test test: - @python -m pytest --ignore=tests/full_wlm/ + @python -m pytest --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-verbose - Run all tests verbosely .PHONY: test-verbose test-verbose: - @python -m pytest -vv --ignore=tests/full_wlm/ + @python -m pytest -vv --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-debug - Run all tests with debug output .PHONY: test-debug test-debug: - @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ + @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-cov - Run all tests with coverage .PHONY: test-cov test-cov: - @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ + @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-full - Run all WLM tests with Python coverage (full test suite) @@ -192,3 +192,8 @@ test-full: .PHONY: test-wlm test-wlm: @python -m pytest -vv tests/full_wlm/ tests/on_wlm + +# help: test-dragon - Run dragon-specific tests +.PHONY: test-dragon +test-dragon: + @dragon pytest tests/dragon diff --git a/conftest.py b/conftest.py index 991c0d17b6..54a47f9e23 100644 --- a/conftest.py +++ b/conftest.py @@ -93,6 +93,7 @@ test_hostlist = None has_aprun = shutil.which("aprun") is not None + def get_account() -> str: return test_account @@ -227,7 +228,6 @@ def kill_all_test_spawned_processes() -> None: print("Not all processes were killed after test") - def get_hostlist() -> t.Optional[t.List[str]]: global test_hostlist if not test_hostlist: diff --git a/doc/changelog.md b/doc/changelog.md index 8f93a1ae2c..bca9209f7a 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -9,6 +9,39 @@ Jump to: ## SmartSim +### MLI branch + +Description + +- Implement asynchronous notifications for shared data +- Quick bug fix in _validate +- Add helper methods to MLI classes +- Update error handling for consistency +- Parameterize installation of dragon package with `smart build` +- Update docstrings +- Filenames conform to snake case +- Update SmartSim environment variables using new naming convention +- Refactor `exception_handler` +- Add RequestDispatcher and the possibility of batching inference requests +- Enable hostname selection for dragon tasks +- Remove pydantic dependency from MLI code +- Update MLI environment variables using new naming convention +- Reduce a copy by using torch.from_numpy instead of torch.tensor +- Enable dynamic feature store selection +- Fix dragon package installation bug +- Adjust schemas for better performance +- Add TorchWorker first implementation and mock inference app example +- Add error handling in Worker Manager pipeline +- Add EnvironmentConfigLoader for ML Worker Manager +- Add Model schema with model metadata included +- Removed device from schemas, MessageHandler and tests +- Add ML worker manager, sample worker, and feature store +- Add schemas and MessageHandler class for de/serialization of + inference requests and response messages + + +### Develop + To be released at some point in the future Description diff --git a/doc/installation_instructions/basic.rst b/doc/installation_instructions/basic.rst index 226ccb0854..73fbceb253 100644 --- a/doc/installation_instructions/basic.rst +++ b/doc/installation_instructions/basic.rst @@ -305,6 +305,20 @@ For example, to install dragon alongside the RedisAI CPU backends, you can run smart build --device cpu --dragon # install Dragon, PT and TF for cpu +``smart build`` supports installing a specific version of dragon. It exposes the +parameters ``--dragon-repo`` and ``--dragon-version``, which can be used alone or +in combination to customize the Dragon installation. For example: + +.. code-block:: bash + + # using the --dragon-repo and --dragon-version flags to customize the Dragon installation + smart build --device cpu --dragon-repo userfork/dragon # install Dragon from a specific repo + smart build --device cpu --dragon-version 0.10 # install a specific Dragon release + + # combining both flags + smart build --device cpu --dragon-repo userfork/dragon --dragon-version 0.91 + + .. note:: Dragon is only supported on Linux systems. For further information, you can read :ref:`the dedicated documentation page `. diff --git a/ex/high_throughput_inference/mli_driver.py b/ex/high_throughput_inference/mli_driver.py new file mode 100644 index 0000000000..36f427937c --- /dev/null +++ b/ex/high_throughput_inference/mli_driver.py @@ -0,0 +1,77 @@ +import os +import base64 +import cloudpickle +import sys +from smartsim import Experiment +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim.status import TERMINAL_STATUSES +from smartsim.settings import DragonRunSettings +import time +import typing as t + +DEVICE = "gpu" +NUM_RANKS = 4 +NUM_WORKERS = 1 +filedir = os.path.dirname(__file__) +worker_manager_script_name = os.path.join(filedir, "standalone_worker_manager.py") +app_script_name = os.path.join(filedir, "mock_app.py") +model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt") + +transport: t.Literal["hsta", "tcp"] = "hsta" + +os.environ["SMARTSIM_DRAGON_TRANSPORT"] = transport + +exp_path = os.path.join(filedir, f"MLI_proto_{transport.upper()}") +os.makedirs(exp_path, exist_ok=True) +exp = Experiment("MLI_proto", launcher="dragon", exp_path=exp_path) + +torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii") + +worker_manager_rs: DragonRunSettings = exp.create_run_settings( + sys.executable, + [ + worker_manager_script_name, + "--device", + DEVICE, + "--worker_class", + torch_worker_str, + "--batch_size", + str(NUM_RANKS//NUM_WORKERS), + "--batch_timeout", + str(0.00), + "--num_workers", + str(NUM_WORKERS) + ], +) + +aff = [] + +worker_manager_rs.set_cpu_affinity(aff) + +worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs) +worker_manager.attach_generator_files(to_copy=[worker_manager_script_name]) + +app_rs: DragonRunSettings = exp.create_run_settings( + sys.executable, + exe_args=[app_script_name, "--device", DEVICE, "--log_max_batchsize", str(6)], +) +app_rs.set_tasks_per_node(NUM_RANKS) + + +app = exp.create_model("app", run_settings=app_rs) +app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) + +exp.generate(worker_manager, app, overwrite=True) +exp.start(worker_manager, app, block=False) + +while True: + if exp.get_status(app)[0] in TERMINAL_STATUSES: + time.sleep(10) + exp.stop(worker_manager) + break + if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES: + time.sleep(10) + exp.stop(app) + break + +print("Exiting.") diff --git a/ex/high_throughput_inference/mock_app.py b/ex/high_throughput_inference/mock_app.py new file mode 100644 index 0000000000..c3b3eaaf4c --- /dev/null +++ b/ex/high_throughput_inference/mock_app.py @@ -0,0 +1,142 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +import dragon +from dragon import fli +from dragon.channels import Channel +import dragon.channels +from dragon.data.ddict.ddict import DDict +from dragon.globalservices.api_setup import connect_to_infrastructure +from dragon.utils import b64decode, b64encode + +# isort: on + +import argparse +import io + +import torch + +from smartsim.log import get_logger + +torch.set_num_interop_threads(16) +torch.set_num_threads(1) + +logger = get_logger("App") +logger.info("Started app") + +from collections import OrderedDict + +from smartsim.log import get_logger, log_to_file +from smartsim._core.mli.client.protoclient import ProtoClient + +logger = get_logger("App") + + +CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False + + +class ResNetWrapper: + """Wrapper around a pre-rained ResNet model.""" + def __init__(self, name: str, model: str): + """Initialize the instance. + + :param name: The name to use for the model + :param model: The path to the pre-trained PyTorch model""" + self._model = torch.jit.load(model) + self._name = name + buffer = io.BytesIO() + scripted = torch.jit.trace(self._model, self.get_batch()) + torch.jit.save(scripted, buffer) + self._serialized_model = buffer.getvalue() + + def get_batch(self, batch_size: int = 32): + """Create a random batch of data with the correct dimensions to + invoke a ResNet model. + + :param batch_size: The desired number of samples to produce + :returns: A PyTorch tensor""" + return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) + + @property + def model(self) -> bytes: + """The content of a model file. + + :returns: The model bytes""" + return self._serialized_model + + @property + def name(self) -> str: + """The name applied to the model. + + :returns: The name""" + return self._name + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("Mock application") + parser.add_argument("--device", default="cpu", type=str) + parser.add_argument("--log_max_batchsize", default=8, type=int) + args = parser.parse_args() + + resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt") + + client = ProtoClient(timing_on=True) + client.set_model(resnet.name, resnet.model) + + if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: + # TODO: adapt to non-Nvidia devices + torch_device = args.device.replace("gpu", "cuda") + pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to( + torch_device + ) + + TOTAL_ITERATIONS = 100 + + for log2_bsize in range(args.log_max_batchsize + 1): + b_size: int = 2**log2_bsize + logger.info(f"Batch size: {b_size}") + for iteration_number in range(TOTAL_ITERATIONS + int(b_size == 1)): + logger.info(f"Iteration: {iteration_number}") + sample_batch = resnet.get_batch(b_size) + remote_result = client.run_model(resnet.name, sample_batch) + logger.info(client.perf_timer.get_last("total_time")) + if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: + local_res = pt_model(sample_batch.to(torch_device)) + err_norm = torch.linalg.vector_norm( + torch.flatten(remote_result).to(torch_device) + - torch.flatten(local_res), + ord=1, + ).cpu() + res_norm = torch.linalg.vector_norm(remote_result, ord=1).item() + local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item() + logger.info( + f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}" + ) + torch.cuda.synchronize() + + client.perf_timer.print_timings(to_file=True) diff --git a/ex/high_throughput_inference/mock_app_redis.py b/ex/high_throughput_inference/mock_app_redis.py new file mode 100644 index 0000000000..8978bcea23 --- /dev/null +++ b/ex/high_throughput_inference/mock_app_redis.py @@ -0,0 +1,90 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import io +import numpy +import time +import torch +from mpi4py import MPI +from smartsim.log import get_logger +from smartsim._core.utils.timings import PerfTimer +from smartredis import Client + +logger = get_logger("App") + +class ResNetWrapper(): + def __init__(self, name: str, model: str): + self._model = torch.jit.load(model) + self._name = name + buffer = io.BytesIO() + scripted = torch.jit.trace(self._model, self.get_batch()) + torch.jit.save(scripted, buffer) + self._serialized_model = buffer.getvalue() + + def get_batch(self, batch_size: int=32): + return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) + + @property + def model(self): + return self._serialized_model + + @property + def name(self): + return self._name + +if __name__ == "__main__": + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + + parser = argparse.ArgumentParser("Mock application") + parser.add_argument("--device", default="cpu") + args = parser.parse_args() + + resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt") + + client = Client(cluster=False, address=None) + client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper()) + + perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"redis{rank}_") + + total_iterations = 100 + timings=[] + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: + logger.info(f"Batch size: {batch_size}") + for iteration_number in range(total_iterations + int(batch_size==1)): + perf_timer.start_timings("batch_size", batch_size) + logger.info(f"Iteration: {iteration_number}") + input_name = f"batch_{rank}" + output_name = f"result_{rank}" + client.put_tensor(name=input_name, data=resnet.get_batch(batch_size).numpy()) + client.run_model(name=resnet.name, inputs=[input_name], outputs=[output_name]) + result = client.get_tensor(name=output_name) + perf_timer.end_timings() + + + perf_timer.print_timings(True) diff --git a/ex/high_throughput_inference/redis_driver.py b/ex/high_throughput_inference/redis_driver.py new file mode 100644 index 0000000000..ff57725d40 --- /dev/null +++ b/ex/high_throughput_inference/redis_driver.py @@ -0,0 +1,66 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +from smartsim import Experiment +from smartsim.status import TERMINAL_STATUSES +import time + +DEVICE = "gpu" +filedir = os.path.dirname(__file__) +app_script_name = os.path.join(filedir, "mock_app_redis.py") +model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt") + + +exp_path = os.path.join(filedir, "redis_ai_multi") +os.makedirs(exp_path, exist_ok=True) +exp = Experiment("redis_ai_multi", launcher="slurm", exp_path=exp_path) + +db = exp.create_database(interface="hsn0") + +app_rs = exp.create_run_settings( + sys.executable, exe_args = [app_script_name, "--device", DEVICE] + ) +app_rs.set_nodes(1) +app_rs.set_tasks(4) +app = exp.create_model("app", run_settings=app_rs) +app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) + +exp.generate(db, app, overwrite=True) + +exp.start(db, app, block=False) + +while True: + if exp.get_status(app)[0] in TERMINAL_STATUSES: + exp.stop(db) + break + if exp.get_status(db)[0] in TERMINAL_STATUSES: + exp.stop(app) + break + time.sleep(5) + +print("Exiting.") \ No newline at end of file diff --git a/ex/high_throughput_inference/standalone_worker_manager.py b/ex/high_throughput_inference/standalone_worker_manager.py new file mode 100644 index 0000000000..b4527bc5d2 --- /dev/null +++ b/ex/high_throughput_inference/standalone_worker_manager.py @@ -0,0 +1,218 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import dragon + +# pylint disable=import-error +import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc +import dragon.native.process as dragon_process +from dragon import fli +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.globalservices.api_setup import connect_to_infrastructure +from dragon.managed_memory import MemoryPool +from dragon.utils import b64decode, b64encode + +# pylint enable=import-error + +# isort: off +# isort: on + +import argparse +import base64 +import multiprocessing as mp +import os +import socket +import time +import typing as t + +import cloudpickle + +from smartsim._core.entrypoints.service import Service +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.control.request_dispatcher import ( + RequestDispatcher, +) +from smartsim._core.mli.infrastructure.control.worker_manager import WorkerManager +from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger("Worker Manager Entry Point") + +mp.set_start_method("dragon") + +pid = os.getpid() +affinity = os.sched_getaffinity(pid) +logger.info(f"Entry point: {socket.gethostname()}, {affinity}") +logger.info(f"CPUS: {os.cpu_count()}") + + +def service_as_dragon_proc( + service: Service, cpu_affinity: list[int], gpu_affinity: list[int] +) -> dragon_process.Process: + + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + return dragon_process.Process( + target=service.execute, + args=[], + cwd=os.getcwd(), + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Worker Manager") + parser.add_argument( + "--device", + type=str, + default="gpu", + choices="gpu cpu".split(), + help="Device on which the inference takes place", + ) + parser.add_argument( + "--worker_class", + type=str, + required=True, + help="Serialized class of worker to run", + ) + parser.add_argument( + "--num_workers", type=int, default=1, help="Number of workers to run" + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="How many requests the workers will try to aggregate before processing them", + ) + parser.add_argument( + "--batch_timeout", + type=float, + default=0.001, + help="How much time (in seconds) should be waited before processing an incomplete aggregated request", + ) + args = parser.parse_args() + + connect_to_infrastructure() + ddict_str = os.environ[BackboneFeatureStore.MLI_BACKBONE] + + backbone = BackboneFeatureStore.from_descriptor(ddict_str) + + to_worker_channel = create_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli) + + backbone.worker_queue = to_worker_fli_comm_ch.descriptor + + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor + + arg_worker_type = cloudpickle.loads( + base64.b64decode(args.worker_class.encode("ascii")) + ) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher = RequestDispatcher( + batch_timeout=args.batch_timeout, + batch_size=args.batch_size, + config_loader=config_loader, + worker_type=arg_worker_type, + ) + + wms = [] + worker_device = args.device + for wm_idx in range(args.num_workers): + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=arg_worker_type, + as_service=True, + cooldown=10, + device=worker_device, + dispatcher_queue=dispatcher.task_queue, + ) + + wms.append(worker_manager) + + wm_affinity: list[int] = [] + disp_affinity: list[int] = [] + + # This is hardcoded for a specific type of node: + # the GPU-to-CPU mapping is taken from the nvidia-smi tool + # TODO can this be computed on the fly? + gpu_to_cpu_aff: dict[int, list[int]] = {} + gpu_to_cpu_aff[0] = list(range(48, 64)) + list(range(112, 128)) + gpu_to_cpu_aff[1] = list(range(32, 48)) + list(range(96, 112)) + gpu_to_cpu_aff[2] = list(range(16, 32)) + list(range(80, 96)) + gpu_to_cpu_aff[3] = list(range(0, 16)) + list(range(64, 80)) + + worker_manager_procs = [] + for worker_idx in range(args.num_workers): + wm_cpus = len(gpu_to_cpu_aff[worker_idx]) - 4 + wm_affinity = gpu_to_cpu_aff[worker_idx][:wm_cpus] + disp_affinity.extend(gpu_to_cpu_aff[worker_idx][wm_cpus:]) + worker_manager_procs.append( + service_as_dragon_proc( + worker_manager, cpu_affinity=wm_affinity, gpu_affinity=[worker_idx] + ) + ) + + dispatcher_proc = service_as_dragon_proc( + dispatcher, cpu_affinity=disp_affinity, gpu_affinity=[] + ) + + # TODO: use ProcessGroup and restart=True? + all_procs = [dispatcher_proc, *worker_manager_procs] + + print(f"Dispatcher proc: {dispatcher_proc}") + for proc in all_procs: + proc.start() + + while all(proc.is_alive for proc in all_procs): + time.sleep(1) diff --git a/pyproject.toml b/pyproject.toml index 62df92f0c9..61e17891b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ markers = [ "group_a: fast test subset a", "group_b: fast test subset b", "slow_tests: tests that take a long duration to complete", + "dragon: tests that must be executed in a dragon runtime", ] [tool.isort] diff --git a/setup.py b/setup.py index 571974d284..cd5ace55db 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ class BuildError(Exception): pass + # Define needed dependencies for the installation extras_require = { @@ -176,6 +177,7 @@ class BuildError(Exception): "GitPython<=3.1.43", "protobuf<=3.20.3", "jinja2>=3.1.2", + "pycapnp==2.0.0", "watchdog>4,<5", "pydantic>2", "pyzmq>=25.1.2", diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 5d094b72f4..ec9ef4aa29 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -36,7 +36,13 @@ from tabulate import tabulate -from smartsim._core._cli.scripts.dragon_install import install_dragon +from smartsim._core._cli.scripts.dragon_install import ( + DEFAULT_DRAGON_REPO, + DEFAULT_DRAGON_VERSION, + DragonInstallRequest, + display_post_install_logs, + install_dragon, +) from smartsim._core._cli.utils import SMART_LOGGER_FORMAT from smartsim._core._install import builder from smartsim._core._install.buildenv import BuildEnv, DbEngine, Version_, Versioner @@ -67,22 +73,22 @@ def check_backends_install() -> bool: """Checks if backends have already been installed. Logs details on how to proceed forward - if the RAI_PATH environment variable is set or if + if the SMARTSIM_RAI_LIB environment variable is set or if backends have already been installed. """ - rai_path = os.environ.get("RAI_PATH", "") + rai_path = os.environ.get("SMARTSIM_RAI_LIB", "") installed = installed_redisai_backends() msg = "" if rai_path and installed: msg = ( f"There is no need to build. backends are already built and " - f"specified in the environment at 'RAI_PATH': {CONFIG.redisai}" + f"specified in the environment at 'SMARTSIM_RAI_LIB': {CONFIG.redisai}" ) elif rai_path and not installed: msg = ( - "Before running 'smart build', unset your RAI_PATH environment " - "variable with 'unset RAI_PATH'." + "Before running 'smart build', unset your SMARTSIM_RAI_LIB environment " + "variable with 'unset SMARTSIM_RAI_LIB'." ) elif not rai_path and installed: msg = ( @@ -231,7 +237,7 @@ def _configure_keydb_build(versions: Versioner) -> None: CONFIG.conf_path = Path(CONFIG.core_path, "config", "keydb.conf") if not CONFIG.conf_path.resolve().is_file(): raise SSConfigError( - "Database configuration file at REDIS_CONF could not be found" + "Database configuration file at SMARTSIM_REDIS_CONF could not be found" ) @@ -245,6 +251,8 @@ def execute( keydb = args.keydb device = Device.from_str(args.device.lower()) is_dragon_requested = args.dragon + dragon_repo = args.dragon_repo + dragon_version = args.dragon_version if Path(CONFIG.build_path).exists(): logger.warning(f"Build path already exists, removing: {CONFIG.build_path}") @@ -294,12 +302,23 @@ def execute( logger.info("ML Packages") print(mlpackages) - if is_dragon_requested: + if is_dragon_requested or dragon_repo or dragon_version: install_to = CONFIG.core_path / ".dragon" - return_code = install_dragon(install_to) + + try: + request = DragonInstallRequest( + install_to, + dragon_repo, + dragon_version, + ) + return_code = install_dragon(request) + except ValueError as ex: + return_code = 2 + logger.error(" ".join(ex.args)) if return_code == 0: - logger.info("Dragon installation complete") + display_post_install_logs() + elif return_code == 1: logger.info("Dragon installation not supported on platform") else: @@ -358,6 +377,21 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: default=False, help="Install the dragon runtime", ) + parser.add_argument( + "--dragon-repo", + default=None, + type=str, + help=( + "Specify a git repo containing dragon release assets " + f"(e.g. {DEFAULT_DRAGON_REPO})" + ), + ) + parser.add_argument( + "--dragon-version", + default=None, + type=str, + help=f"Specify the dragon version to install (e.g. {DEFAULT_DRAGON_VERSION})", + ) parser.add_argument( "--skip-python-packages", action="store_true", diff --git a/smartsim/_core/_cli/scripts/dragon_install.py b/smartsim/_core/_cli/scripts/dragon_install.py index 8028b8ecfd..3a9358390b 100644 --- a/smartsim/_core/_cli/scripts/dragon_install.py +++ b/smartsim/_core/_cli/scripts/dragon_install.py @@ -1,10 +1,16 @@ import os import pathlib +import re +import shutil import sys import typing as t +from urllib.request import Request, urlopen from github import Github +from github.Auth import Token +from github.GitRelease import GitRelease from github.GitReleaseAsset import GitReleaseAsset +from github.Repository import Repository from smartsim._core._cli.utils import pip from smartsim._core._install.utils import retrieve @@ -15,20 +21,90 @@ logger = get_logger(__name__) +DEFAULT_DRAGON_REPO = "DragonHPC/dragon" +DEFAULT_DRAGON_VERSION = "0.9" +DEFAULT_DRAGON_VERSION_TAG = f"v{DEFAULT_DRAGON_VERSION}" +_GH_TOKEN = "SMARTSIM_DRAGON_TOKEN" -def create_dotenv(dragon_root_dir: pathlib.Path) -> None: + +class DragonInstallRequest: + """Encapsulates a request to install the dragon package""" + + def __init__( + self, + working_dir: pathlib.Path, + repo_name: t.Optional[str] = None, + version: t.Optional[str] = None, + ) -> None: + """Initialize an install request. + + :param working_dir: A path to store temporary files used during installation + :param repo_name: The name of a repository to install from, e.g. DragonHPC/dragon + :param version: The version to install, e.g. v0.10 + """ + + self.working_dir = working_dir + """A path to store temporary files used during installation""" + + self.repo_name = repo_name or DEFAULT_DRAGON_REPO + """The name of a repository to install from, e.g. DragonHPC/dragon""" + + self.pkg_version = version or DEFAULT_DRAGON_VERSION + """The version to install, e.g. 0.10""" + + self._check() + + def _check(self) -> None: + """Perform validation of this instance + + :raises ValueError: if any value fails validation""" + if not self.repo_name or len(self.repo_name.split("/")) != 2: + raise ValueError( + f"Invalid dragon repository name. Example: `dragonhpc/dragon`" + ) + + # version must match standard dragon tag & filename format `vX.YZ` + match = re.match(r"^\d\.\d+$", self.pkg_version) + if not self.pkg_version or not match: + raise ValueError("Invalid dragon version. Examples: `0.9, 0.91, 0.10`") + + # attempting to retrieve from a non-default repository requires an auth token + if self.repo_name.lower() != DEFAULT_DRAGON_REPO.lower() and not self.raw_token: + raise ValueError( + f"An access token must be available to access {self.repo_name}. " + f"Set the `{_GH_TOKEN}` env var to pass your access token." + ) + + @property + def raw_token(self) -> t.Optional[str]: + """Returns the raw access token from the environment, if available""" + return os.environ.get(_GH_TOKEN, None) + + +def get_auth_token(request: DragonInstallRequest) -> t.Optional[Token]: + """Create a Github.Auth.Token if an access token can be found + in the environment + + :param request: details of a request for the installation of the dragon package + :returns: an auth token if one can be built, otherwise `None`""" + if gh_token := request.raw_token: + return Token(gh_token) + return None + + +def create_dotenv(dragon_root_dir: pathlib.Path, dragon_version: str) -> None: """Create a .env file with required environment variables for the Dragon runtime""" dragon_root = str(dragon_root_dir) - dragon_inc_dir = str(dragon_root_dir / "include") - dragon_lib_dir = str(dragon_root_dir / "lib") - dragon_bin_dir = str(dragon_root_dir / "bin") + dragon_inc_dir = dragon_root + "/include" + dragon_lib_dir = dragon_root + "/lib" + dragon_bin_dir = dragon_root + "/bin" dragon_vars = { "DRAGON_BASE_DIR": dragon_root, - "DRAGON_ROOT_DIR": dragon_root, # note: same as base_dir + "DRAGON_ROOT_DIR": dragon_root, "DRAGON_INCLUDE_DIR": dragon_inc_dir, "DRAGON_LIB_DIR": dragon_lib_dir, - "DRAGON_VERSION": dragon_pin(), + "DRAGON_VERSION": dragon_version, "PATH": dragon_bin_dir, "LD_LIBRARY_PATH": dragon_lib_dir, } @@ -48,12 +124,6 @@ def python_version() -> str: return f"py{sys.version_info.major}.{sys.version_info.minor}" -def dragon_pin() -> str: - """Return a string indicating the pinned major/minor version of the dragon - package to install""" - return "0.9" - - def _platform_filter(asset_name: str) -> bool: """Return True if the asset name matches naming standard for current platform (Cray or non-Cray). Otherwise, returns False. @@ -75,67 +145,125 @@ def _version_filter(asset_name: str) -> bool: return python_version() in asset_name -def _pin_filter(asset_name: str) -> bool: +def _pin_filter(asset_name: str, dragon_version: str) -> bool: """Return true if the supplied value contains a dragon version pin match - :param asset_name: A value to inspect for keywords indicating a dragon version + :param asset_name: the asset name to inspect for keywords indicating a dragon version + :param dragon_version: the dragon version to match :returns: True if supplied value is correct for current dragon version""" - return f"dragon-{dragon_pin()}" in asset_name + return f"dragon-{dragon_version}" in asset_name + + +def _get_all_releases(dragon_repo: Repository) -> t.Collection[GitRelease]: + """Retrieve all available releases for the configured dragon repository + + :param dragon_repo: A GitHub repository object for the dragon package + :returns: A list of GitRelease""" + all_releases = [release for release in list(dragon_repo.get_releases())] + return all_releases -def _get_release_assets() -> t.Collection[GitReleaseAsset]: +def _get_release_assets(request: DragonInstallRequest) -> t.Collection[GitReleaseAsset]: """Retrieve a collection of available assets for all releases that satisfy the dragon version pin + :param request: details of a request for the installation of the dragon package :returns: A collection of release assets""" - git = Github() - - dragon_repo = git.get_repo("DragonHPC/dragon") + auth = get_auth_token(request) + git = Github(auth=auth) + dragon_repo = git.get_repo(request.repo_name) if dragon_repo is None: raise SmartSimCLIActionCancelled("Unable to locate dragon repo") - # find any releases matching our pinned version requirement - tags = [tag for tag in dragon_repo.get_tags() if dragon_pin() in tag.name] - # repo.get_latest_release fails if only pre-release results are returned - pin_releases = list(dragon_repo.get_release(tag.name) for tag in tags) - releases = sorted(pin_releases, key=lambda r: r.published_at, reverse=True) + all_releases = sorted( + _get_all_releases(dragon_repo), key=lambda r: r.published_at, reverse=True + ) - # take the most recent release for the given pin - assets = releases[0].assets + # filter the list of releases to include only the target version + releases = [ + release + for release in all_releases + if request.pkg_version in release.title or release.tag_name + ] + + releases = sorted(releases, key=lambda r: r.published_at, reverse=True) + + if not releases: + release_titles = ", ".join(release.title for release in all_releases) + raise SmartSimCLIActionCancelled( + f"Unable to find a release for dragon version {request.pkg_version}. " + f"Available releases: {release_titles}" + ) + + assets: t.List[GitReleaseAsset] = [] + + # install the latest release of the target version (including pre-release) + for release in releases: + # delay in attaching release assets may leave us with an empty list, retry + # with the next available release + if assets := list(release.get_assets()): + logger.debug(f"Found assets for dragon release {release.title}") + break + else: + logger.debug(f"No assets for dragon release {release.title}. Retrying.") + + if not assets: + raise SmartSimCLIActionCancelled( + f"Unable to find assets for dragon release {release.title}" + ) return assets -def filter_assets(assets: t.Collection[GitReleaseAsset]) -> t.Optional[GitReleaseAsset]: +def filter_assets( + request: DragonInstallRequest, assets: t.Collection[GitReleaseAsset] +) -> t.Optional[GitReleaseAsset]: """Filter the available release assets so that HSTA agents are used when run on a Cray EX platform + :param request: details of a request for the installation of the dragon package :param assets: The collection of dragon release assets to filter :returns: An asset meeting platform & version filtering requirements""" # Expect cray & non-cray assets that require a filter, e.g. # 'dragon-0.8-py3.9.4.1-bafaa887f.tar.gz', # 'dragon-0.8-py3.9.4.1-CRAYEX-ac132fe95.tar.gz' - asset = next( - ( - asset - for asset in assets - if _version_filter(asset.name) - and _platform_filter(asset.name) - and _pin_filter(asset.name) - ), - None, + all_assets = [asset.name for asset in assets] + + assets = list( + asset + for asset in assets + if _version_filter(asset.name) and _pin_filter(asset.name, request.pkg_version) ) + + if len(assets) == 0: + available = "\n\t".join(all_assets) + logger.warning( + f"Please specify a dragon version (e.g. {DEFAULT_DRAGON_VERSION}) " + f"of an asset available in the repository:\n\t{available}" + ) + return None + + asset: t.Optional[GitReleaseAsset] = None + + # Apply platform filter if we have multiple matches for python/dragon version + if len(assets) > 0: + asset = next((asset for asset in assets if _platform_filter(asset.name)), None) + + if not asset: + asset = assets[0] + logger.warning(f"Platform-specific package not found. Using {asset.name}") + return asset -def retrieve_asset_info() -> GitReleaseAsset: +def retrieve_asset_info(request: DragonInstallRequest) -> GitReleaseAsset: """Find a release asset that meets all necessary filtering criteria - :param dragon_pin: identify the dragon version to install (e.g. dragon-0.8) + :param request: details of a request for the installation of the dragon package :returns: A GitHub release asset""" - assets = _get_release_assets() - asset = filter_assets(assets) + assets = _get_release_assets(request) + asset = filter_assets(request, assets) platform_result = check_platform() if not platform_result.is_cray: @@ -150,42 +278,79 @@ def retrieve_asset_info() -> GitReleaseAsset: return asset -def retrieve_asset(working_dir: pathlib.Path, asset: GitReleaseAsset) -> pathlib.Path: +def retrieve_asset( + request: DragonInstallRequest, asset: GitReleaseAsset +) -> pathlib.Path: """Retrieve the physical file associated to a given GitHub release asset - :param working_dir: location in file system where assets should be written + :param request: details of a request for the installation of the dragon package :param asset: GitHub release asset to retrieve - :returns: path to the downloaded asset""" - if working_dir.exists() and list(working_dir.rglob("*.whl")): - return working_dir + :returns: path to the directory containing the extracted release asset + :raises SmartSimCLIActionCancelled: if the asset cannot be downloaded or extracted + """ + download_dir = request.working_dir / str(asset.id) + + # if we've previously downloaded the release and still have + # wheels laying around, use that cached version instead + cleanup(download_dir) + download_dir.mkdir(parents=True, exist_ok=True) + + # grab a copy of the complete asset + asset_path = download_dir / str(asset.name) + + # use the asset URL instead of the browser_download_url to enable + # using auth for private repositories + headers: t.Dict[str, str] = {"Accept": "application/octet-stream"} + + if request.raw_token: + headers["Authorization"] = f"Bearer {request.raw_token}" + + try: + # a github asset endpoint causes a redirect. the first request + # receives a pre-signed URL to the asset to pass on to retrieve + dl_request = Request(asset.url, headers=headers) + response = urlopen(dl_request) + presigned_url = response.url + + logger.debug(f"Retrieved asset {asset.name} metadata from {asset.url}") + except Exception: + logger.exception(f"Unable to download {asset.name} from: {asset.url}") + presigned_url = asset.url + + # extract the asset + try: + retrieve(presigned_url, asset_path) - retrieve(asset.browser_download_url, working_dir) + logger.debug(f"Extracted {asset.name} to {download_dir}") + except Exception as ex: + raise SmartSimCLIActionCancelled( + f"Unable to extract {asset.name} from {download_dir}" + ) from ex - logger.debug(f"Retrieved {asset.browser_download_url} to {working_dir}") - return working_dir + return download_dir -def install_package(asset_dir: pathlib.Path) -> int: +def install_package(request: DragonInstallRequest, asset_dir: pathlib.Path) -> int: """Install the package found in `asset_dir` into the current python environment - :param asset_dir: path to a decompressed archive contents for a release asset""" - wheels = asset_dir.rglob("*.whl") - wheel_path = next(wheels, None) - if not wheel_path: - logger.error(f"No wheel found for package in {asset_dir}") + :param request: details of a request for the installation of the dragon package + :param asset_dir: path to a decompressed archive contents for a release asset + :returns: Integer return code, 0 for success, non-zero on failures""" + found_wheels = list(asset_dir.rglob("*.whl")) + if not found_wheels: + logger.error(f"No wheel(s) found for package in {asset_dir}") return 1 - create_dotenv(wheel_path.parent) - - while wheel_path is not None: - logger.info(f"Installing package: {wheel_path.absolute()}") + create_dotenv(found_wheels[0].parent, request.pkg_version) - try: - pip("install", "--force-reinstall", str(wheel_path), "numpy<2") - wheel_path = next(wheels, None) - except Exception: - logger.error(f"Unable to install from {asset_dir}") - return 1 + try: + wheels = list(map(str, found_wheels)) + for wheel_path in wheels: + logger.info(f"Installing package: {wheel_path}") + pip("install", wheel_path) + except Exception: + logger.error(f"Unable to install from {asset_dir}") + return 1 return 0 @@ -196,36 +361,83 @@ def cleanup( """Delete the downloaded asset and any files extracted during installation :param archive_path: path to a downloaded archive for a release asset""" - if archive_path: - archive_path.unlink(missing_ok=True) - logger.debug(f"Deleted archive: {archive_path}") + if not archive_path: + return + + if archive_path.exists() and archive_path.is_file(): + archive_path.unlink() + archive_path = archive_path.parent + if archive_path.exists() and archive_path.is_dir(): + shutil.rmtree(archive_path, ignore_errors=True) + logger.debug(f"Deleted temporary files in: {archive_path}") -def install_dragon(extraction_dir: t.Union[str, os.PathLike[str]]) -> int: + +def install_dragon(request: DragonInstallRequest) -> int: """Retrieve a dragon runtime appropriate for the current platform and install to the current python environment - :param extraction_dir: path for download and extraction of assets + + :param request: details of a request for the installation of the dragon package :returns: Integer return code, 0 for success, non-zero on failures""" if sys.platform == "darwin": logger.debug(f"Dragon not supported on platform: {sys.platform}") return 1 - extraction_dir = pathlib.Path(extraction_dir) - filename: t.Optional[pathlib.Path] = None asset_dir: t.Optional[pathlib.Path] = None try: - asset_info = retrieve_asset_info() - asset_dir = retrieve_asset(extraction_dir, asset_info) + asset_info = retrieve_asset_info(request) + if asset_info is not None: + asset_dir = retrieve_asset(request, asset_info) + return install_package(request, asset_dir) - return install_package(asset_dir) + except SmartSimCLIActionCancelled as ex: + logger.warning(*ex.args) except Exception as ex: - logger.error("Unable to install dragon runtime", exc_info=ex) - finally: - cleanup(filename) + logger.error("Unable to install dragon runtime", exc_info=True) return 2 +def display_post_install_logs() -> None: + """Display post-installation instructions for the user""" + + examples = { + "ofi-include": "/opt/cray/include", + "ofi-build-lib": "/opt/cray/lib64", + "ofi-runtime-lib": "/opt/cray/lib64", + } + + config = ":".join(f"{k}={v}" for k, v in examples.items()) + example_msg1 = f"dragon-config -a \\" + example_msg2 = f' "{config}"' + + logger.info( + "************************** Dragon Package Installed *****************************" + ) + logger.info("To enable Dragon to use HSTA (default: TCP), configure the following:") + + for key in examples: + logger.info(f"\t{key}") + + logger.info("Example:") + logger.info(example_msg1) + logger.info(example_msg2) + logger.info( + "*********************************************************************************" + ) + + if __name__ == "__main__": - sys.exit(install_dragon(CONFIG.core_path / ".dragon")) + # path for download and extraction of assets + extraction_dir = CONFIG.core_path / ".dragon" + dragon_repo = DEFAULT_DRAGON_REPO + dragon_version = DEFAULT_DRAGON_VERSION + + request = DragonInstallRequest( + extraction_dir, + dragon_repo, + dragon_version, + ) + + sys.exit(install_dragon(request)) diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 17036e825e..957f2b6ef6 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -246,7 +246,9 @@ def build_from_git(self, git_url: str, branch: str) -> None: bin_path = Path(dependency_path, "bin").resolve() try: database_exe = next(bin_path.glob("*-server")) - database = Path(os.environ.get("REDIS_PATH", database_exe)).resolve() + database = Path( + os.environ.get("SMARTSIM_REDIS_SERVER_EXE", database_exe) + ).resolve() _ = expand_exe_path(str(database)) except (TypeError, FileNotFoundError) as e: raise BuildError("Installation of redis-server failed!") from e @@ -254,7 +256,9 @@ def build_from_git(self, git_url: str, branch: str) -> None: # validate install -- redis-cli try: redis_cli_exe = next(bin_path.glob("*-cli")) - redis_cli = Path(os.environ.get("REDIS_CLI_PATH", redis_cli_exe)).resolve() + redis_cli = Path( + os.environ.get("SMARTSIM_REDIS_CLI_EXE", redis_cli_exe) + ).resolve() _ = expand_exe_path(str(redis_cli)) except (TypeError, FileNotFoundError) as e: raise BuildError("Installation of redis-cli failed!") from e diff --git a/smartsim/_core/config/config.py b/smartsim/_core/config/config.py index 03c284edb3..c8b4ff17b9 100644 --- a/smartsim/_core/config/config.py +++ b/smartsim/_core/config/config.py @@ -40,19 +40,19 @@ # These values can be set through environment variables to # override the default behavior of SmartSim. # -# RAI_PATH +# SMARTSIM_RAI_LIB # - Path to the RAI shared library # - Default: /smartsim/smartsim/_core/lib/redisai.so # -# REDIS_CONF +# SMARTSIM_REDIS_CONF # - Path to the redis.conf file # - Default: /SmartSim/smartsim/_core/config/redis.conf # -# REDIS_PATH +# SMARTSIM_REDIS_SERVER_EXE # - Path to the redis-server executable # - Default: /SmartSim/smartsim/_core/bin/redis-server # -# REDIS_CLI_PATH +# SMARTSIM_REDIS_CLI_EXE # - Path to the redis-cli executable # - Default: /SmartSim/smartsim/_core/bin/redis-cli # @@ -120,20 +120,20 @@ def build_path(self) -> Path: @property def redisai(self) -> str: rai_path = self.lib_path / "redisai.so" - redisai = Path(os.environ.get("RAI_PATH", rai_path)).resolve() + redisai = Path(os.environ.get("SMARTSIM_RAI_LIB", rai_path)).resolve() if not redisai.is_file(): raise SSConfigError( "RedisAI dependency not found. Build with `smart` cli " - "or specify RAI_PATH" + "or specify SMARTSIM_RAI_LIB" ) return str(redisai) @property def database_conf(self) -> str: - conf = Path(os.environ.get("REDIS_CONF", self.conf_path)).resolve() + conf = Path(os.environ.get("SMARTSIM_REDIS_CONF", self.conf_path)).resolve() if not conf.is_file(): raise SSConfigError( - "Database configuration file at REDIS_CONF could not be found" + "Database configuration file at SMARTSIM_REDIS_CONF could not be found" ) return str(conf) @@ -141,24 +141,29 @@ def database_conf(self) -> str: def database_exe(self) -> str: try: database_exe = next(self.bin_path.glob("*-server")) - database = Path(os.environ.get("REDIS_PATH", database_exe)).resolve() + database = Path( + os.environ.get("SMARTSIM_REDIS_SERVER_EXE", database_exe) + ).resolve() exe = expand_exe_path(str(database)) return exe except (TypeError, FileNotFoundError) as e: raise SSConfigError( - "Specified database binary at REDIS_PATH could not be used" + "Specified database binary at SMARTSIM_REDIS_SERVER_EXE " + "could not be used" ) from e @property def database_cli(self) -> str: try: redis_cli_exe = next(self.bin_path.glob("*-cli")) - redis_cli = Path(os.environ.get("REDIS_CLI_PATH", redis_cli_exe)).resolve() + redis_cli = Path( + os.environ.get("SMARTSIM_REDIS_CLI_EXE", redis_cli_exe) + ).resolve() exe = expand_exe_path(str(redis_cli)) return exe except (TypeError, FileNotFoundError) as e: raise SSConfigError( - "Specified Redis binary at REDIS_CLI_PATH could not be used" + "Specified Redis binary at SMARTSIM_REDIS_CLI_EXE could not be used" ) from e @property @@ -178,7 +183,7 @@ def dragon_dotenv(self) -> Path: def dragon_server_path(self) -> t.Optional[str]: return os.getenv( "SMARTSIM_DRAGON_SERVER_PATH", - os.getenv("SMARTSIM_DRAGON_SERVER_PATH_EXP", None), + os.getenv("_SMARTSIM_DRAGON_SERVER_PATH_EXP", None), ) @property @@ -306,10 +311,6 @@ def smartsim_key_path(self) -> str: default_path = Path.home() / ".smartsim" / "keys" return os.environ.get("SMARTSIM_KEY_PATH", str(default_path)) - @property - def dragon_pin(self) -> str: - return "0.9" - @lru_cache(maxsize=128, typed=False) def get_config() -> Config: diff --git a/smartsim/_core/entrypoints/service.py b/smartsim/_core/entrypoints/service.py new file mode 100644 index 0000000000..719c2a60fe --- /dev/null +++ b/smartsim/_core/entrypoints/service.py @@ -0,0 +1,185 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import datetime +import time +import typing as t +from abc import ABC, abstractmethod + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class Service(ABC): + """Core API for standalone entrypoint scripts. Makes use of overridable hook + methods to modify behaviors (event loop, automatic shutdown, cooldown) as + well as simple hooks for status changes""" + + def __init__( + self, + as_service: bool = False, + cooldown: float = 0, + loop_delay: float = 0, + health_check_frequency: float = 0, + ) -> None: + """Initialize the Service + + :param as_service: Determines lifetime of the service. When `True`, calling + execute on the service will run continuously until shutdown criteria are met. + Otherwise, `execute` performs a single pass through the service lifecycle and + automatically exits (regardless of the result of `_can_shutdown`). + :param cooldown: Period of time (in seconds) to allow the service to run + after a shutdown is permitted. Enables the service to avoid restarting if + new work is discovered. A value of 0 disables the cooldown. + :param loop_delay: Duration (in seconds) of a forced delay between + iterations of the event loop + :param health_check_frequency: Time (in seconds) between calls to a + health check handler. A value of 0 triggers the health check on every + iteration. + """ + self._as_service = as_service + """Determines lifetime of the service. When `True`, calling + `execute` on the service will run continuously until shutdown criteria are met. + Otherwise, `execute` performs a single pass through the service lifecycle and + automatically exits (regardless of the result of `_can_shutdown`).""" + self._cooldown = abs(cooldown) + """Period of time (in seconds) to allow the service to run + after a shutdown is permitted. Enables the service to avoid restarting if + new work is discovered. A value of 0 disables the cooldown.""" + self._loop_delay = abs(loop_delay) + """Duration (in seconds) of a forced delay between + iterations of the event loop""" + self._health_check_frequency = health_check_frequency + """Time (in seconds) between calls to a + health check handler. A value of 0 triggers the health check on every + iteration.""" + self._last_health_check = time.time() + """The timestamp of the latest health check""" + + @abstractmethod + def _on_iteration(self) -> None: + """The user-defined event handler. Executed repeatedly until shutdown + conditions are satisfied and cooldown is elapsed. + """ + + @abstractmethod + def _can_shutdown(self) -> bool: + """Return true when the criteria to shut down the service are met.""" + + def _on_start(self) -> None: + """Empty hook method for use by subclasses. Called on initial entry into + Service `execute` event loop before `_on_iteration` is invoked.""" + logger.debug(f"Starting {self.__class__.__name__}") + + def _on_shutdown(self) -> None: + """Empty hook method for use by subclasses. Called immediately after exiting + the main event loop during automatic shutdown.""" + logger.debug(f"Shutting down {self.__class__.__name__}") + + def _on_health_check(self) -> None: + """Empty hook method for use by subclasses. Invoked based on the + value of `self._health_check_frequency`.""" + logger.debug(f"Performing health check for {self.__class__.__name__}") + + def _on_cooldown_elapsed(self) -> None: + """Empty hook method for use by subclasses. Called on every event loop + iteration immediately upon exceeding the cooldown period""" + logger.debug(f"Cooldown exceeded by {self.__class__.__name__}") + + def _on_delay(self) -> None: + """Empty hook method for use by subclasses. Called on every event loop + iteration immediately before executing a delay before the next iteration""" + logger.debug(f"Service iteration waiting for {self.__class__.__name__}s") + + def _log_cooldown(self, elapsed: float) -> None: + """Log the remaining cooldown time, if any""" + remaining = self._cooldown - elapsed + if remaining > 0: + logger.debug(f"{abs(remaining):.2f}s remains of {self._cooldown}s cooldown") + else: + logger.info(f"exceeded cooldown {self._cooldown}s by {abs(remaining):.2f}s") + + def execute(self) -> None: + """The main event loop of a service host. Evaluates shutdown criteria and + combines with a cooldown period to allow automatic service termination. + Responsible for executing calls to subclass implementation of `_on_iteration`""" + + try: + self._on_start() + except Exception: + logger.exception("Unable to start service.") + return + + running = True + cooldown_start: t.Optional[datetime.datetime] = None + + while running: + try: + self._on_iteration() + except Exception: + running = False + logger.exception( + "Failure in event loop resulted in service termination" + ) + + if self._health_check_frequency >= 0: + hc_elapsed = time.time() - self._last_health_check + if hc_elapsed >= self._health_check_frequency: + self._on_health_check() + self._last_health_check = time.time() + + # allow immediate shutdown if not set to run as a service + if not self._as_service: + running = False + continue + + # reset cooldown period if shutdown criteria are not met + if not self._can_shutdown(): + cooldown_start = None + + # start tracking cooldown elapsed once eligible to quit + if cooldown_start is None: + cooldown_start = datetime.datetime.now() + + # change running state if cooldown period is exceeded + if self._cooldown > 0: + elapsed = datetime.datetime.now() - cooldown_start + running = elapsed.total_seconds() < self._cooldown + self._log_cooldown(elapsed.total_seconds()) + if not running: + self._on_cooldown_elapsed() + elif self._cooldown < 1 and self._can_shutdown(): + running = False + + if self._loop_delay: + self._on_delay() + time.sleep(self._loop_delay) + + try: + self._on_shutdown() + except Exception: + logger.exception("Service shutdown may not have completed.") diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragonBackend.py index 4aba60d558..5e01299141 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragonBackend.py @@ -26,6 +26,8 @@ import collections import functools import itertools +import os +import socket import time import typing as t from dataclasses import dataclass, field @@ -34,15 +36,27 @@ from tabulate import tabulate -# pylint: disable=import-error +# pylint: disable=import-error,C0302,R0915 # isort: off + import dragon.infrastructure.connection as dragon_connection import dragon.infrastructure.policy as dragon_policy -import dragon.native.group_state as dragon_group_state +import dragon.infrastructure.process_desc as dragon_process_desc + import dragon.native.process as dragon_process import dragon.native.process_group as dragon_process_group import dragon.native.machine as dragon_machine +from smartsim._core.launcher.dragon.pqueue import NodePrioritizer, PrioritizerFilter +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_util import create_ddict +from smartsim.error.errors import SmartSimError + # pylint: enable=import-error # isort: on from ...._core.config import get_config @@ -68,8 +82,8 @@ class DragonStatus(str, Enum): - ERROR = str(dragon_group_state.Error()) - RUNNING = str(dragon_group_state.Running()) + ERROR = "Error" + RUNNING = "Running" def __str__(self) -> str: return self.value @@ -86,7 +100,7 @@ class ProcessGroupInfo: return_codes: t.Optional[t.List[int]] = None """List of return codes of completed processes""" hosts: t.List[str] = field(default_factory=list) - """List of hosts on which the Process Group """ + """List of hosts on which the Process Group should be executed""" redir_workers: t.Optional[dragon_process_group.ProcessGroup] = None """Workers used to redirect stdout and stderr to file""" @@ -143,6 +157,11 @@ class DragonBackend: by threads spawned by it. """ + _DEFAULT_NUM_MGR_PER_NODE = 2 + """The default number of manager processes for each feature store node""" + _DEFAULT_MEM_PER_NODE = 512 * 1024**2 + """The default memory capacity (in bytes) to allocate for a feaure store node""" + def __init__(self, pid: int) -> None: self._pid = pid """PID of dragon executable which launched this server""" @@ -153,7 +172,6 @@ def __init__(self, pid: int) -> None: self._step_ids = (f"{create_short_id_str()}-{id}" for id in itertools.count()) """Incremental ID to assign to new steps prior to execution""" - self._initialize_hosts() self._queued_steps: "collections.OrderedDict[str, DragonRunRequest]" = ( collections.OrderedDict() ) @@ -177,16 +195,26 @@ def __init__(self, pid: int) -> None: """Whether the server frontend should shut down when the backend does""" self._shutdown_initiation_time: t.Optional[float] = None """The time at which the server initiated shutdown""" - smartsim_config = get_config() - self._cooldown_period = ( - smartsim_config.telemetry_frequency * 2 + 5 - if smartsim_config.telemetry_enabled - else 5 - ) - """Time in seconds needed to server to complete shutdown""" + self._cooldown_period = self._initialize_cooldown() + """Time in seconds needed by the server to complete shutdown""" + self._backbone: t.Optional[BackboneFeatureStore] = None + """The backbone feature store""" + self._listener: t.Optional[dragon_process.Process] = None + """The standalone process executing the event consumer""" + + self._nodes: t.List["dragon_machine.Node"] = [] + """Node capability information for hosts in the allocation""" + self._hosts: t.List[str] = [] + """List of hosts available in allocation""" + self._cpus: t.List[int] = [] + """List of cpu-count by node""" + self._gpus: t.List[int] = [] + """List of gpu-count by node""" + self._allocated_hosts: t.Dict[str, t.Set[str]] = {} + """Mapping with hostnames as keys and a set of running step IDs as the value""" - self._view = DragonBackendView(self) - logger.debug(self._view.host_desc) + self._initialize_hosts() + self._prioritizer = NodePrioritizer(self._nodes, self._queue_lock) @property def hosts(self) -> list[str]: @@ -194,34 +222,39 @@ def hosts(self) -> list[str]: return self._hosts @property - def allocated_hosts(self) -> dict[str, str]: + def allocated_hosts(self) -> dict[str, t.Set[str]]: + """A map of host names to the step id executing on a host + + :returns: Dictionary with host name as key and step id as value""" with self._queue_lock: return self._allocated_hosts @property - def free_hosts(self) -> t.Deque[str]: + def free_hosts(self) -> t.Sequence[str]: + """Find hosts that do not have a step assigned + + :returns: List of host names""" with self._queue_lock: - return self._free_hosts + return list(map(lambda x: x.hostname, self._prioritizer.unassigned())) @property def group_infos(self) -> dict[str, ProcessGroupInfo]: + """Find information pertaining to process groups executing on a host + + :returns: Dictionary with host name as key and group information as value""" with self._queue_lock: return self._group_infos def _initialize_hosts(self) -> None: + """Prepare metadata about the allocation""" with self._queue_lock: self._nodes = [ dragon_machine.Node(node) for node in dragon_machine.System().nodes ] - self._hosts: t.List[str] = sorted(node.hostname for node in self._nodes) + self._hosts = sorted(node.hostname for node in self._nodes) self._cpus = [node.num_cpus for node in self._nodes] self._gpus = [node.num_gpus for node in self._nodes] - - """List of hosts available in allocation""" - self._free_hosts: t.Deque[str] = collections.deque(self._hosts) - """List of hosts on which steps can be launched""" - self._allocated_hosts: t.Dict[str, str] = {} - """Mapping of hosts on which a step is already running to step ID""" + self._allocated_hosts = collections.defaultdict(set) def __str__(self) -> str: return self.status_message @@ -230,21 +263,19 @@ def __str__(self) -> str: def status_message(self) -> str: """Message with status of available nodes and history of launched jobs. - :returns: Status message + :returns: a status message """ - return ( - "Dragon server backend update\n" - f"{self._view.host_table}\n{self._view.step_table}" - ) + view = DragonBackendView(self) + return "Dragon server backend update\n" f"{view.host_table}\n{view.step_table}" def _heartbeat(self) -> None: + """Update the value of the last heartbeat to the current time.""" self._last_beat = self.current_time @property def cooldown_period(self) -> int: - """Time (in seconds) the server will wait before shutting down - - when exit conditions are met (see ``should_shutdown()`` for further details). + """Time (in seconds) the server will wait before shutting down when + exit conditions are met (see ``should_shutdown()`` for further details). """ return self._cooldown_period @@ -278,6 +309,8 @@ def should_shutdown(self) -> bool: and it requested immediate shutdown, or if it did not request immediate shutdown, but all jobs have been executed. In both cases, a cooldown period may need to be waited before shutdown. + + :returns: `True` if the server should terminate, otherwise `False` """ if self._shutdown_requested and self._can_shutdown: return self._has_cooled_down @@ -285,7 +318,9 @@ def should_shutdown(self) -> bool: @property def current_time(self) -> float: - """Current time for DragonBackend object, in seconds since the Epoch""" + """Current time for DragonBackend object, in seconds since the Epoch + + :returns: the current timestamp""" return time.time() def _can_honor_policy( @@ -293,63 +328,149 @@ def _can_honor_policy( ) -> t.Tuple[bool, t.Optional[str]]: """Check if the policy can be honored with resources available in the allocation. - :param request: DragonRunRequest containing policy information + + :param request: `DragonRunRequest` to validate :returns: Tuple indicating if the policy can be honored and an optional error message""" # ensure the policy can be honored if request.policy: + logger.debug(f"{request.policy=}{self._cpus=}{self._gpus=}") + if request.policy.cpu_affinity: # make sure some node has enough CPUs - available = max(self._cpus) + last_available = max(self._cpus or [-1]) requested = max(request.policy.cpu_affinity) - - if requested >= available: + if not any(self._cpus) or requested >= last_available: return False, "Cannot satisfy request, not enough CPUs available" - if request.policy.gpu_affinity: # make sure some node has enough GPUs - available = max(self._gpus) + last_available = max(self._gpus or [-1]) requested = max(request.policy.gpu_affinity) - - if requested >= available: + if not any(self._gpus) or requested >= last_available: + logger.warning( + f"failed check w/{self._gpus=}, {requested=}, {last_available=}" + ) return False, "Cannot satisfy request, not enough GPUs available" - return True, None def _can_honor(self, request: DragonRunRequest) -> t.Tuple[bool, t.Optional[str]]: - """Check if request can be honored with resources available in the allocation. - - Currently only checks for total number of nodes, - in the future it will also look at other constraints - such as memory, accelerators, and so on. + """Check if request can be honored with resources available in + the allocation. Currently only checks for total number of nodes, + in the future it will also look at other constraints such as memory, + accelerators, and so on. + + :param request: `DragonRunRequest` to validate + :returns: Tuple indicating if the request can be honored and + an optional error message """ - if request.nodes > len(self._hosts): - message = f"Cannot satisfy request. Requested {request.nodes} nodes, " - message += f"but only {len(self._hosts)} nodes are available." - return False, message - if self._shutdown_requested: - message = "Cannot satisfy request, server is shutting down." - return False, message + honorable, err = self._can_honor_state(request) + if not honorable: + return False, err honorable, err = self._can_honor_policy(request) if not honorable: return False, err + honorable, err = self._can_honor_hosts(request) + if not honorable: + return False, err + + return True, None + + def _can_honor_hosts( + self, request: DragonRunRequest + ) -> t.Tuple[bool, t.Optional[str]]: + """Check if the current state of the backend process inhibits executing + the request. + + :param request: `DragonRunRequest` to validate + :returns: Tuple indicating if the request can be honored and + an optional error message""" + all_hosts = frozenset(self._hosts) + num_nodes = request.nodes + + # fail if requesting more nodes than the total number available + if num_nodes > len(all_hosts): + message = f"Cannot satisfy request. {num_nodes} requested nodes" + message += f" exceeds {len(all_hosts)} available." + return False, message + + requested_hosts = all_hosts + if request.hostlist: + requested_hosts = frozenset( + {host.strip() for host in request.hostlist.split(",")} + ) + + valid_hosts = all_hosts.intersection(requested_hosts) + invalid_hosts = requested_hosts - valid_hosts + + logger.debug(f"{num_nodes=}{valid_hosts=}{invalid_hosts=}") + + if invalid_hosts: + logger.warning(f"Some invalid hostnames were requested: {invalid_hosts}") + + # fail if requesting specific hostnames and there aren't enough available + if num_nodes > len(valid_hosts): + message = f"Cannot satisfy request. Requested {num_nodes} nodes, " + message += f"but only {len(valid_hosts)} named hosts are available." + return False, message + + return True, None + + def _can_honor_state( + self, _request: DragonRunRequest + ) -> t.Tuple[bool, t.Optional[str]]: + """Check if the current state of the backend process inhibits executing + the request. + :param _request: the DragonRunRequest to verify + :returns: Tuple indicating if the request can be honored and + an optional error message""" + if self._shutdown_requested: + message = "Cannot satisfy request, server is shutting down." + return False, message + return True, None def _allocate_step( self, step_id: str, request: DragonRunRequest ) -> t.Optional[t.List[str]]: + """Identify the hosts on which the request will be executed + :param step_id: The identifier of a step that will be executed on the host + :param request: The request to be executed + :returns: A list of selected hostnames""" + # ensure at least one host is selected num_hosts: int = request.nodes + with self._queue_lock: - if num_hosts <= 0 or num_hosts > len(self._free_hosts): + if num_hosts <= 0 or num_hosts > len(self._hosts): + logger.debug( + f"The number of requested hosts ({num_hosts}) is invalid or" + f" cannot be satisfied with {len(self._hosts)} available nodes" + ) return None - to_allocate = [] - for _ in range(num_hosts): - host = self._free_hosts.popleft() - self._allocated_hosts[host] = step_id - to_allocate.append(host) + + hosts = [] + if request.hostlist: + # convert the comma-separated argument into a real list + hosts = [host for host in request.hostlist.split(",") if host] + + filter_on: t.Optional[PrioritizerFilter] = None + if request.policy and request.policy.gpu_affinity: + filter_on = PrioritizerFilter.GPU + + nodes = self._prioritizer.next_n(num_hosts, filter_on, step_id, hosts) + + if len(nodes) < num_hosts: + # exit if the prioritizer can't identify enough nodes + return None + + to_allocate = [node.hostname for node in nodes] + + for hostname in to_allocate: + # track assigning this step to each node + self._allocated_hosts[hostname].add(step_id) + return to_allocate @staticmethod @@ -389,6 +510,7 @@ def _create_redirect_workers( return grp_redir def _stop_steps(self) -> None: + """Trigger termination of all currently executing steps""" self._heartbeat() with self._queue_lock: while len(self._stop_requests) > 0: @@ -427,18 +549,96 @@ def _stop_steps(self) -> None: self._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED self._group_infos[step_id].return_codes = [-9] + def _create_backbone(self) -> BackboneFeatureStore: + """ + Creates a BackboneFeatureStore if one does not exist. Updates + environment variables of this process to include the backbone + descriptor. + + :returns: The backbone feature store + """ + if self._backbone is None: + backbone_storage = create_ddict( + len(self._hosts), + self._DEFAULT_NUM_MGR_PER_NODE, + self._DEFAULT_MEM_PER_NODE, + ) + + self._backbone = BackboneFeatureStore( + backbone_storage, allow_reserved_writes=True + ) + + # put the backbone descriptor in the env vars + os.environ.update(self._backbone.get_env()) + + return self._backbone + + @staticmethod + def _initialize_cooldown() -> int: + """Load environment configuration and determine the correct cooldown + period to apply to the backend process. + + :returns: The calculated cooldown (in seconds) + """ + smartsim_config = get_config() + return ( + smartsim_config.telemetry_frequency * 2 + 5 + if smartsim_config.telemetry_enabled + else 5 + ) + + def start_event_listener( + self, cpu_affinity: list[int], gpu_affinity: list[int] + ) -> dragon_process.Process: + """Start a standalone event listener. + + :param cpu_affinity: The CPU affinity for the process + :param gpu_affinity: The GPU affinity for the process + :returns: The dragon Process managing the process + :raises SmartSimError: If the backbone is not provided + """ + if self._backbone is None: + raise SmartSimError("Backbone feature store is not available") + + service = ConsumerRegistrationListener( + self._backbone, 1.0, 2.0, as_service=True, health_check_frequency=90 + ) + + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + process = dragon_process.Process( + target=service.execute, + args=[], + cwd=os.getcwd(), + env={ + **os.environ, + **self._backbone.get_env(), + }, + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) + process.start() + return process + @staticmethod def create_run_policy( request: DragonRequest, node_name: str ) -> "dragon_policy.Policy": """Create a dragon Policy from the request and node name + :param request: DragonRunRequest containing policy information :param node_name: Name of the node on which the process will run :returns: dragon_policy.Policy object mapped from request properties""" if isinstance(request, DragonRunRequest): run_request: DragonRunRequest = request - affinity = dragon_policy.Policy.Affinity.DEFAULT cpu_affinity: t.List[int] = [] gpu_affinity: t.List[int] = [] @@ -446,25 +646,20 @@ def create_run_policy( if run_request.policy is not None: # Affinities are not mutually exclusive. If specified, both are used if run_request.policy.cpu_affinity: - affinity = dragon_policy.Policy.Affinity.SPECIFIC cpu_affinity = run_request.policy.cpu_affinity if run_request.policy.gpu_affinity: - affinity = dragon_policy.Policy.Affinity.SPECIFIC gpu_affinity = run_request.policy.gpu_affinity logger.debug( - f"Affinity strategy: {affinity}, " f"CPU affinity mask: {cpu_affinity}, " f"GPU affinity mask: {gpu_affinity}" ) - if affinity != dragon_policy.Policy.Affinity.DEFAULT: - return dragon_policy.Policy( - placement=dragon_policy.Policy.Placement.HOST_NAME, - host_name=node_name, - affinity=affinity, - cpu_affinity=cpu_affinity, - gpu_affinity=gpu_affinity, - ) + return dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=node_name, + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) return dragon_policy.Policy( placement=dragon_policy.Policy.Placement.HOST_NAME, @@ -472,7 +667,9 @@ def create_run_policy( ) def _start_steps(self) -> None: + """Start all new steps created since the last update.""" self._heartbeat() + with self._queue_lock: started = [] for step_id, request in self._queued_steps.items(): @@ -482,10 +679,8 @@ def _start_steps(self) -> None: logger.debug(f"Step id {step_id} allocated on {hosts}") - global_policy = dragon_policy.Policy( - placement=dragon_policy.Policy.Placement.HOST_NAME, - host_name=hosts[0], - ) + global_policy = self.create_run_policy(request, hosts[0]) + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) grp = dragon_process_group.ProcessGroup( restart=False, pmi_enabled=request.pmi_enabled, policy=global_policy ) @@ -498,10 +693,15 @@ def _start_steps(self) -> None: target=request.exe, args=request.exe_args, cwd=request.path, - env={**request.current_env, **request.env}, + env={ + **request.current_env, + **request.env, + **(self._backbone.get_env() if self._backbone else {}), + }, stdout=dragon_process.Popen.PIPE, stderr=dragon_process.Popen.PIPE, policy=local_policy, + options=options, ) grp.add_process(nproc=request.tasks_per_node, template=tmp_proc) @@ -567,9 +767,11 @@ def _start_steps(self) -> None: logger.error(e) def _refresh_statuses(self) -> None: + """Query underlying management system for step status and update + stored assigned and unassigned task information""" self._heartbeat() with self._queue_lock: - terminated = [] + terminated: t.Set[str] = set() for step_id in self._running_steps: group_info = self._group_infos[step_id] grp = group_info.process_group @@ -603,11 +805,15 @@ def _refresh_statuses(self) -> None: ) if group_info.status in TERMINAL_STATUSES: - terminated.append(step_id) + terminated.add(step_id) if terminated: logger.debug(f"{terminated=}") + # remove all the terminated steps from all hosts + for host in list(self._allocated_hosts.keys()): + self._allocated_hosts[host].difference_update(terminated) + for step_id in terminated: self._running_steps.remove(step_id) self._completed_steps.append(step_id) @@ -615,15 +821,20 @@ def _refresh_statuses(self) -> None: if group_info is not None: for host in group_info.hosts: logger.debug(f"Releasing host {host}") - try: - self._allocated_hosts.pop(host) - except KeyError: + if host not in self._allocated_hosts: logger.error(f"Tried to free a non-allocated host: {host}") - self._free_hosts.append(host) + else: + # remove any hosts that have had all their steps terminated + if not self._allocated_hosts[host]: + self._allocated_hosts.pop(host) + self._prioritizer.decrement(host, step_id) group_info.process_group = None group_info.redir_workers = None def _update_shutdown_status(self) -> None: + """Query the status of running tasks and update the status + of any that have completed. + """ self._heartbeat() with self._queue_lock: self._can_shutdown |= ( @@ -637,12 +848,18 @@ def _update_shutdown_status(self) -> None: ) def _should_print_status(self) -> bool: + """Determine if status messages should be printed based off the last + update. Returns `True` to trigger prints, `False` otherwise. + """ if self.current_time - self._last_update_time > 10: self._last_update_time = self.current_time return True return False def _update(self) -> None: + """Trigger all update queries and update local state database""" + self._create_backbone() + self._stop_steps() self._start_steps() self._refresh_statuses() @@ -650,6 +867,9 @@ def _update(self) -> None: def _kill_all_running_jobs(self) -> None: with self._queue_lock: + if self._listener and self._listener.is_alive: + self._listener.kill() + for step_id, group_info in self._group_infos.items(): if group_info.status not in TERMINAL_STATUSES: self._stop_requests.append(DragonStopRequest(step_id=step_id)) @@ -730,8 +950,14 @@ def _(self, request: DragonShutdownRequest) -> DragonShutdownResponse: class DragonBackendView: - def __init__(self, backend: DragonBackend): + def __init__(self, backend: DragonBackend) -> None: + """Initialize the instance + + :param backend: A dragon backend used to produce the view""" self._backend = backend + """A dragon backend used to produce the view""" + + logger.debug(self.host_desc) @property def host_desc(self) -> str: @@ -793,9 +1019,7 @@ def step_table(self) -> str: @property def host_table(self) -> str: """Table representation of current state of nodes available - - in the allocation. - """ + in the allocation.""" headers = ["Host", "Status"] hosts = self._backend.hosts free_hosts = self._backend.free_hosts diff --git a/smartsim/_core/launcher/dragon/dragonConnector.py b/smartsim/_core/launcher/dragon/dragonConnector.py index 0cd68c24e9..1144b7764e 100644 --- a/smartsim/_core/launcher/dragon/dragonConnector.py +++ b/smartsim/_core/launcher/dragon/dragonConnector.py @@ -71,17 +71,23 @@ class DragonConnector: def __init__(self) -> None: self._context: zmq.Context[t.Any] = zmq.Context.instance() + """ZeroMQ context used to share configuration across requests""" self._context.setsockopt(zmq.REQ_CORRELATE, 1) self._context.setsockopt(zmq.REQ_RELAXED, 1) self._authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None + """ZeroMQ authenticator used to secure queue access""" config = get_config() self._reset_timeout(config.dragon_server_timeout) self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None + """ZeroMQ socket exposing the connection to the DragonBackend""" self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None + """A handle to the process executing the DragonBackend""" # Returned by dragon head, useful if shutdown is to be requested # but process was started by another connector self._dragon_head_pid: t.Optional[int] = None + """Process ID of the process executing the DragonBackend""" self._dragon_server_path = config.dragon_server_path + """Path to a dragon installation""" logger.debug(f"Dragon Server path was set to {self._dragon_server_path}") self._env_vars: t.Dict[str, str] = {} if self._dragon_server_path is None: @@ -95,7 +101,7 @@ def __init__(self) -> None: @property def is_connected(self) -> bool: - """Whether the Connector established a connection to the server + """Whether the Connector established a connection to the server. :return: True if connected """ @@ -104,12 +110,18 @@ def is_connected(self) -> bool: @property def can_monitor(self) -> bool: """Whether the Connector knows the PID of the dragon server head process - and can monitor its status + and can monitor its status. :return: True if the server can be monitored""" return self._dragon_head_pid is not None def _handshake(self, address: str) -> None: + """Perform the handshake process with the DragonBackend and + confirm two-way communication is established. + + :param address: The address of the head node socket to initiate a + handhake with + """ self._dragon_head_socket = dragonSockets.get_secure_socket( self._context, zmq.REQ, False ) @@ -132,6 +144,11 @@ def _handshake(self, address: str) -> None: ) from e def _reset_timeout(self, timeout: int = get_config().dragon_server_timeout) -> None: + """Reset the timeout applied to the ZMQ context. If an authenticator is + enabled, also update the authenticator timeouts. + + :param timeout: The timeout value to apply to ZMQ sockets + """ self._context.setsockopt(zmq.SNDTIMEO, value=timeout) self._context.setsockopt(zmq.RCVTIMEO, value=timeout) if self._authenticator is not None and self._authenticator.thread is not None: @@ -183,11 +200,19 @@ def _get_new_authenticator( @staticmethod def _get_dragon_log_level() -> str: + """Maps the log level from SmartSim to a valid log level + for a dragon process. + + :returns: The dragon log level string + """ smartsim_to_dragon = defaultdict(lambda: "NONE") smartsim_to_dragon["developer"] = "INFO" return smartsim_to_dragon.get(get_config().log_level, "NONE") def _connect_to_existing_server(self, path: Path) -> None: + """Connects to an existing DragonBackend using address information from + a persisted dragon log file. + """ config = get_config() dragon_config_log = path / config.dragon_log_filename @@ -217,6 +242,11 @@ def _connect_to_existing_server(self, path: Path) -> None: return def _start_connector_socket(self, socket_addr: str) -> zmq.Socket[t.Any]: + """Instantiate the ZMQ socket to be used by the connector. + + :param socket_addr: The socket address the connector should bind to + :returns: The bound socket + """ config = get_config() connector_socket: t.Optional[zmq.Socket[t.Any]] = None self._reset_timeout(config.dragon_server_startup_timeout) @@ -245,9 +275,14 @@ def load_persisted_env(self) -> t.Dict[str, str]: with open(config.dragon_dotenv, encoding="utf-8") as dot_env: for kvp in dot_env.readlines(): - split = kvp.strip().split("=", maxsplit=1) - key, value = split[0], split[-1] - self._env_vars[key] = value + if not kvp: + continue + + # skip any commented lines + if not kvp.startswith("#"): + split = kvp.strip().split("=", maxsplit=1) + key, value = split[0], split[-1] + self._env_vars[key] = value return self._env_vars @@ -418,6 +453,15 @@ def send_request(self, request: DragonRequest, flags: int = 0) -> DragonResponse def _parse_launched_dragon_server_info_from_iterable( stream: t.Iterable[str], num_dragon_envs: t.Optional[int] = None ) -> t.List[t.Dict[str, str]]: + """Parses dragon backend connection information from a stream. + + :param stream: The stream to inspect. Usually the stdout of the + DragonBackend process + :param num_dragon_envs: The expected number of dragon environments + to parse from the stream. + :returns: A list of dictionaries, one per environment, containing + the parsed server information + """ lines = (line.strip() for line in stream) lines = (line for line in lines if line) tokenized = (line.split(maxsplit=1) for line in lines) @@ -444,6 +488,15 @@ def _parse_launched_dragon_server_info_from_files( file_paths: t.List[t.Union[str, "os.PathLike[str]"]], num_dragon_envs: t.Optional[int] = None, ) -> t.List[t.Dict[str, str]]: + """Read a known log file into a Stream and parse dragon server configuration + from the stream. + + :param file_paths: Path to a file containing dragon server configuration + :num_dragon_envs: The expected number of dragon environments to be found + in the file + :returns: The parsed server configuration, one item per + discovered dragon environment + """ with fileinput.FileInput(file_paths) as ifstream: dragon_envs = cls._parse_launched_dragon_server_info_from_iterable( ifstream, num_dragon_envs @@ -458,6 +511,15 @@ def _send_req_with_socket( send_flags: int = 0, recv_flags: int = 0, ) -> DragonResponse: + """Sends a synchronous request through a ZMQ socket. + + :param socket: Socket to send on + :param request: The request to send + :param send_flags: Configuration to apply to the send operation + :param recv_flags: Configuration to apply to the recv operation; used to + allow the receiver to immediately respond to the sent request. + :returns: The response from the target + """ client = dragonSockets.as_client(socket) with DRG_LOCK: logger.debug(f"Sending {type(request).__name__}: {request}") @@ -469,6 +531,13 @@ def _send_req_with_socket( def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT: + """Verify that objects can be sent as messages acceptable to the target. + + :param obj: The message to test + :param typ: The type that is acceptable + :returns: The original `obj` if it is of the requested type + :raises TypeError: If the object fails the test and is not + an instance of the desired type""" if not isinstance(obj, typ): raise TypeError(f"Expected schema of type `{typ}`, but got {type(obj)}") return obj @@ -520,6 +589,12 @@ def _dragon_cleanup( def _resolve_dragon_path(fallback: t.Union[str, "os.PathLike[str]"]) -> Path: + """Determine the applicable dragon server path for the connector + + :param fallback: A default dragon server path to use if one is not + found in the runtime configuration + :returns: The path to the dragon libraries + """ dragon_server_path = get_config().dragon_server_path or os.path.join( fallback, ".smartsim", "dragon" ) diff --git a/smartsim/_core/launcher/dragon/dragonLauncher.py b/smartsim/_core/launcher/dragon/dragonLauncher.py index 9078fed54f..75ca675225 100644 --- a/smartsim/_core/launcher/dragon/dragonLauncher.py +++ b/smartsim/_core/launcher/dragon/dragonLauncher.py @@ -170,6 +170,7 @@ def run(self, step: Step) -> t.Optional[str]: merged_env = self._connector.merge_persisted_env(os.environ.copy()) nodes = int(run_args.get("nodes", None) or 1) tasks_per_node = int(run_args.get("tasks-per-node", None) or 1) + hosts = run_args.get("host-list", None) policy = DragonRunPolicy.from_run_args(run_args) @@ -187,6 +188,7 @@ def run(self, step: Step) -> t.Optional[str]: output_file=out, error_file=err, policy=policy, + hostlist=hosts, ) ), DragonRunResponse, diff --git a/smartsim/_core/launcher/dragon/pqueue.py b/smartsim/_core/launcher/dragon/pqueue.py new file mode 100644 index 0000000000..8c14a828f5 --- /dev/null +++ b/smartsim/_core/launcher/dragon/pqueue.py @@ -0,0 +1,461 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# import collections +import enum +import heapq +import threading +import typing as t + +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class Node(t.Protocol): + """Base Node API required to support the NodePrioritizer""" + + @property + def hostname(self) -> str: + """The hostname of the node""" + + @property + def num_cpus(self) -> int: + """The number of CPUs in the node""" + + @property + def num_gpus(self) -> int: + """The number of GPUs in the node""" + + +class NodeReferenceCount(t.Protocol): + """Contains details pertaining to references to a node""" + + @property + def hostname(self) -> str: + """The hostname of the node""" + + @property + def num_refs(self) -> int: + """The number of jobs assigned to the node""" + + +class _TrackedNode: + """Node API required to have support in the NodePrioritizer""" + + def __init__(self, node: Node) -> None: + self._node = node + """The node being tracked""" + self._num_refs = 0 + """The number of references to the tracked node""" + self._assigned_tasks: t.Set[str] = set() + """The unique identifiers of processes using this node""" + self._is_dirty = False + """Flag indicating that tracking information has been modified""" + + @property + def hostname(self) -> str: + """Returns the hostname of the node""" + return self._node.hostname + + @property + def num_cpus(self) -> int: + """Returns the number of CPUs in the node""" + return self._node.num_cpus + + @property + def num_gpus(self) -> int: + """Returns the number of GPUs attached to the node""" + return self._node.num_gpus + + @property + def num_refs(self) -> int: + """Returns the number of processes currently running on the node""" + return self._num_refs + + @property + def is_assigned(self) -> bool: + """Returns `True` if no references are currently counted, `False` otherwise""" + return self._num_refs > 0 + + @property + def assigned_tasks(self) -> t.Set[str]: + """Returns the set of unique IDs for currently running processes""" + return self._assigned_tasks + + @property + def is_dirty(self) -> bool: + """Returns a flag indicating if the reference counter has changed. `True` + if references have been added or removed, `False` otherwise.""" + return self._is_dirty + + def clean(self) -> None: + """Marks the node as unmodified""" + self._is_dirty = False + + def add( + self, + tracking_id: t.Optional[str] = None, + ) -> None: + """Update the node to indicate the addition of a process that must be + reference counted. + + :param tracking_id: a unique task identifier executing on the node + to add + :raises ValueError: if tracking_id is already assigned to this node""" + if tracking_id in self.assigned_tasks: + raise ValueError("Attempted adding task more than once") + + self._num_refs = self._num_refs + 1 + if tracking_id: + self._assigned_tasks = self._assigned_tasks.union({tracking_id}) + self._is_dirty = True + + def remove( + self, + tracking_id: t.Optional[str] = None, + ) -> None: + """Update the reference counter to indicate the removal of a process. + + :param tracking_id: a unique task identifier executing on the node + to remove + :raises ValueError: if tracking_id is already assigned to this node""" + self._num_refs = max(self._num_refs - 1, 0) + if tracking_id: + self._assigned_tasks = self._assigned_tasks - {tracking_id} + self._is_dirty = True + + def __lt__(self, other: "_TrackedNode") -> bool: + """Comparison operator used to evaluate the ordering of nodes within + the prioritizer. This comparison only considers reference counts. + + :param other: Another node to compare against + :returns: True if this node has fewer references than the other node""" + if self.num_refs < other.num_refs: + return True + + return False + + +class PrioritizerFilter(str, enum.Enum): + """A filter used to select a subset of nodes to be queried""" + + CPU = enum.auto() + GPU = enum.auto() + + +class NodePrioritizer: + def __init__(self, nodes: t.List[Node], lock: threading.RLock) -> None: + """Initialize the prioritizer + + :param nodes: node attribute information for initializing the priorizer + :param lock: a lock used to ensure threadsafe operations + :raises SmartSimError: if the nodes collection is empty + """ + if not nodes: + raise SmartSimError("Missing nodes to prioritize") + + self._lock = lock + """Lock used to ensure thread safe changes of the reference counters""" + self._cpu_refs: t.List[_TrackedNode] = [] + """Track reference counts to CPU-only nodes""" + self._gpu_refs: t.List[_TrackedNode] = [] + """Track reference counts to GPU nodes""" + self._nodes: t.Dict[str, _TrackedNode] = {} + + self._initialize_reference_counters(nodes) + + def _initialize_reference_counters(self, nodes: t.List[Node]) -> None: + """Perform initialization of reference counters for nodes in the allocation + + :param nodes: node attribute information for initializing the priorizer""" + for node in nodes: + # create a set of reference counters for the nodes + tracked = _TrackedNode(node) + + self._nodes[node.hostname] = tracked # for O(1) access + + if node.num_gpus: + self._gpu_refs.append(tracked) + else: + self._cpu_refs.append(tracked) + + def increment( + self, host: str, tracking_id: t.Optional[str] = None + ) -> NodeReferenceCount: + """Directly increment the reference count of a given node and ensure the + ref counter is marked as dirty to trigger a reordering on retrieval + + :param host: a hostname that should have a reference counter selected + :param tracking_id: a unique task identifier executing on the node + to add""" + with self._lock: + tracked_node = self._nodes[host] + tracked_node.add(tracking_id) + return tracked_node + + def _heapify_all_refs(self) -> t.List[_TrackedNode]: + """Combine the CPU and GPU nodes into a single heap + + :returns: list of all reference counters""" + refs = [*self._cpu_refs, *self._gpu_refs] + heapq.heapify(refs) + return refs + + def get_tracking_info(self, host: str) -> NodeReferenceCount: + """Returns the reference counter information for a single node + + :param host: a hostname that should have a reference counter selected + :returns: a reference counter for the node + :raises ValueError: if the hostname is not in the set of managed nodes""" + if host not in self._nodes: + raise ValueError("The supplied hostname was not found") + + return self._nodes[host] + + def decrement( + self, host: str, tracking_id: t.Optional[str] = None + ) -> NodeReferenceCount: + """Directly decrement the reference count of a given node and ensure the + ref counter is marked as dirty to trigger a reordering + + :param host: a hostname that should have a reference counter decremented + :param tracking_id: unique task identifier to remove""" + with self._lock: + tracked_node = self._nodes[host] + tracked_node.remove(tracking_id) + + return tracked_node + + def _create_sub_heap( + self, + hosts: t.Optional[t.List[str]] = None, + filter_on: t.Optional[PrioritizerFilter] = None, + ) -> t.List[_TrackedNode]: + """Create a new heap from the primary heap with user-specified nodes + + :param hosts: a list of hostnames used to filter the available nodes + :returns: a list of assigned reference counters + """ + nodes_tracking_info: t.List[_TrackedNode] = [] + heap = self._get_filtered_heap(filter_on) + + # Collect all the tracking info for the requested nodes... + for node in heap: + if not hosts or node.hostname in hosts: + nodes_tracking_info.append(node) + + # ... and use it to create a new heap from a specified subset of nodes + heapq.heapify(nodes_tracking_info) + + return nodes_tracking_info + + def unassigned( + self, heap: t.Optional[t.List[_TrackedNode]] = None + ) -> t.Sequence[Node]: + """Select nodes that are currently not assigned a task + + :param heap: a subset of the node heap to consider + :returns: a list of reference counts for all unassigned nodes""" + if heap is None: + heap = list(self._nodes.values()) + + nodes: t.List[_TrackedNode] = [] + for item in heap: + if item.num_refs == 0: + nodes.append(item) + return nodes + + def assigned( + self, heap: t.Optional[t.List[_TrackedNode]] = None + ) -> t.Sequence[Node]: + """Helper method to identify the nodes that are currently assigned + + :param heap: a subset of the node heap to consider + :returns: a list of reference counts for all assigned nodes""" + if heap is None: + heap = list(self._nodes.values()) + + nodes: t.List[_TrackedNode] = [] + for item in heap: + if item.num_refs > 0: + nodes.append(item) + return nodes + + def _check_satisfiable_n( + self, num_items: int, heap: t.Optional[t.List[_TrackedNode]] = None + ) -> bool: + """Validates that a request for some number of nodes `n` can be + satisfied by the prioritizer given the set of nodes available + + :param num_items: the desired number of nodes to allocate + :param heap: a subset of the node heap to consider + :returns: True if the request can be fulfilled, False otherwise""" + num_nodes = len(self._nodes.keys()) + + if num_items < 1: + msg = "Cannot handle request; request requires a positive integer" + logger.warning(msg) + return False + + if num_nodes < num_items: + msg = f"Cannot satisfy request for {num_items} nodes; {num_nodes} in pool" + logger.warning(msg) + return False + + num_open = len(self.unassigned(heap)) + if num_open < num_items: + msg = f"Cannot satisfy request for {num_items} nodes; {num_open} available" + logger.warning(msg) + return False + + return True + + def _get_next_unassigned_node( + self, + heap: t.List[_TrackedNode], + tracking_id: t.Optional[str] = None, + ) -> t.Optional[Node]: + """Finds the next node with no running processes and + ensures that any elements that were directly updated are updated in + the priority structure before being made available + + :param heap: a subset of the node heap to consider + :param tracking_id: unique task identifier to track + :returns: a reference counter for an available node if an unassigned node + exists, `None` otherwise""" + tracking_info: t.Optional[_TrackedNode] = None + + with self._lock: + # re-sort the heap to handle any tracking changes + if any(node.is_dirty for node in heap): + heapq.heapify(heap) + + # grab the min node from the heap + tracking_info = heapq.heappop(heap) + + # the node is available if it has no assigned tasks + is_assigned = tracking_info.is_assigned + if not is_assigned: + # track the new process on the node + tracking_info.add(tracking_id) + + # add the node that was popped back into the heap + heapq.heappush(heap, tracking_info) + + # mark all nodes as clean now that everything is updated & sorted + for node in heap: + node.clean() + + # next available must only return previously unassigned nodes + if is_assigned: + return None + + return tracking_info + + def _get_next_n_available_nodes( + self, + num_items: int, + heap: t.List[_TrackedNode], + tracking_id: t.Optional[str] = None, + ) -> t.List[Node]: + """Find the next N available nodes w/least amount of references using + the supplied filter to target a specific node capability + + :param num_items: number of nodes to reserve + :param heap: a subset of the node heap to consider + :param tracking_id: unique task identifier to track + :returns: a list of reference counters for a available nodes if enough + unassigned nodes exists, `None` otherwise + :raises ValueError: if the number of requested nodes is not a positive integer + """ + next_nodes: t.List[Node] = [] + + if num_items < 1: + raise ValueError(f"Number of items requested {num_items} is invalid") + + if not self._check_satisfiable_n(num_items, heap): + return next_nodes + + while len(next_nodes) < num_items: + if next_node := self._get_next_unassigned_node(heap, tracking_id): + next_nodes.append(next_node) + continue + break + + return next_nodes + + def _get_filtered_heap( + self, filter_on: t.Optional[PrioritizerFilter] = None + ) -> t.List[_TrackedNode]: + """Helper method to select the set of nodes to include in a filtered + heap. + + :param filter_on: A list of nodes that satisfy the filter. If no + filter is supplied, all nodes are returned""" + if filter_on == PrioritizerFilter.GPU: + return self._gpu_refs + if filter_on == PrioritizerFilter.CPU: + return self._cpu_refs + + return self._heapify_all_refs() + + def next( + self, + filter_on: t.Optional[PrioritizerFilter] = None, + tracking_id: t.Optional[str] = None, + hosts: t.Optional[t.List[str]] = None, + ) -> t.Optional[Node]: + """Find the next unsassigned node using the supplied filter to target + a specific node capability + + :param filter_on: the subset of nodes to query for available nodes + :param tracking_id: unique task identifier to track + :param hosts: a list of hostnames used to filter the available nodes + :returns: a reference counter for an available node if an unassigned node + exists, `None` otherwise""" + if results := self.next_n(1, filter_on, tracking_id, hosts): + return results[0] + return None + + def next_n( + self, + num_items: int = 1, + filter_on: t.Optional[PrioritizerFilter] = None, + tracking_id: t.Optional[str] = None, + hosts: t.Optional[t.List[str]] = None, + ) -> t.List[Node]: + """Find the next N available nodes w/least amount of references using + the supplied filter to target a specific node capability + + :param num_items: number of nodes to reserve + :param filter_on: the subset of nodes to query for available nodes + :param tracking_id: unique task identifier to track + :param hosts: a list of hostnames used to filter the available nodes + :returns: Collection of reserved nodes + :raises ValueError: if the hosts parameter is an empty list""" + heap = self._create_sub_heap(hosts, filter_on) + return self._get_next_n_available_nodes(num_items, heap, tracking_id) diff --git a/smartsim/_core/launcher/step/dragonStep.py b/smartsim/_core/launcher/step/dragonStep.py index dd93d7910c..8583ceeb1b 100644 --- a/smartsim/_core/launcher/step/dragonStep.py +++ b/smartsim/_core/launcher/step/dragonStep.py @@ -169,6 +169,7 @@ def _write_request_file(self) -> str: env = run_settings.env_vars nodes = int(run_args.get("nodes", None) or 1) tasks_per_node = int(run_args.get("tasks-per-node", None) or 1) + hosts_csv = run_args.get("host-list", None) policy = DragonRunPolicy.from_run_args(run_args) @@ -187,6 +188,7 @@ def _write_request_file(self) -> str: output_file=out, error_file=err, policy=policy, + hostlist=hosts_csv, ) requests.append(request_registry.to_string(request)) with open(request_file, "w", encoding="utf-8") as script_file: diff --git a/smartsim/_core/mli/__init__.py b/smartsim/_core/mli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/client/__init__.py b/smartsim/_core/mli/client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/client/protoclient.py b/smartsim/_core/mli/client/protoclient.py new file mode 100644 index 0000000000..46598a8171 --- /dev/null +++ b/smartsim/_core/mli/client/protoclient.py @@ -0,0 +1,348 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +# pylint: disable=unused-import,import-error +import dragon +import dragon.channels +from dragon.globalservices.api_setup import connect_to_infrastructure + +try: + from mpi4py import MPI # type: ignore[import-not-found] +except Exception: + MPI = None + print("Unable to import `mpi4py` package") + +# isort: on +# pylint: enable=unused-import,import-error + +import numbers +import os +import time +import typing as t +from collections import OrderedDict + +import numpy +import torch + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.utils.timings import PerfTimer +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +_TimingDict = OrderedDict[str, list[str]] + + +logger = get_logger("App") +logger.info("Started app") +CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False + + +class ProtoClient: + """Proof of concept implementation of a client enabling user applications + to interact with MLI resources.""" + + _DEFAULT_BACKBONE_TIMEOUT = 1.0 + """A default timeout period applied to connection attempts with the + backbone feature store.""" + + _DEFAULT_WORK_QUEUE_SIZE = 500 + """A default number of events to be buffered in the work queue before + triggering QueueFull exceptions.""" + + _EVENT_SOURCE = "proto-client" + """A user-friendly name for this class instance to identify + the client as the publisher of an event.""" + + @staticmethod + def _attach_to_backbone() -> BackboneFeatureStore: + """Use the supplied environment variables to attach + to a pre-existing backbone featurestore. Requires the + environment to contain `_SMARTSIM_INFRA_BACKBONE` + environment variable. + + :returns: The attached backbone featurestore + :raises SmartSimError: If the backbone descriptor is not contained + in the appropriate environment variable + """ + descriptor = os.environ.get(BackboneFeatureStore.MLI_BACKBONE, None) + if descriptor is None or not descriptor: + raise SmartSimError( + "Missing required backbone configuration in environment: " + f"{BackboneFeatureStore.MLI_BACKBONE}" + ) + + backbone = t.cast( + BackboneFeatureStore, BackboneFeatureStore.from_descriptor(descriptor) + ) + return backbone + + def _attach_to_worker_queue(self) -> DragonFLIChannel: + """Wait until the backbone contains the worker queue configuration, + then attach an FLI to the given worker queue. + + :returns: The attached FLI channel + :raises SmartSimError: if the required configuration is not found in the + backbone feature store + """ + + descriptor = "" + try: + # NOTE: without wait_for, this MUST be in the backbone.... + config = self._backbone.wait_for( + [BackboneFeatureStore.MLI_WORKER_QUEUE], self.backbone_timeout + ) + descriptor = str(config[BackboneFeatureStore.MLI_WORKER_QUEUE]) + except Exception as ex: + logger.info( + f"Unable to retrieve {BackboneFeatureStore.MLI_WORKER_QUEUE} " + "to attach to the worker queue." + ) + raise SmartSimError("Unable to locate worker queue using backbone") from ex + + return DragonFLIChannel.from_descriptor(descriptor) + + def _create_broadcaster(self) -> EventBroadcaster: + """Create an EventBroadcaster that broadcasts events to + all MLI components registered to consume them. + + :returns: An EventBroadcaster instance + """ + broadcaster = EventBroadcaster( + self._backbone, DragonCommChannel.from_descriptor + ) + return broadcaster + + def __init__( + self, + timing_on: bool, + backbone_timeout: float = _DEFAULT_BACKBONE_TIMEOUT, + ) -> None: + """Initialize the client instance. + + :param timing_on: Flag indicating if timing information should be + written to file + :param backbone_timeout: Maximum wait time (in seconds) allowed to attach to the + worker queue + :raises SmartSimError: If unable to attach to a backbone featurestore + :raises ValueError: If an invalid backbone timeout is specified + """ + if MPI is not None: + # TODO: determine a way to make MPI work in the test environment + # - consider catching the import exception and defaulting rank to 0 + comm = MPI.COMM_WORLD + rank: int = comm.Get_rank() + else: + rank = 0 + + if backbone_timeout <= 0: + raise ValueError( + f"Invalid backbone timeout provided: {backbone_timeout}. " + "The value must be greater than zero." + ) + self._backbone_timeout = max(backbone_timeout, 0.1) + + connect_to_infrastructure() + + self._backbone = self._attach_to_backbone() + self._backbone.wait_timeout = self.backbone_timeout + self._to_worker_fli = self._attach_to_worker_queue() + + self._from_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE) + self._to_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE) + + self._publisher = self._create_broadcaster() + + self.perf_timer: PerfTimer = PerfTimer( + debug=False, timing_on=timing_on, prefix=f"a{rank}_" + ) + self._start: t.Optional[float] = None + self._interm: t.Optional[float] = None + self._timings: _TimingDict = OrderedDict() + self._timing_on = timing_on + + @property + def backbone_timeout(self) -> float: + """The timeout (in seconds) applied to retrievals + from the backbone feature store. + + :returns: A float indicating the number of seconds to allow""" + return self._backbone_timeout + + def _add_label_to_timings(self, label: str) -> None: + """Adds a new label into the timing dictionary to prepare for + receiving timing events. + + :param label: The label to create storage for + """ + if label not in self._timings: + self._timings[label] = [] + + @staticmethod + def _format_number(number: t.Union[numbers.Number, float]) -> str: + """Utility function for formatting numbers consistently for logs. + + :param number: The number to convert to a formatted string + :returns: The formatted string containing the number + """ + return f"{number:0.4e}" + + def start_timings(self, batch_size: numbers.Number) -> None: + """Configure the client to begin storing timing information. + + :param batch_size: The size of batches to generate as inputs + to the model + """ + if self._timing_on: + self._add_label_to_timings("batch_size") + self._timings["batch_size"].append(self._format_number(batch_size)) + self._start = time.perf_counter() + self._interm = time.perf_counter() + + def end_timings(self) -> None: + """Configure the client to stop storing timing information.""" + if self._timing_on and self._start is not None: + self._add_label_to_timings("total_time") + self._timings["total_time"].append( + self._format_number(time.perf_counter() - self._start) + ) + + def measure_time(self, label: str) -> None: + """Measures elapsed time since the last recorded signal. + + :param label: The label to measure time for + """ + if self._timing_on and self._interm is not None: + self._add_label_to_timings(label) + self._timings[label].append( + self._format_number(time.perf_counter() - self._interm) + ) + self._interm = time.perf_counter() + + def print_timings(self, to_file: bool = False) -> None: + """Print timing information to standard output. If `to_file` + is `True`, also write results to a file. + + :param to_file: If `True`, also saves timing information + to the files `timings.npy` and `timings.txt` + """ + print(" ".join(self._timings.keys())) + + value_array = numpy.array(self._timings.values(), dtype=float) + value_array = numpy.transpose(value_array) + for i in range(value_array.shape[0]): + print(" ".join(self._format_number(value) for value in value_array[i])) + if to_file: + numpy.save("timings.npy", value_array) + numpy.savetxt("timings.txt", value_array) + + def run_model(self, model: t.Union[bytes, str], batch: torch.Tensor) -> t.Any: + """Execute a batch of inference requests with the supplied ML model. + + :param model: The raw bytes or path to a pytorch model + :param batch: The tensor batch to perform inference on + :returns: The inference results + :raises ValueError: if the worker queue is not configured properly + in the environment variables + """ + tensors = [batch.numpy()] + self.perf_timer.start_timings("batch_size", batch.shape[0]) + built_tensor_desc = MessageHandler.build_tensor_descriptor( + "c", "float32", list(batch.shape) + ) + self.perf_timer.measure_time("build_tensor_descriptor") + if isinstance(model, str): + model_arg = MessageHandler.build_model_key(model, self._backbone.descriptor) + else: + model_arg = MessageHandler.build_model( + model, "resnet-50", "1.0" + ) # type: ignore + request = MessageHandler.build_request( + reply_channel=self._from_worker_ch.descriptor, + model=model_arg, + inputs=[built_tensor_desc], + outputs=[], + output_descriptors=[], + custom_attributes=None, + ) + self.perf_timer.measure_time("build_request") + request_bytes = MessageHandler.serialize_request(request) + self.perf_timer.measure_time("serialize_request") + + if self._to_worker_fli is None: + raise ValueError("No worker queue available.") + + # pylint: disable-next=protected-access + with self._to_worker_fli._channel.sendh( # type: ignore + timeout=None, + stream_channel=self._to_worker_ch.channel, + ) as to_sendh: + to_sendh.send_bytes(request_bytes) + self.perf_timer.measure_time("send_request") + for tensor in tensors: + to_sendh.send_bytes(tensor.tobytes()) # TODO NOT FAST ENOUGH!!! + logger.info(f"Message size: {len(request_bytes)} bytes") + + self.perf_timer.measure_time("send_tensors") + with self._from_worker_ch.channel.recvh(timeout=None) as from_recvh: + resp = from_recvh.recv_bytes(timeout=None) + self.perf_timer.measure_time("receive_response") + response = MessageHandler.deserialize_response(resp) + self.perf_timer.measure_time("deserialize_response") + + # recv depending on the len(response.result.descriptors)? + data_blob: bytes = from_recvh.recv_bytes(timeout=None) + self.perf_timer.measure_time("receive_tensor") + result = torch.from_numpy( + numpy.frombuffer( + data_blob, + dtype=str(response.result.descriptors[0].dataType), + ) + ) + self.perf_timer.measure_time("deserialize_tensor") + + self.perf_timer.end_timings() + return result + + def set_model(self, key: str, model: bytes) -> None: + """Write the supplied model to the feature store. + + :param key: The unique key used to identify the model + :param model: The raw bytes of the model to execute + """ + self._backbone[key] = model + + # notify components of a change in the data at this key + event = OnWriteFeatureStore(self._EVENT_SOURCE, self._backbone.descriptor, key) + self._publisher.send(event) diff --git a/smartsim/_core/mli/comm/channel/__init__.py b/smartsim/_core/mli/comm/channel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/comm/channel/channel.py b/smartsim/_core/mli/comm/channel/channel.py new file mode 100644 index 0000000000..104333ce7f --- /dev/null +++ b/smartsim/_core/mli/comm/channel/channel.py @@ -0,0 +1,82 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import typing as t +import uuid +from abc import ABC, abstractmethod + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class CommChannelBase(ABC): + """Base class for abstracting a message passing mechanism""" + + def __init__( + self, + descriptor: str, + name: t.Optional[str] = None, + ) -> None: + """Initialize the CommChannel instance. + + :param descriptor: Channel descriptor + """ + self._descriptor = descriptor + """An opaque identifier used to connect to an underlying communication channel""" + self._name = name or str(uuid.uuid4()) + """A user-friendly identifier for channel-related logging""" + + @abstractmethod + def send(self, value: bytes, timeout: float = 0.001) -> None: + """Send a message through the underlying communication channel. + + :param value: The value to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :raises SmartSimError: If sending message fails + """ + + @abstractmethod + def recv(self, timeout: float = 0.001) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: Maximum time to wait (in seconds) for messages to arrive + :returns: The received message + """ + + @property + def descriptor(self) -> str: + """Return the channel descriptor for the underlying dragon channel. + + :returns: Byte encoded channel descriptor + """ + return self._descriptor + + def __str__(self) -> str: + """Build a string representation of the channel useful for printing.""" + classname = type(self).__class__.__name__ + return f"{classname}('{self._name}', '{self._descriptor}')" diff --git a/smartsim/_core/mli/comm/channel/dragon_channel.py b/smartsim/_core/mli/comm/channel/dragon_channel.py new file mode 100644 index 0000000000..110f19258a --- /dev/null +++ b/smartsim/_core/mli/comm/channel/dragon_channel.py @@ -0,0 +1,127 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import dragon.channels as dch + +import smartsim._core.mli.comm.channel.channel as cch +import smartsim._core.mli.comm.channel.dragon_util as drg_util +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class DragonCommChannel(cch.CommChannelBase): + """Passes messages by writing to a Dragon channel.""" + + def __init__(self, channel: "dch.Channel") -> None: + """Initialize the DragonCommChannel instance. + + :param channel: A channel to use for communications + """ + descriptor = drg_util.channel_to_descriptor(channel) + super().__init__(descriptor) + self._channel = channel + """The underlying dragon channel used by this CommChannel for communications""" + + @property + def channel(self) -> "dch.Channel": + """The underlying communication channel. + + :returns: The channel + """ + return self._channel + + def send(self, value: bytes, timeout: float = 0.001) -> None: + """Send a message through the underlying communication channel. + + :param value: The value to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :raises SmartSimError: If sending message fails + """ + try: + with self._channel.sendh(timeout=timeout) as sendh: + sendh.send_bytes(value, blocking=False) + logger.debug(f"DragonCommChannel {self.descriptor} sent message") + except Exception as e: + raise SmartSimError( + f"Error sending via DragonCommChannel {self.descriptor}" + ) from e + + def recv(self, timeout: float = 0.001) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: Maximum time to wait (in seconds) for messages to arrive + :returns: The received message(s) + """ + with self._channel.recvh(timeout=timeout) as recvh: + messages: t.List[bytes] = [] + + try: + message_bytes = recvh.recv_bytes(timeout=timeout) + messages.append(message_bytes) + logger.debug(f"DragonCommChannel {self.descriptor} received message") + except dch.ChannelEmpty: + # emptied the queue, ok to swallow this ex + logger.debug(f"DragonCommChannel exhausted: {self.descriptor}") + except dch.ChannelRecvTimeout: + logger.debug(f"Timeout exceeded on channel.recv: {self.descriptor}") + + return messages + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "DragonCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource. + :returns: An attached DragonCommChannel + :raises SmartSimError: If creation of comm channel fails + """ + try: + channel = drg_util.descriptor_to_channel(descriptor) + return DragonCommChannel(channel) + except Exception as ex: + raise SmartSimError( + f"Failed to create dragon comm channel: {descriptor}" + ) from ex + + @classmethod + def from_local(cls, _descriptor: t.Optional[str] = None) -> "DragonCommChannel": + """A factory method that creates a local channel instance. + + :param _descriptor: Unused placeholder + :returns: An attached DragonCommChannel""" + try: + channel = drg_util.create_local() + return DragonCommChannel(channel) + except: + logger.error(f"Failed to create local dragon comm channel", exc_info=True) + raise diff --git a/smartsim/_core/mli/comm/channel/dragon_fli.py b/smartsim/_core/mli/comm/channel/dragon_fli.py new file mode 100644 index 0000000000..5fb0790a84 --- /dev/null +++ b/smartsim/_core/mli/comm/channel/dragon_fli.py @@ -0,0 +1,158 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +import typing as t + +import smartsim._core.mli.comm.channel.channel as cch +import smartsim._core.mli.comm.channel.dragon_util as drg_util +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class DragonFLIChannel(cch.CommChannelBase): + """Passes messages by writing to a Dragon FLI Channel.""" + + def __init__( + self, + fli_: fli.FLInterface, + buffer_size: int = drg_util.DEFAULT_CHANNEL_BUFFER_SIZE, + ) -> None: + """Initialize the DragonFLIChannel instance. + + :param fli_: The FLIInterface to use as the underlying communications channel + :param sender_supplied: Flag indicating if the FLI uses sender-supplied streams + :param buffer_size: Maximum number of sent messages that can be buffered + """ + descriptor = drg_util.channel_to_descriptor(fli_) + super().__init__(descriptor) + + self._channel: t.Optional["Channel"] = None + """The underlying dragon Channel used by a sender-side DragonFLIChannel + to attach to the main FLI channel""" + + self._fli = fli_ + """The underlying dragon FLInterface used by this CommChannel for communications""" + self._buffer_size: int = buffer_size + """Maximum number of messages that can be buffered before sending""" + + def send(self, value: bytes, timeout: float = 0.001) -> None: + """Send a message through the underlying communication channel. + + :param value: The value to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :raises SmartSimError: If sending message fails + """ + try: + if self._channel is None: + self._channel = drg_util.create_local(self._buffer_size) + + with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: + sendh.send_bytes(value, timeout=timeout) + logger.debug(f"DragonFLIChannel {self.descriptor} sent message") + except Exception as e: + self._channel = None + raise SmartSimError( + f"Error sending via DragonFLIChannel {self.descriptor}" + ) from e + + def send_multiple( + self, + values: t.Sequence[bytes], + timeout: float = 0.001, + ) -> None: + """Send a message through the underlying communication channel. + + :param values: The values to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :raises SmartSimError: If sending message fails + """ + try: + if self._channel is None: + self._channel = drg_util.create_local(self._buffer_size) + + with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: + for value in values: + sendh.send_bytes(value) + logger.debug(f"DragonFLIChannel {self.descriptor} sent message") + except Exception as e: + self._channel = None + raise SmartSimError( + f"Error sending via DragonFLIChannel {self.descriptor} {e}" + ) from e + + def recv(self, timeout: float = 0.001) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: Maximum time to wait (in seconds) for messages to arrive + :returns: The received message(s) + :raises SmartSimError: If receiving message(s) fails + """ + messages = [] + eot = False + with self._fli.recvh(timeout=timeout) as recvh: + while not eot: + try: + message, _ = recvh.recv_bytes(timeout=timeout) + messages.append(message) + logger.debug(f"DragonFLIChannel {self.descriptor} received message") + except fli.FLIEOT: + eot = True + logger.debug(f"DragonFLIChannel exhausted: {self.descriptor}") + except Exception as e: + raise SmartSimError( + f"Error receiving messages: DragonFLIChannel {self.descriptor}" + ) from e + return messages + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "DragonFLIChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached DragonFLIChannel + :raises SmartSimError: If creation of DragonFLIChannel fails + :raises ValueError: If the descriptor is invalid + """ + if not descriptor: + raise ValueError("Invalid descriptor provided") + + try: + return DragonFLIChannel(fli_=drg_util.descriptor_to_fli(descriptor)) + except Exception as e: + raise SmartSimError( + f"Error while creating DragonFLIChannel: {descriptor}" + ) from e diff --git a/smartsim/_core/mli/comm/channel/dragon_util.py b/smartsim/_core/mli/comm/channel/dragon_util.py new file mode 100644 index 0000000000..8517979ec4 --- /dev/null +++ b/smartsim/_core/mli/comm/channel/dragon_util.py @@ -0,0 +1,131 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import binascii +import typing as t + +import dragon.channels as dch +import dragon.fli as fli +import dragon.managed_memory as dm + +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + +DEFAULT_CHANNEL_BUFFER_SIZE = 500 +"""Maximum number of messages that can be buffered. DragonCommChannel will +raise an exception if no clients consume messages before the buffer is filled.""" + +LAST_OFFSET = 0 +"""The last offset used to create a local channel. This is used to avoid +unnecessary retries when creating a local channel.""" + + +def channel_to_descriptor(channel: t.Union[dch.Channel, fli.FLInterface]) -> str: + """Convert a dragon channel to a descriptor string. + + :param channel: The dragon channel to convert + :returns: The descriptor string + :raises ValueError: If a dragon channel is not provided + """ + if channel is None: + raise ValueError("Channel is not available to create a descriptor") + + serialized_ch = channel.serialize() + return base64.b64encode(serialized_ch).decode("utf-8") + + +def pool_to_descriptor(pool: dm.MemoryPool) -> str: + """Convert a dragon memory pool to a descriptor string. + + :param pool: The memory pool to convert + :returns: The descriptor string + :raises ValueError: If a memory pool is not provided + """ + if pool is None: + raise ValueError("Memory pool is not available to create a descriptor") + + serialized_pool = pool.serialize() + return base64.b64encode(serialized_pool).decode("utf-8") + + +def descriptor_to_fli(descriptor: str) -> "fli.FLInterface": + """Create and attach a new FLI instance given + the string-encoded descriptor. + + :param descriptor: The descriptor of an FLI to attach to + :returns: The attached dragon FLI + :raises ValueError: If the descriptor is empty or incorrectly formatted + :raises SmartSimError: If attachment using the descriptor fails + """ + if len(descriptor) < 1: + raise ValueError("Descriptors may not be empty") + + try: + encoded = descriptor.encode("utf-8") + descriptor_ = base64.b64decode(encoded) + return fli.FLInterface.attach(descriptor_) + except binascii.Error: + raise ValueError("The descriptor was not properly base64 encoded") + except fli.DragonFLIError: + raise SmartSimError("The descriptor did not address an available FLI") + + +def descriptor_to_channel(descriptor: str) -> dch.Channel: + """Create and attach a new Channel instance given + the string-encoded descriptor. + + :param descriptor: The descriptor of a channel to attach to + :returns: The attached dragon Channel + :raises ValueError: If the descriptor is empty or incorrectly formatted + :raises SmartSimError: If attachment using the descriptor fails + """ + if len(descriptor) < 1: + raise ValueError("Descriptors may not be empty") + + try: + encoded = descriptor.encode("utf-8") + descriptor_ = base64.b64decode(encoded) + return dch.Channel.attach(descriptor_) + except binascii.Error: + raise ValueError("The descriptor was not properly base64 encoded") + except dch.ChannelError: + raise SmartSimError("The descriptor did not address an available channel") + + +def create_local(_capacity: int = 0) -> dch.Channel: + """Creates a Channel attached to the local memory pool. Replacement for + direct calls to `dch.Channel.make_process_local()` to enable + supplying a channel capacity. + + :param _capacity: The number of events the channel can buffer; uses the default + buffer size `DEFAULT_CHANNEL_BUFFER_SIZE` when not supplied + :returns: The instantiated channel + """ + channel = dch.Channel.make_process_local() + return channel diff --git a/smartsim/_core/mli/infrastructure/__init__.py b/smartsim/_core/mli/infrastructure/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/comm/__init__.py b/smartsim/_core/mli/infrastructure/comm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/comm/broadcaster.py b/smartsim/_core/mli/infrastructure/comm/broadcaster.py new file mode 100644 index 0000000000..56dcf549f7 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/broadcaster.py @@ -0,0 +1,239 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +import uuid +from collections import defaultdict, deque + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.infrastructure.comm.event import EventBase +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class BroadcastResult(t.NamedTuple): + """Contains summary details about a broadcast.""" + + num_sent: int + """The total number of messages delivered across all consumers""" + num_failed: int + """The total number of messages not delivered across all consumers""" + + +class EventBroadcaster: + """Performs fan-out publishing of system events.""" + + def __init__( + self, + backbone: BackboneFeatureStore, + channel_factory: t.Optional[t.Callable[[str], CommChannelBase]] = None, + name: t.Optional[str] = None, + ) -> None: + """Initialize the EventPublisher instance. + + :param backbone: The MLI backbone feature store + :param channel_factory: Factory method to construct new channel instances + :param name: A user-friendly name for logging. If not provided, an + auto-generated GUID will be used + """ + self._backbone = backbone + """The backbone feature store used to retrieve consumer descriptors""" + self._channel_factory = channel_factory + """A factory method used to instantiate channels from descriptors""" + self._channel_cache: t.Dict[str, t.Optional[CommChannelBase]] = defaultdict( + lambda: None + ) + """A mapping of instantiated channels that can be re-used. Automatically + calls the channel factory if a descriptor is not already in the collection""" + self._event_buffer: t.Deque[EventBase] = deque() + """A buffer for storing events when a consumer list is not found""" + self._descriptors: t.Set[str] + """Stores the most recent list of broadcast consumers. Updated automatically + on each broadcast""" + self._name = name or str(uuid.uuid4()) + """A unique identifer assigned to the broadcaster for logging""" + + @property + def name(self) -> str: + """The friendly name assigned to the broadcaster. + + :returns: The broadcaster name if one is assigned, otherwise a unique + id assigned by the system. + """ + return self._name + + @property + def num_buffered(self) -> int: + """Return the number of events currently buffered to send. + + :returns: Number of buffered events + """ + return len(self._event_buffer) + + def _save_to_buffer(self, event: EventBase) -> None: + """Places the event in the buffer to be sent once a consumer + list is available. + + :param event: The event to buffer + :raises ValueError: If the event cannot be buffered + """ + try: + self._event_buffer.append(event) + logger.debug(f"Buffered event {event=}") + except Exception as ex: + raise ValueError( + f"Unable to buffer event {event} in broadcaster {self.name}" + ) from ex + + def _log_broadcast_start(self) -> None: + """Logs broadcast statistics.""" + num_events = len(self._event_buffer) + num_copies = len(self._descriptors) + logger.debug( + f"Broadcast {num_events} events to {num_copies} consumers from {self.name}" + ) + + def _prune_unused_consumers(self) -> None: + """Performs maintenance on the channel cache by pruning any channel + that has been removed from the consumers list.""" + active_consumers = set(self._descriptors) + current_channels = set(self._channel_cache.keys()) + + # find any cached channels that are now unused + inactive_channels = current_channels.difference(active_consumers) + new_channels = active_consumers.difference(current_channels) + + for descriptor in inactive_channels: + self._channel_cache.pop(descriptor) + + logger.debug( + f"Pruning {len(inactive_channels)} stale consumers and" + f" found {len(new_channels)} new channels for {self.name}" + ) + + def _get_comm_channel(self, descriptor: str) -> CommChannelBase: + """Helper method to build and cache a comm channel. + + :param descriptor: The descriptor to pass to the channel factory + :returns: The instantiated channel + :raises SmartSimError: If the channel fails to attach + """ + comm_channel = self._channel_cache[descriptor] + if comm_channel is not None: + return comm_channel + + if self._channel_factory is None: + raise SmartSimError("No channel factory provided for consumers") + + try: + channel = self._channel_factory(descriptor) + self._channel_cache[descriptor] = channel + return channel + except Exception as ex: + msg = f"Unable to construct channel with descriptor: {descriptor}" + logger.error(msg, exc_info=True) + raise SmartSimError(msg) from ex + + def _get_next_event(self) -> t.Optional[EventBase]: + """Pop the next event to be sent from the queue. + + :returns: The next event to send if any events are enqueued, otherwise `None`. + """ + try: + return self._event_buffer.popleft() + except IndexError: + logger.debug(f"Broadcast buffer exhausted for {self.name}") + + return None + + def _broadcast(self, timeout: float = 0.001) -> BroadcastResult: + """Broadcasts all buffered events to registered event consumers. + + :param timeout: Maximum time to wait (in seconds) for messages to send + :returns: BroadcastResult containing the number of messages that were + successfully and unsuccessfully sent for all consumers + :raises SmartSimError: If the channel fails to attach or broadcasting fails + """ + # allow descriptors to be empty since events are buffered + self._descriptors = set(x for x in self._backbone.notification_channels if x) + if not self._descriptors: + msg = f"No event consumers are registered for {self.name}" + logger.warning(msg) + return BroadcastResult(0, 0) + + self._prune_unused_consumers() + self._log_broadcast_start() + + num_listeners = len(self._descriptors) + num_sent = 0 + num_failures = 0 + + # send each event to every consumer + while event := self._get_next_event(): + logger.debug(f"Broadcasting {event=} to {num_listeners} listeners") + event_bytes = bytes(event) + + for i, descriptor in enumerate(self._descriptors): + comm_channel = self._get_comm_channel(descriptor) + + try: + comm_channel.send(event_bytes, timeout) + num_sent += 1 + except Exception: + msg = ( + f"Broadcast {i+1}/{num_listeners} for event {event.uid} to " + f"channel {descriptor} from {self.name} failed." + ) + logger.exception(msg) + num_failures += 1 + + return BroadcastResult(num_sent, num_failures) + + def send(self, event: EventBase, timeout: float = 0.001) -> int: + """Implementation of `send` method of the `EventPublisher` protocol. Publishes + the supplied event to all registered broadcast consumers. + + :param event: An event to publish + :param timeout: Maximum time to wait (in seconds) for messages to send + :returns: The total number of events successfully published to consumers + :raises ValueError: If event serialization fails + :raises AttributeError: If event cannot be serialized + :raises KeyError: If channel fails to attach using registered descriptors + :raises SmartSimError: If any unexpected error occurs during send + """ + try: + self._save_to_buffer(event) + result = self._broadcast(timeout) + return result.num_sent + except (KeyError, ValueError, AttributeError, SmartSimError): + raise + except Exception as ex: + raise SmartSimError("An unexpected failure occurred while sending") from ex diff --git a/smartsim/_core/mli/infrastructure/comm/consumer.py b/smartsim/_core/mli/infrastructure/comm/consumer.py new file mode 100644 index 0000000000..08b5c47852 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/consumer.py @@ -0,0 +1,281 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pickle +import time +import typing as t +import uuid + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.infrastructure.comm.event import ( + EventBase, + OnCreateConsumer, + OnRemoveConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class EventConsumer: + """Reads system events published to a communications channel.""" + + _BACKBONE_WAIT_TIMEOUT = 10.0 + """Maximum time (in seconds) to wait for the backbone to register the consumer""" + + def __init__( + self, + comm_channel: CommChannelBase, + backbone: BackboneFeatureStore, + filters: t.Optional[t.List[str]] = None, + name: t.Optional[str] = None, + event_handler: t.Optional[t.Callable[[EventBase], None]] = None, + ) -> None: + """Initialize the EventConsumer instance. + + :param comm_channel: Communications channel to listen to for events + :param backbone: The MLI backbone feature store + :param filters: A list of event types to deliver. when empty, all + events will be delivered + :param name: A user-friendly name for logging. If not provided, an + auto-generated GUID will be used + """ + self._comm_channel = comm_channel + """The comm channel used by the consumer to receive messages. The channel + descriptor will be published for senders to discover.""" + self._backbone = backbone + """The backbone instance used to bootstrap the instance. The EventConsumer + uses the backbone to discover where it can publish its descriptor.""" + self._global_filters = filters or [] + """A set of global filters to apply to incoming events. Global filters are + combined with per-call filters. Filters act as an allow-list.""" + self._name = name or str(uuid.uuid4()) + """User-friendly name assigned to a consumer for logging. Automatically + assigned if not provided.""" + self._event_handler = event_handler + """The function that should be executed when an event + passed by the filters is received.""" + self.listening = True + """Flag indicating that the consumer is currently listening for new + events. Setting this flag to `False` will cause any active calls to + `listen` to terminate.""" + + @property + def descriptor(self) -> str: + """The descriptor of the underlying comm channel. + + :returns: The comm channel descriptor""" + return self._comm_channel.descriptor + + @property + def name(self) -> str: + """The friendly name assigned to the consumer. + + :returns: The consumer name if one is assigned, otherwise a unique + id assigned by the system. + """ + return self._name + + def recv( + self, + filters: t.Optional[t.List[str]] = None, + timeout: float = 0.001, + batch_timeout: float = 1.0, + ) -> t.List[EventBase]: + """Receives available published event(s). + + :param filters: Additional filters to add to the global filters configured + on the EventConsumer instance + :param timeout: Maximum time to wait for a single message to arrive + :param batch_timeout: Maximum time to wait for messages to arrive; allows + multiple batches to be retrieved in one call to `send` + :returns: A list of events that pass any configured filters + :raises ValueError: If a positive, non-zero value is not provided for the + timeout or batch_timeout. + """ + if filters is None: + filters = [] + + if timeout is not None and timeout <= 0: + raise ValueError("request timeout must be a non-zero, positive value") + + if batch_timeout is not None and batch_timeout <= 0: + raise ValueError("batch_timeout must be a non-zero, positive value") + + filter_set = {*self._global_filters, *filters} + all_message_bytes: t.List[bytes] = [] + + # firehose as many messages as possible within the batch_timeout + start_at = time.time() + remaining = batch_timeout + + batch_message_bytes = self._comm_channel.recv(timeout=timeout) + while batch_message_bytes: + # remove any empty messages that will fail to decode + all_message_bytes.extend(batch_message_bytes) + batch_message_bytes = [] + + # avoid getting stuck indefinitely waiting for the channel + elapsed = time.time() - start_at + remaining = batch_timeout - elapsed + + if remaining > 0: + batch_message_bytes = self._comm_channel.recv(timeout=timeout) + + events_received: t.List[EventBase] = [] + + # Timeout elapsed or no messages received - return the empty list + if not all_message_bytes: + return events_received + + for message in all_message_bytes: + if not message or message is None: + continue + + event = pickle.loads(message) + if not event: + logger.warning(f"Consumer {self.name} is unable to unpickle message") + continue + + # skip events that don't pass a filter + if filter_set and event.category not in filter_set: + continue + + events_received.append(event) + + return events_received + + def _send_to_registrar(self, event: EventBase) -> None: + """Send an event direct to the registrar listener.""" + registrar_key = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER + config = self._backbone.wait_for([registrar_key], self._BACKBONE_WAIT_TIMEOUT) + registrar_descriptor = str(config.get(registrar_key, None)) + + if not registrar_descriptor: + logger.warning( + f"Unable to send {event.category} from {self.name}. " + "No registrar channel found." + ) + return + + logger.debug(f"Sending {event.category} from {self.name}") + + registrar_channel = DragonCommChannel.from_descriptor(registrar_descriptor) + registrar_channel.send(bytes(event), timeout=1.0) + + logger.debug(f"{event.category} from {self.name} sent") + + def register(self) -> None: + """Send an event to register this consumer as a listener.""" + descriptor = self._comm_channel.descriptor + event = OnCreateConsumer(self.name, descriptor, self._global_filters) + + self._send_to_registrar(event) + + def unregister(self) -> None: + """Send an event to un-register this consumer as a listener.""" + descriptor = self._comm_channel.descriptor + event = OnRemoveConsumer(self.name, descriptor) + + self._send_to_registrar(event) + + def _on_handler_missing(self, event: EventBase) -> None: + """A "dead letter" event handler that is called to perform + processing on events before they're discarded. + + :param event: The event to handle + """ + logger.warning( + "No event handler is registered in consumer " + f"{self.name}. Discarding {event=}" + ) + + def listen_once(self, timeout: float = 0.001, batch_timeout: float = 1.0) -> None: + """Receives messages for the consumer a single time. Delivers + all messages that pass the consumer filters. Shutdown requests + are handled by a default event handler. + + + NOTE: Executes a single batch-retrieval to receive the maximum + number of messages available under batch timeout. To continually + listen, use `listen` in a non-blocking thread/process + + :param timeout: Maximum time to wait (in seconds) for a message to arrive + :param timeout: Maximum time to wait (in seconds) for a batch to arrive + """ + logger.info( + f"Consumer {self.name} listening with {timeout} second timeout" + f" on channel {self._comm_channel.descriptor}" + ) + + if not self._event_handler: + logger.info("Unable to handle messages. No event handler is registered.") + + incoming_messages = self.recv(timeout=timeout, batch_timeout=batch_timeout) + + if not incoming_messages: + logger.info(f"Consumer {self.name} received empty message list") + + for message in incoming_messages: + logger.info(f"Consumer {self.name} is handling event {message=}") + self._handle_shutdown(message) + + if self._event_handler: + self._event_handler(message) + else: + self._on_handler_missing(message) + + def _handle_shutdown(self, event: EventBase) -> bool: + """Handles shutdown requests sent to the consumer by setting the + `self.listener` property to `False`. + + :param event: The event to handle + :returns: A bool indicating if the event was a shutdown request + """ + if isinstance(event, OnShutdownRequested): + logger.debug(f"Shutdown requested from: {event.source}") + self.listening = False + return True + return False + + def listen(self, timeout: float = 0.001, batch_timeout: float = 1.0) -> None: + """Receives messages for the consumer until a shutdown request is received. + + :param timeout: Maximum time to wait (in seconds) for a message to arrive + :param batch_timeout: Maximum time to wait (in seconds) for a batch to arrive + """ + + logger.debug(f"Consumer {self.name} is now listening for events.") + + while self.listening: + self.listen_once(timeout, batch_timeout) + + logger.debug(f"Consumer {self.name} is no longer listening.") diff --git a/smartsim/_core/mli/infrastructure/comm/event.py b/smartsim/_core/mli/infrastructure/comm/event.py new file mode 100644 index 0000000000..ccef9f9b86 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/event.py @@ -0,0 +1,162 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pickle +import typing as t +import uuid +from dataclasses import dataclass, field + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +@dataclass +class EventBase: + """Core API for an event.""" + + category: str + """Unique category name for an event class""" + source: str + """A unique identifier for the publisher of the event""" + uid: str = field(default_factory=lambda: str(uuid.uuid4())) + """A unique identifier for this event""" + + def __bytes__(self) -> bytes: + """Default conversion to bytes for an event required to publish + messages using byte-oriented communication channels. + + :returns: This entity encoded as bytes""" + return pickle.dumps(self) + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance""" + return f"{self.uid}|{self.category}" + + +class OnShutdownRequested(EventBase): + """Publish this event to trigger the listener to shutdown.""" + + SHUTDOWN: t.ClassVar[str] = "consumer-unregister" + """Unique category name for an event raised when a new consumer is unregistered""" + + def __init__(self, source: str) -> None: + """Initialize the event instance. + + :param source: A unique identifier for the publisher of the event + creating the event + """ + super().__init__(self.SHUTDOWN, source) + + +class OnCreateConsumer(EventBase): + """Publish this event when a new event consumer registration is required.""" + + descriptor: str + """Descriptor of the comm channel exposed by the consumer""" + filters: t.List[str] = field(default_factory=list) + """The collection of filters indicating messages of interest to this consumer""" + + CONSUMER_CREATED: t.ClassVar[str] = "consumer-created" + """Unique category name for an event raised when a new consumer is registered""" + + def __init__(self, source: str, descriptor: str, filters: t.Sequence[str]) -> None: + """Initialize the event instance. + + :param source: A unique identifier for the publisher of the event + :param descriptor: Descriptor of the comm channel exposed by the consumer + :param filters: Collection of filters indicating messages of interest + """ + super().__init__(self.CONSUMER_CREATED, source) + self.descriptor = descriptor + self.filters = list(filters) + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance + """ + _filters = ",".join(self.filters) + return f"{str(super())}|{self.descriptor}|{_filters}" + + +class OnRemoveConsumer(EventBase): + """Publish this event when a consumer is shutting down and + should be removed from notification lists.""" + + descriptor: str + """Descriptor of the comm channel exposed by the consumer""" + + CONSUMER_REMOVED: t.ClassVar[str] = "consumer-removed" + """Unique category name for an event raised when a new consumer is unregistered""" + + def __init__(self, source: str, descriptor: str) -> None: + """Initialize the OnRemoveConsumer event. + + :param source: A unique identifier for the publisher of the event + :param descriptor: Descriptor of the comm channel exposed by the consumer + """ + super().__init__(self.CONSUMER_REMOVED, source) + self.descriptor = descriptor + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance + """ + return f"{str(super())}|{self.descriptor}" + + +class OnWriteFeatureStore(EventBase): + """Publish this event when a feature store key is written.""" + + descriptor: str + """The descriptor of the feature store where the write occurred""" + key: str + """The key identifying where the write occurred""" + + FEATURE_STORE_WRITTEN: str = "feature-store-written" + """Event category for an event raised when a feature store key is written""" + + def __init__(self, source: str, descriptor: str, key: str) -> None: + """Initialize the OnWriteFeatureStore event. + + :param source: A unique identifier for the publisher of the event + :param descriptor: The descriptor of the feature store where the write occurred + :param key: The key identifying where the write occurred + """ + super().__init__(self.FEATURE_STORE_WRITTEN, source) + self.descriptor = descriptor + self.key = key + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance + """ + return f"{str(super())}|{self.descriptor}|{self.key}" diff --git a/smartsim/_core/mli/infrastructure/comm/producer.py b/smartsim/_core/mli/infrastructure/comm/producer.py new file mode 100644 index 0000000000..2d8a7c14ad --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/producer.py @@ -0,0 +1,44 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +from smartsim._core.mli.infrastructure.comm.event import EventBase +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class EventProducer(t.Protocol): + """Core API of a class that publishes events.""" + + def send(self, event: EventBase, timeout: float = 0.001) -> int: + """Send an event using the configured comm channel. + + :param event: The event to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :returns: The number of messages that were sent + """ diff --git a/smartsim/_core/mli/infrastructure/control/__init__.py b/smartsim/_core/mli/infrastructure/control/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/control/device_manager.py b/smartsim/_core/mli/infrastructure/control/device_manager.py new file mode 100644 index 0000000000..9334971f8c --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/device_manager.py @@ -0,0 +1,166 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +from contextlib import _GeneratorContextManager, contextmanager + +from .....log import get_logger +from ..storage.feature_store import FeatureStore +from ..worker.worker import MachineLearningWorkerBase, RequestBatch + +logger = get_logger(__name__) + + +class WorkerDevice: + def __init__(self, name: str) -> None: + """Wrapper around a device to keep track of loaded Models and availability. + + :param name: Name used by the toolkit to identify this device, e.g. ``cuda:0`` + """ + self._name = name + """The name used by the toolkit to identify this device""" + self._models: dict[str, t.Any] = {} + """Dict of keys to models which are loaded on this device""" + + @property + def name(self) -> str: + """The identifier of the device represented by this object + + :returns: Name used by the toolkit to identify this device + """ + return self._name + + def add_model(self, key: str, model: t.Any) -> None: + """Add a reference to a model loaded on this device and assign it a key. + + :param key: The key under which the model is saved + :param model: The model which is added + """ + self._models[key] = model + + def remove_model(self, key: str) -> None: + """Remove the reference to a model loaded on this device. + + :param key: The key of the model to remove + :raises KeyError: If key does not exist for removal + """ + try: + self._models.pop(key) + except KeyError: + logger.warning(f"An unknown key was requested for removal: {key}") + raise + + def get_model(self, key: str) -> t.Any: + """Get the model corresponding to a given key. + + :param key: The model key + :returns: The model for the given key + :raises KeyError: If key does not exist + """ + try: + return self._models[key] + except KeyError: + logger.warning(f"An unknown key was requested: {key}") + raise + + def __contains__(self, key: str) -> bool: + """Check if model with a given key is available on the device. + + :param key: The key of the model to check for existence + :returns: Whether the model is available on the device + """ + return key in self._models + + @contextmanager + def get(self, key_to_remove: t.Optional[str]) -> t.Iterator["WorkerDevice"]: + """Get the WorkerDevice generator and optionally remove a model. + + :param key_to_remove: The key of the model to optionally remove + :returns: WorkerDevice generator + """ + yield self + if key_to_remove is not None: + self.remove_model(key_to_remove) + + +class DeviceManager: + def __init__(self, device: WorkerDevice): + """An object to manage devices such as GPUs and CPUs. + + The main goal of the ``DeviceManager`` is to ensure that + the managed device is ready to be used by a worker to + run a given model. + + :param device: The managed device + """ + self._device = device + """Device managed by this object""" + + def _load_model_on_device( + self, + worker: MachineLearningWorkerBase, + batch: RequestBatch, + feature_stores: dict[str, FeatureStore], + ) -> None: + """Load the model needed to execute a batch on the managed device. + + The model is loaded by the worker. + + :param worker: The worker that loads the model + :param batch: The batch for which the model is needed + :param feature_stores: Feature stores where the model could be stored + """ + + model_bytes = worker.fetch_model(batch, feature_stores) + loaded_model = worker.load_model(batch, model_bytes, self._device.name) + self._device.add_model(batch.model_id.key, loaded_model.model) + + def get_device( + self, + worker: MachineLearningWorkerBase, + batch: RequestBatch, + feature_stores: dict[str, FeatureStore], + ) -> _GeneratorContextManager[WorkerDevice]: + """Get the device managed by this object. + + The model needed to run the batch of requests is + guaranteed to be available on the device. + + :param worker: The worker that wants to access the device + :param batch: The batch of requests + :param feature_store: The feature store on which part of the + data needed by the request may be stored + :returns: A generator yielding the device + """ + model_in_request = batch.has_raw_model + + # Load model if not already loaded, or + # because it is sent with the request + if model_in_request or not batch.model_id.key in self._device: + self._load_model_on_device(worker, batch, feature_stores) + + key_to_remove = batch.model_id.key if model_in_request else None + return self._device.get(key_to_remove) diff --git a/smartsim/_core/mli/infrastructure/control/error_handling.py b/smartsim/_core/mli/infrastructure/control/error_handling.py new file mode 100644 index 0000000000..a75f533a37 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/error_handling.py @@ -0,0 +1,78 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +from .....log import get_logger +from ...comm.channel.channel import CommChannelBase +from ...message_handler import MessageHandler +from ...mli_schemas.response.response_capnp import ResponseBuilder + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +logger = get_logger(__file__) + + +def build_failure_reply(status: "Status", message: str) -> ResponseBuilder: + """ + Builds a failure response message. + + :param status: Status enum + :param message: Status message + :returns: Failure response + """ + return MessageHandler.build_response( + status=status, + message=message, + result=None, + custom_attributes=None, + ) + + +def exception_handler( + exc: Exception, + reply_channel: t.Optional[CommChannelBase], + failure_message: t.Optional[str], +) -> None: + """ + Logs exceptions and sends a failure response. + + :param exc: The exception to be logged + :param reply_channel: The channel used to send replies + :param failure_message: Failure message to log and send back + """ + logger.exception(exc) + if reply_channel: + if failure_message is None: + failure_message = str(exc) + + serialized_resp = MessageHandler.serialize_response( + build_failure_reply("fail", failure_message) + ) + reply_channel.send(serialized_resp) + else: + logger.warning("Unable to notify client of error without a reply channel") diff --git a/smartsim/_core/mli/infrastructure/control/listener.py b/smartsim/_core/mli/infrastructure/control/listener.py new file mode 100644 index 0000000000..56a7b12d34 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/listener.py @@ -0,0 +1,352 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +# pylint: disable=import-error +# pylint: disable=unused-import +import socket +import dragon + +# pylint: enable=unused-import +# pylint: enable=import-error +# isort: on + +import argparse +import multiprocessing as mp +import os +import sys +import typing as t + +from smartsim._core.entrypoints.service import Service +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + EventBase, + OnCreateConsumer, + OnRemoveConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class ConsumerRegistrationListener(Service): + """A long-running service that manages the list of consumers receiving + events that are broadcast. It hosts handlers for adding and removing consumers + """ + + def __init__( + self, + backbone: BackboneFeatureStore, + timeout: float, + batch_timeout: float, + as_service: bool = False, + cooldown: int = 0, + health_check_frequency: float = 60.0, + ) -> None: + """Initialize the EventListener. + + :param backbone: The backbone feature store + :param timeout: Maximum time (in seconds) to allow a single recv request to wait + :param batch_timeout: Maximum time (in seconds) to allow a batch of receives to + continue to build + :param as_service: Specifies run-once or run-until-complete behavior of service + :param cooldown: Number of seconds to wait before shutting down after + shutdown criteria are met + """ + super().__init__( + as_service, cooldown, health_check_frequency=health_check_frequency + ) + self._timeout = timeout + """ Maximum time (in seconds) to allow a single recv request to wait""" + self._batch_timeout = batch_timeout + """Maximum time (in seconds) to allow a batch of receives to + continue to build""" + self._consumer: t.Optional[EventConsumer] = None + """The event consumer that handles receiving events""" + self._backbone = backbone + """A standalone, system-created feature store used to share internal + information among MLI components""" + + def _on_start(self) -> None: + """Called on initial entry into Service `execute` event loop before + `_on_iteration` is invoked.""" + super()._on_start() + self._create_eventing() + + def _on_shutdown(self) -> None: + """Release dragon resources. Called immediately after exiting + the main event loop during automatic shutdown.""" + super()._on_shutdown() + + if not self._consumer: + return + + # remove descriptor for this listener from the backbone if it's there + if registered_consumer := self._backbone.backend_channel: + # if there is a descriptor in the backbone and it's still this listener + if registered_consumer == self._consumer.descriptor: + logger.info( + f"Listener clearing backend consumer {self._consumer.name} " + "from backbone" + ) + + # unregister this listener in the backbone + self._backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + # TODO: need the channel to be cleaned up + # self._consumer._comm_channel._channel.destroy() + + def _on_iteration(self) -> None: + """Executes calls to the machine learning worker implementation to complete + the inference pipeline.""" + + if self._consumer is None: + logger.info("Unable to listen. No consumer available.") + return + + self._consumer.listen_once(self._timeout, self._batch_timeout) + + def _can_shutdown(self) -> bool: + """Determines if the event consumer is ready to stop listening. + + :returns: True when criteria to shutdown the service are met, False otherwise + """ + + if self._backbone is None: + logger.info("Listener must shutdown. No backbone attached") + return True + + if self._consumer is None: + logger.info("Listener must shutdown. No consumer channel created") + return True + + if not self._consumer.listening: + logger.info( + f"Listener can shutdown. Consumer `{self._consumer.name}` " + "is not listening" + ) + return True + + return False + + def _on_unregister(self, event: OnRemoveConsumer) -> None: + """Event handler for updating the backbone when event consumers + are un-registered. + + :param event: The event that was received + """ + notify_list = set(self._backbone.notification_channels) + + # remove the descriptor specified in the event + if event.descriptor in notify_list: + logger.debug(f"Removing notify consumer: {event.descriptor}") + notify_list.remove(event.descriptor) + + # push the updated list back into the backbone + self._backbone.notification_channels = list(notify_list) + + def _on_register(self, event: OnCreateConsumer) -> None: + """Event handler for updating the backbone when new event consumers + are registered. + + :param event: The event that was received + """ + notify_list = set(self._backbone.notification_channels) + logger.debug(f"Adding notify consumer: {event.descriptor}") + notify_list.add(event.descriptor) + self._backbone.notification_channels = list(notify_list) + + def _on_event_received(self, event: EventBase) -> None: + """Primary event handler for the listener. Distributes events to + type-specific handlers. + + :param event: The event that was received + """ + if self._backbone is None: + logger.info("Unable to handle event. Backbone is missing.") + + if isinstance(event, OnCreateConsumer): + self._on_register(event) + elif isinstance(event, OnRemoveConsumer): + self._on_unregister(event) + else: + logger.info( + "Consumer registration listener received an " + f"unexpected event: {event=}" + ) + + def _on_health_check(self) -> None: + """Check if this consumer has been replaced by a new listener + and automatically trigger a shutdown. Invoked based on the + value of `self._health_check_frequency`.""" + super()._on_health_check() + + try: + logger.debug("Retrieving registered listener descriptor") + descriptor = self._backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + except KeyError: + descriptor = None + if self._consumer: + self._consumer.listening = False + + if self._consumer and descriptor != self._consumer.descriptor: + logger.warning( + f"Consumer `{self._consumer.name}` for `ConsumerRegistrationListener` " + "is no longer registered. It will automatically shut down." + ) + self._consumer.listening = False + + def _publish_consumer(self) -> None: + """Publish the registrar consumer descriptor to the backbone.""" + if self._consumer is None: + logger.warning("No registrar consumer descriptor available to publisher") + return + + logger.debug(f"Publishing {self._consumer.descriptor} to backbone") + self._backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = ( + self._consumer.descriptor + ) + + def _create_eventing(self) -> EventConsumer: + """ + Create an event publisher and event consumer for communicating with + other MLI resources. + + NOTE: the backbone must be initialized before connecting eventing clients. + + :returns: The newly created EventConsumer instance + :raises SmartSimError: If a listener channel cannot be created + """ + + if self._consumer: + return self._consumer + + logger.info("Creating event consumer") + + dragon_channel = create_local(500) + event_channel = DragonCommChannel(dragon_channel) + + if not event_channel.descriptor: + raise SmartSimError( + "Unable to generate the descriptor for the event channel" + ) + + self._consumer = EventConsumer( + event_channel, + self._backbone, + [ + OnCreateConsumer.CONSUMER_CREATED, + OnRemoveConsumer.CONSUMER_REMOVED, + OnShutdownRequested.SHUTDOWN, + ], + name=f"ConsumerRegistrar.{socket.gethostname()}", + event_handler=self._on_event_received, + ) + self._publish_consumer() + + logger.info( + f"Backend consumer `{self._consumer.name}` created: " + f"{self._consumer.descriptor}" + ) + + return self._consumer + + +def _create_parser() -> argparse.ArgumentParser: + """ + Create an argument parser that contains the arguments + required to start the listener as a new process: + + --timeout + --batch_timeout + + :returns: A configured parser + """ + arg_parser = argparse.ArgumentParser(prog="ConsumerRegistrarEventListener") + + arg_parser.add_argument("--timeout", type=float, default=1.0) + arg_parser.add_argument("--batch_timeout", type=float, default=1.0) + + return arg_parser + + +def _connect_backbone() -> t.Optional[BackboneFeatureStore]: + """ + Load the backbone by retrieving the descriptor from environment variables. + + :returns: The backbone feature store + :raises SmartSimError: if a descriptor is not found + """ + descriptor = os.environ.get(BackboneFeatureStore.MLI_BACKBONE, "") + + if not descriptor: + return None + + logger.info(f"Listener backbone descriptor: {descriptor}\n") + + # `from_writable_descriptor` ensures we can update the backbone + return BackboneFeatureStore.from_writable_descriptor(descriptor) + + +if __name__ == "__main__": + mp.set_start_method("dragon") + + parser = _create_parser() + args = parser.parse_args() + + backbone_fs = _connect_backbone() + + if backbone_fs is None: + logger.error( + "Unable to attach to the backbone without the " + f"`{BackboneFeatureStore.MLI_BACKBONE}` environment variable." + ) + sys.exit(1) + + logger.debug(f"Listener attached to backbone: {backbone_fs.descriptor}") + + listener = ConsumerRegistrationListener( + backbone_fs, + float(args.timeout), + float(args.batch_timeout), + as_service=True, + ) + + logger.info(f"listener created? {listener}") + + try: + listener.execute() + sys.exit(0) + except Exception: + logger.exception("An error occurred in the event listener") + sys.exit(1) diff --git a/smartsim/_core/mli/infrastructure/control/request_dispatcher.py b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py new file mode 100644 index 0000000000..e22a2c8f62 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py @@ -0,0 +1,559 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +# pylint: disable-next=unused-import +import dragon +import dragon.globalservices.pool as dragon_gs_pool +from dragon.managed_memory import MemoryPool +from dragon.mpbridge.queues import DragonQueue + +# pylint: enable=import-error + +# isort: off +# isort: on + +import multiprocessing as mp +import time +import typing as t +import uuid +from queue import Empty, Full, Queue + +from smartsim._core.entrypoints.service import Service + +from .....error import SmartSimError +from .....log import get_logger +from ....utils.timings import PerfTimer +from ..environment_loader import EnvironmentConfigLoader +from ..storage.feature_store import FeatureStore +from ..worker.worker import ( + InferenceRequest, + MachineLearningWorkerBase, + ModelIdentifier, + RequestBatch, +) +from .error_handling import exception_handler + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +logger = get_logger("Request Dispatcher") + + +class BatchQueue(Queue[InferenceRequest]): + def __init__( + self, batch_timeout: float, batch_size: int, model_id: ModelIdentifier + ) -> None: + """Queue used to store inference requests waiting to be batched and + sent to Worker Managers. + + :param batch_timeout: Time in seconds that has to be waited before flushing a + non-full queue. The time of the first item put is 0 seconds. + :param batch_size: Total capacity of the queue + :param model_id: Key of the model which needs to be executed on the queued + requests + """ + super().__init__(maxsize=batch_size) + self._batch_timeout = batch_timeout + """Time in seconds that has to be waited before flushing a non-full queue. + The time of the first item put is 0 seconds.""" + self._batch_size = batch_size + """Total capacity of the queue""" + self._first_put: t.Optional[float] = None + """Time at which the first item was put on the queue""" + self._disposable = False + """Whether the queue will not be used again and can be deleted. + A disposable queue is always full.""" + self._model_id: ModelIdentifier = model_id + """Key of the model which needs to be executed on the queued requests""" + self._uid = str(uuid.uuid4()) + """Unique ID of queue""" + + @property + def uid(self) -> str: + """ID of this queue. + + :returns: Queue ID + """ + return self._uid + + @property + def model_id(self) -> ModelIdentifier: + """Key of the model which needs to be run on the queued requests. + + :returns: Model key + """ + return self._model_id + + def put( + self, + item: InferenceRequest, + block: bool = False, + timeout: t.Optional[float] = 0.0, + ) -> None: + """Put an inference request in the queue. + + :param item: The request + :param block: Whether to block when trying to put the item + :param timeout: Time (in seconds) to wait if block==True + :raises Full: If an item cannot be put on the queue + """ + super().put(item, block=block, timeout=timeout) + if self._first_put is None: + self._first_put = time.time() + + @property + def _elapsed_time(self) -> float: + """Time elapsed since the first item was put on this queue. + + :returns: Time elapsed + """ + if self.empty() or self._first_put is None: + return 0 + return time.time() - self._first_put + + @property + def ready(self) -> bool: + """Check if the queue can be flushed. + + :returns: True if the queue can be flushed, False otherwise + """ + if self.empty(): + logger.debug("Request dispatcher queue is empty") + return False + + timed_out = False + if self._batch_timeout >= 0: + timed_out = self._elapsed_time >= self._batch_timeout + + if self.full(): + logger.debug("Request dispatcher ready to deliver full batch") + return True + + if timed_out: + logger.debug("Request dispatcher delivering partial batch") + return True + + return False + + def make_disposable(self) -> None: + """Set this queue as disposable, and never use it again after it gets + flushed.""" + self._disposable = True + + @property + def can_be_removed(self) -> bool: + """Determine whether this queue can be deleted and garbage collected. + + :returns: True if queue can be removed, False otherwise + """ + return self.empty() and self._disposable + + def flush(self) -> list[t.Any]: + """Get all requests from queue. + + :returns: Requests waiting to be executed + """ + num_items = self.qsize() + self._first_put = None + items = [] + for _ in range(num_items): + try: + items.append(self.get()) + except Empty: + break + + return items + + def full(self) -> bool: + """Check if the queue has reached its maximum capacity. + + :returns: True if the queue has reached its maximum capacity, + False otherwise + """ + if self._disposable: + return True + return self.qsize() >= self._batch_size + + def empty(self) -> bool: + """Check if the queue is empty. + + :returns: True if the queue has 0 elements, False otherwise + """ + return self.qsize() == 0 + + +class RequestDispatcher(Service): + def __init__( + self, + batch_timeout: float, + batch_size: int, + config_loader: EnvironmentConfigLoader, + worker_type: t.Type[MachineLearningWorkerBase], + mem_pool_size: int = 2 * 1024**3, + ) -> None: + """The RequestDispatcher intercepts inference requests, stages them in + queues and batches them together before making them available to Worker + Managers. + + :param batch_timeout: Maximum elapsed time before flushing a complete or + incomplete batch + :param batch_size: Total capacity of each batch queue + :param mem_pool: Memory pool used to share batched input tensors with worker + managers + :param config_loader: Object to load configuration from environment + :param worker_type: Type of worker to instantiate to batch inputs + :param mem_pool_size: Size of the memory pool used to allocate tensors + """ + super().__init__(as_service=True, cooldown=1) + self._queues: dict[str, list[BatchQueue]] = {} + """Dict of all batch queues available for a given model id""" + self._active_queues: dict[str, BatchQueue] = {} + """Mapping telling which queue is the recipient of requests for a given model + key""" + self._batch_timeout = batch_timeout + """Time in seconds that has to be waited before flushing a non-full queue""" + self._batch_size = batch_size + """Total capacity of each batch queue""" + incoming_channel = config_loader.get_queue() + if incoming_channel is None: + raise SmartSimError("No incoming channel for dispatcher") + self._incoming_channel = incoming_channel + """The channel the dispatcher monitors for new tasks""" + self._outgoing_queue: DragonQueue = mp.Queue(maxsize=0) + """The queue on which batched inference requests are placed""" + self._feature_stores: t.Dict[str, FeatureStore] = {} + """A collection of attached feature stores""" + self._featurestore_factory = config_loader._featurestore_factory + """A factory method to create a desired feature store client type""" + self._backbone: t.Optional[FeatureStore] = config_loader.get_backbone() + """A standalone, system-created feature store used to share internal + information among MLI components""" + self._callback_factory = config_loader._callback_factory + """The type of communication channel to construct for callbacks""" + self._worker = worker_type() + """The worker used to batch inputs""" + self._mem_pool = MemoryPool.attach(dragon_gs_pool.create(mem_pool_size).sdesc) + """Memory pool used to share batched input tensors with the Worker Managers""" + self._perf_timer = PerfTimer(prefix="r_", debug=False, timing_on=True) + """Performance timer""" + + @property + def has_featurestore_factory(self) -> bool: + """Check if the RequestDispatcher has a FeatureStore factory. + + :returns: True if there is a FeatureStore factory, False otherwise + """ + return self._featurestore_factory is not None + + def _check_feature_stores(self, request: InferenceRequest) -> bool: + """Ensures that all feature stores required by the request are available. + + :param request: The request to validate + :returns: False if feature store validation fails for the request, True + otherwise + """ + # collect all feature stores required by the request + fs_model: t.Set[str] = set() + if request.model_key: + fs_model = {request.model_key.descriptor} + fs_inputs = {key.descriptor for key in request.input_keys} + fs_outputs = {key.descriptor for key in request.output_keys} + + # identify which feature stores are requested and unknown + fs_desired = fs_model.union(fs_inputs).union(fs_outputs) + fs_actual = {item.descriptor for item in self._feature_stores.values()} + fs_missing = fs_desired - fs_actual + + if not self.has_featurestore_factory: + logger.error("No feature store factory is configured. Unable to dispatch.") + return False + + # create the feature stores we need to service request + if fs_missing: + logger.debug(f"Adding feature store(s): {fs_missing}") + for descriptor in fs_missing: + feature_store = self._featurestore_factory(descriptor) + self._feature_stores[descriptor] = feature_store + + return True + + # pylint: disable-next=no-self-use + def _check_model(self, request: InferenceRequest) -> bool: + """Ensure that a model is available for the request. + + :param request: The request to validate + :returns: False if model validation fails for the request, True otherwise + """ + if request.has_model_key or request.has_raw_model: + return True + + logger.error("Unable to continue without model bytes or feature store key") + return False + + # pylint: disable-next=no-self-use + def _check_inputs(self, request: InferenceRequest) -> bool: + """Ensure that inputs are available for the request. + + :param request: The request to validate + :returns: False if input validation fails for the request, True otherwise + """ + if request.has_input_keys or request.has_raw_inputs: + return True + + logger.error("Unable to continue without input bytes or feature store keys") + return False + + # pylint: disable-next=no-self-use + def _check_callback(self, request: InferenceRequest) -> bool: + """Ensure that a callback channel is available for the request. + + :param request: The request to validate + :returns: False if callback validation fails for the request, True otherwise + """ + if request.callback: + return True + + logger.error("No callback channel provided in request") + return False + + def _validate_request(self, request: InferenceRequest) -> bool: + """Ensure the request can be processed. + + :param request: The request to validate + :returns: False if the request fails any validation checks, True otherwise + """ + checks = [ + self._check_feature_stores(request), + self._check_model(request), + self._check_inputs(request), + self._check_callback(request), + ] + + return all(checks) + + def _on_iteration(self) -> None: + """This method is executed repeatedly until ``Service`` shutdown + conditions are satisfied and cooldown is elapsed.""" + try: + self._perf_timer.is_active = True + bytes_list: t.List[bytes] = self._incoming_channel.recv() + except Exception: + self._perf_timer.is_active = False + else: + if not bytes_list: + exception_handler( + ValueError("No request data found"), + None, + None, + ) + + logger.debug(f"Dispatcher is processing {len(bytes_list)} messages") + request_bytes = bytes_list[0] + tensor_bytes_list = bytes_list[1:] + self._perf_timer.start_timings() + + request = self._worker.deserialize_message( + request_bytes, self._callback_factory + ) + if request.has_input_meta and tensor_bytes_list: + request.raw_inputs = tensor_bytes_list + + self._perf_timer.measure_time("deserialize_message") + + if not self._validate_request(request): + exception_handler( + ValueError("Error validating the request"), + request.callback, + None, + ) + self._perf_timer.measure_time("validate_request") + else: + self._perf_timer.measure_time("validate_request") + self.dispatch(request) + self._perf_timer.measure_time("dispatch") + finally: + self.flush_requests() + self.remove_queues() + + self._perf_timer.end_timings() + + if self._perf_timer.max_length == 801 and self._perf_timer.is_active: + self._perf_timer.print_timings(True) + + def remove_queues(self) -> None: + """Remove references to queues that can be removed + and allow them to be garbage collected.""" + queue_lists_to_remove = [] + for key, queues in self._queues.items(): + queues_to_remove = [] + for queue in queues: + if queue.can_be_removed: + queues_to_remove.append(queue) + + for queue_to_remove in queues_to_remove: + queues.remove(queue_to_remove) + if ( + key in self._active_queues + and self._active_queues[key] == queue_to_remove + ): + del self._active_queues[key] + + if len(queues) == 0: + queue_lists_to_remove.append(key) + + for key in queue_lists_to_remove: + del self._queues[key] + + @property + def task_queue(self) -> DragonQueue: + """The queue on which batched requests are placed. + + :returns: The queue + """ + return self._outgoing_queue + + def _swap_queue(self, model_id: ModelIdentifier) -> None: + """Get an empty queue or create a new one + and make it the active one for a given model. + + :param model_id: The id of the model for which the + queue has to be swapped + """ + if model_id.key in self._queues: + for queue in self._queues[model_id.key]: + if not queue.full(): + self._active_queues[model_id.key] = queue + return + + new_queue = BatchQueue(self._batch_timeout, self._batch_size, model_id) + if model_id.key in self._queues: + self._queues[model_id.key].append(new_queue) + else: + self._queues[model_id.key] = [new_queue] + self._active_queues[model_id.key] = new_queue + return + + def dispatch(self, request: InferenceRequest) -> None: + """Assign a request to a batch queue. + + :param request: The request to place + """ + if request.has_raw_model: + logger.debug("Direct inference requested, creating tmp queue") + tmp_id = f"_tmp_{str(uuid.uuid4())}" + tmp_queue: BatchQueue = BatchQueue( + batch_timeout=0, + batch_size=1, + model_id=ModelIdentifier(key=tmp_id, descriptor="TMP"), + ) + self._active_queues[tmp_id] = tmp_queue + self._queues[tmp_id] = [tmp_queue] + tmp_queue.put(request) + tmp_queue.make_disposable() + return + + if request.model_key: + success = False + while not success: + try: + self._active_queues[request.model_key.key].put_nowait(request) + success = True + except (Full, KeyError): + self._swap_queue(request.model_key) + + def flush_requests(self) -> None: + """Get all requests from queues which are ready to be flushed. Place all + available request batches in the outgoing queue.""" + for queue_list in self._queues.values(): + for queue in queue_list: + if queue.ready: + self._perf_timer.measure_time("find_queue") + try: + batch = RequestBatch( + requests=queue.flush(), + inputs=None, + model_id=queue.model_id, + ) + finally: + self._perf_timer.measure_time("flush_requests") + try: + fetch_results = self._worker.fetch_inputs( + batch=batch, feature_stores=self._feature_stores + ) + except Exception as exc: + exception_handler( + exc, + None, + "Error fetching input.", + ) + continue + self._perf_timer.measure_time("fetch_input") + try: + transformed_inputs = self._worker.transform_input( + batch=batch, + fetch_results=fetch_results, + mem_pool=self._mem_pool, + ) + except Exception as exc: + exception_handler( + exc, + None, + "Error transforming input.", + ) + continue + + self._perf_timer.measure_time("transform_input") + batch.inputs = transformed_inputs + for request in batch.requests: + request.raw_inputs = [] + request.input_meta = [] + + try: + self._outgoing_queue.put(batch) + except Exception as exc: + exception_handler( + exc, + None, + "Error placing batch on task queue.", + ) + continue + self._perf_timer.measure_time("put") + + def _can_shutdown(self) -> bool: + """Determine whether the Service can be shut down. + + :returns: False + """ + return False + + def __del__(self) -> None: + """Destroy allocated memory resources.""" + # pool may be null if a failure occurs prior to successful attach + pool: t.Optional[MemoryPool] = getattr(self, "_mem_pool", None) + + if pool: + pool.destroy() diff --git a/smartsim/_core/mli/infrastructure/control/worker_manager.py b/smartsim/_core/mli/infrastructure/control/worker_manager.py new file mode 100644 index 0000000000..bf6fddb81d --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/worker_manager.py @@ -0,0 +1,330 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +# pylint: disable-next=unused-import +import dragon + +# pylint: enable=import-error + +# isort: off +# isort: on + +import multiprocessing as mp +import time +import typing as t +from queue import Empty + +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore + +from .....log import get_logger +from ....entrypoints.service import Service +from ....utils.timings import PerfTimer +from ...message_handler import MessageHandler +from ..environment_loader import EnvironmentConfigLoader +from ..worker.worker import ( + InferenceReply, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, +) +from .device_manager import DeviceManager, WorkerDevice +from .error_handling import build_failure_reply, exception_handler + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +logger = get_logger(__name__) + + +class WorkerManager(Service): + """An implementation of a service managing distribution of tasks to + machine learning workers.""" + + def __init__( + self, + config_loader: EnvironmentConfigLoader, + worker_type: t.Type[MachineLearningWorkerBase], + dispatcher_queue: "mp.Queue[RequestBatch]", + as_service: bool = False, + cooldown: int = 0, + device: t.Literal["cpu", "gpu"] = "cpu", + ) -> None: + """Initialize the WorkerManager. + + :param config_loader: Environment config loader for loading queues + and feature stores + :param worker_type: The type of worker to manage + :param dispatcher_queue: Queue from which the batched requests are pulled + :param as_service: Specifies run-once or run-until-complete behavior of service + :param cooldown: Number of seconds to wait before shutting down after + shutdown criteria are met + :param device: The device on which the Worker should run. Every worker manager + is assigned one single GPU (if available), thus the device should have no index. + """ + super().__init__(as_service, cooldown) + + self._dispatcher_queue = dispatcher_queue + """The Dispatcher queue that the WorkerManager monitors for new batches""" + self._worker = worker_type() + """The ML Worker implementation""" + self._callback_factory = config_loader._callback_factory + """The type of communication channel to construct for callbacks""" + self._device = device + """Device on which workers need to run""" + self._cached_models: dict[str, t.Any] = {} + """Dictionary of previously loaded models""" + self._feature_stores: t.Dict[str, FeatureStore] = {} + """A collection of attached feature stores""" + self._featurestore_factory = config_loader._featurestore_factory + """A factory method to create a desired feature store client type""" + self._backbone: t.Optional[FeatureStore] = config_loader.get_backbone() + """A standalone, system-created feature store used to share internal + information among MLI components""" + self._device_manager: t.Optional[DeviceManager] = None + """Object responsible for model caching and device access""" + self._perf_timer = PerfTimer(prefix="w_", debug=False, timing_on=True) + """Performance timer""" + + @property + def has_featurestore_factory(self) -> bool: + """Check if the WorkerManager has a FeatureStore factory. + + :returns: True if there is a FeatureStore factory, False otherwise + """ + return self._featurestore_factory is not None + + def _on_start(self) -> None: + """Called on initial entry into Service `execute` event loop before + `_on_iteration` is invoked.""" + self._device_manager = DeviceManager(WorkerDevice(self._device)) + + def _check_feature_stores(self, batch: RequestBatch) -> bool: + """Ensures that all feature stores required by the request are available. + + :param batch: The batch of requests to validate + :returns: False if feature store validation fails for the batch, True otherwise + """ + # collect all feature stores required by the request + fs_model: t.Set[str] = set() + if batch.model_id.key: + fs_model = {batch.model_id.descriptor} + fs_inputs = {key.descriptor for key in batch.input_keys} + fs_outputs = {key.descriptor for key in batch.output_keys} + + # identify which feature stores are requested and unknown + fs_desired = fs_model.union(fs_inputs).union(fs_outputs) + fs_actual = {item.descriptor for item in self._feature_stores.values()} + fs_missing = fs_desired - fs_actual + + if not self.has_featurestore_factory: + logger.error("No feature store factory configured") + return False + + # create the feature stores we need to service request + if fs_missing: + logger.debug(f"Adding feature store(s): {fs_missing}") + for descriptor in fs_missing: + feature_store = self._featurestore_factory(descriptor) + self._feature_stores[descriptor] = feature_store + + return True + + def _validate_batch(self, batch: RequestBatch) -> bool: + """Ensure the request can be processed. + + :param batch: The batch of requests to validate + :returns: False if the request fails any validation checks, True otherwise + """ + if batch is None or not batch.has_valid_requests: + return False + + return self._check_feature_stores(batch) + + # remove this when we are done with time measurements + # pylint: disable-next=too-many-statements + def _on_iteration(self) -> None: + """Executes calls to the machine learning worker implementation to complete + the inference pipeline.""" + pre_batch_time = time.perf_counter() + try: + batch: RequestBatch = self._dispatcher_queue.get(timeout=0.0001) + except Empty: + return + + self._perf_timer.start_timings( + "flush_requests", time.perf_counter() - pre_batch_time + ) + + if not self._validate_batch(batch): + exception_handler( + ValueError("An invalid batch was received"), + None, + None, + ) + return + + if not self._device_manager: + for request in batch.requests: + msg = "No Device Manager found. WorkerManager._on_start() " + "must be called after initialization. If possible, " + "you should use `WorkerManager.execute()` instead of " + "directly calling `_on_iteration()`." + try: + self._dispatcher_queue.put(batch) + except Exception: + msg += "\nThe batch could not be put back in the queue " + "and will not be processed." + exception_handler( + RuntimeError(msg), + request.callback, + "Error acquiring device manager", + ) + return + + try: + device_cm = self._device_manager.get_device( + worker=self._worker, + batch=batch, + feature_stores=self._feature_stores, + ) + except Exception as exc: + for request in batch.requests: + exception_handler( + exc, + request.callback, + "Error loading model on device or getting device.", + ) + return + self._perf_timer.measure_time("fetch_model") + + with device_cm as device: + + try: + model_result = LoadModelResult(device.get_model(batch.model_id.key)) + except Exception as exc: + for request in batch.requests: + exception_handler( + exc, request.callback, "Error getting model from device." + ) + return + self._perf_timer.measure_time("load_model") + + if not batch.inputs: + for request in batch.requests: + exception_handler( + ValueError("Error batching inputs"), + request.callback, + None, + ) + return + transformed_input = batch.inputs + + try: + execute_result = self._worker.execute( + batch, model_result, transformed_input, device.name + ) + except Exception as e: + for request in batch.requests: + exception_handler(e, request.callback, "Error while executing.") + return + self._perf_timer.measure_time("execute") + + try: + transformed_outputs = self._worker.transform_output( + batch, execute_result + ) + except Exception as e: + for request in batch.requests: + exception_handler( + e, request.callback, "Error while transforming the output." + ) + return + + for request, transformed_output in zip(batch.requests, transformed_outputs): + reply = InferenceReply() + if request.has_output_keys: + try: + reply.output_keys = self._worker.place_output( + request, + transformed_output, + self._feature_stores, + ) + except Exception as e: + exception_handler( + e, request.callback, "Error while placing the output." + ) + continue + else: + reply.outputs = transformed_output.outputs + self._perf_timer.measure_time("assign_output") + + if not reply.has_outputs: + response = build_failure_reply("fail", "Outputs not found.") + else: + reply.status_enum = "complete" + reply.message = "Success" + + results = self._worker.prepare_outputs(reply) + response = MessageHandler.build_response( + status=reply.status_enum, + message=reply.message, + result=results, + custom_attributes=None, + ) + + self._perf_timer.measure_time("build_reply") + + serialized_resp = MessageHandler.serialize_response(response) + + self._perf_timer.measure_time("serialize_resp") + + if request.callback: + request.callback.send(serialized_resp) + if reply.has_outputs: + # send tensor data after response + for output in reply.outputs: + request.callback.send(output) + self._perf_timer.measure_time("send") + + self._perf_timer.end_timings() + + if self._perf_timer.max_length == 801: + self._perf_timer.print_timings(True) + + def _can_shutdown(self) -> bool: + """Determine if the service can be shutdown. + + :returns: True when criteria to shutdown the service are met, False otherwise + """ + # todo: determine shutdown criteria + # will we receive a completion message? + # will we let MLI mgr just kill this? + # time_diff = self._last_event - datetime.datetime.now() + # if time_diff.total_seconds() > self._cooldown: + # return True + # return False + return self._worker is None diff --git a/smartsim/_core/mli/infrastructure/environment_loader.py b/smartsim/_core/mli/infrastructure/environment_loader.py new file mode 100644 index 0000000000..5ba0fccc27 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/environment_loader.py @@ -0,0 +1,116 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class EnvironmentConfigLoader: + """ + Facilitates the loading of a FeatureStore and Queue into the WorkerManager. + """ + + REQUEST_QUEUE_ENV_VAR = "_SMARTSIM_REQUEST_QUEUE" + """The environment variable that holds the request queue descriptor""" + BACKBONE_ENV_VAR = "_SMARTSIM_INFRA_BACKBONE" + """The environment variable that holds the backbone descriptor""" + + def __init__( + self, + featurestore_factory: t.Callable[[str], FeatureStore], + callback_factory: t.Callable[[str], CommChannelBase], + queue_factory: t.Callable[[str], CommChannelBase], + ) -> None: + """Initialize the config loader instance with the factories necessary for + creating additional objects. + + :param featurestore_factory: A factory method that produces a feature store + given a descriptor + :param callback_factory: A factory method that produces a callback + channel given a descriptor + :param queue_factory: A factory method that produces a queue + channel given a descriptor + """ + self.queue: t.Optional[CommChannelBase] = None + """The attached incoming event queue channel""" + self.backbone: t.Optional[FeatureStore] = None + """The attached backbone feature store""" + self._featurestore_factory = featurestore_factory + """A factory method to instantiate a FeatureStore""" + self._callback_factory = callback_factory + """A factory method to instantiate a concrete CommChannelBase + for inference callbacks""" + self._queue_factory = queue_factory + """A factory method to instantiate a concrete CommChannelBase + for inference requests""" + + def get_backbone(self) -> t.Optional[FeatureStore]: + """Attach to the backbone feature store using the descriptor found in + the environment variable `_SMARTSIM_INFRA_BACKBONE`. The backbone is + a standalone, system-created feature store used to share internal + information among MLI components. + + :returns: The attached feature store via `_SMARTSIM_INFRA_BACKBONE` + """ + descriptor = os.getenv(self.BACKBONE_ENV_VAR, "") + + if not descriptor: + logger.warning("No backbone descriptor is configured") + return None + + if self._featurestore_factory is None: + logger.warning( + "No feature store factory is configured. Backbone not created." + ) + return None + + self.backbone = self._featurestore_factory(descriptor) + return self.backbone + + def get_queue(self) -> t.Optional[CommChannelBase]: + """Attach to a queue-like communication channel using the descriptor + found in the environment variable `_SMARTSIM_REQUEST_QUEUE`. + + :returns: The attached queue specified via `_SMARTSIM_REQUEST_QUEUE` + """ + descriptor = os.getenv(self.REQUEST_QUEUE_ENV_VAR, "") + + if not descriptor: + logger.warning("No queue descriptor is configured") + return None + + if self._queue_factory is None: + logger.warning("No queue factory is configured") + return None + + self.queue = self._queue_factory(descriptor) + return self.queue diff --git a/smartsim/_core/mli/infrastructure/storage/__init__.py b/smartsim/_core/mli/infrastructure/storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py b/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py new file mode 100644 index 0000000000..b12d7b11b4 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py @@ -0,0 +1,259 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import itertools +import os +import time +import typing as t + +# pylint: disable=import-error +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class BackboneFeatureStore(DragonFeatureStore): + """A DragonFeatureStore wrapper with utility methods for accessing shared + information stored in the MLI backbone feature store.""" + + MLI_NOTIFY_CONSUMERS = "_SMARTSIM_MLI_NOTIFY_CONSUMERS" + """Unique key used in the backbone to locate the consumer list""" + MLI_REGISTRAR_CONSUMER = "_SMARTIM_MLI_REGISTRAR_CONSUMER" + """Unique key used in the backbone to locate the registration consumer""" + MLI_WORKER_QUEUE = "_SMARTSIM_REQUEST_QUEUE" + """Unique key used in the backbone to locate MLI work queue""" + MLI_BACKBONE = "_SMARTSIM_INFRA_BACKBONE" + """Unique key used in the backbone to locate the backbone feature store""" + _CREATED_ON = "creation" + """Unique key used in the backbone to locate the creation date of the + feature store""" + _DEFAULT_WAIT_TIMEOUT = 1.0 + """The default wait time (in seconds) for blocking requests to + the feature store""" + + def __init__( + self, + storage: dragon_ddict.DDict, + allow_reserved_writes: bool = False, + ) -> None: + """Initialize the DragonFeatureStore instance. + + :param storage: A distributed dictionary to be used as the underlying + storage mechanism of the feature store + :param allow_reserved_writes: Whether reserved writes are allowed + """ + super().__init__(storage) + self._enable_reserved_writes = allow_reserved_writes + + self._record_creation_data() + + @property + def wait_timeout(self) -> float: + """Retrieve the wait timeout for this feature store. The wait timeout is + applied to all calls to `wait_for`. + + :returns: The wait timeout (in seconds). + """ + return self._wait_timeout + + @wait_timeout.setter + def wait_timeout(self, value: float) -> None: + """Set the wait timeout (in seconds) for this feature store. The wait + timeout is applied to all calls to `wait_for`. + + :param value: The new value to set + """ + self._wait_timeout = value + + @property + def notification_channels(self) -> t.Sequence[str]: + """Retrieve descriptors for all registered MLI notification channels. + + :returns: The list of channel descriptors + """ + if self.MLI_NOTIFY_CONSUMERS in self: + stored_consumers = self[self.MLI_NOTIFY_CONSUMERS] + return str(stored_consumers).split(",") + return [] + + @notification_channels.setter + def notification_channels(self, values: t.Sequence[str]) -> None: + """Set the notification channels to be sent events. + + :param values: The list of channel descriptors to save + """ + self[self.MLI_NOTIFY_CONSUMERS] = ",".join( + [str(value) for value in values if value] + ) + + @property + def backend_channel(self) -> t.Optional[str]: + """Retrieve the channel descriptor used to register event consumers. + + :returns: The channel descriptor""" + if self.MLI_REGISTRAR_CONSUMER in self: + return str(self[self.MLI_REGISTRAR_CONSUMER]) + return None + + @backend_channel.setter + def backend_channel(self, value: str) -> None: + """Set the channel used to register event consumers. + + :param value: The stringified channel descriptor""" + self[self.MLI_REGISTRAR_CONSUMER] = value + + @property + def worker_queue(self) -> t.Optional[str]: + """Retrieve the channel descriptor used to send work to MLI worker managers. + + :returns: The channel descriptor, if found. Otherwise, `None`""" + if self.MLI_WORKER_QUEUE in self: + return str(self[self.MLI_WORKER_QUEUE]) + return None + + @worker_queue.setter + def worker_queue(self, value: str) -> None: + """Set the channel descriptor used to send work to MLI worker managers. + + :param value: The channel descriptor""" + self[self.MLI_WORKER_QUEUE] = value + + @property + def creation_date(self) -> str: + """Return the creation date for the backbone feature store. + + :returns: The string-formatted date when feature store was created""" + return str(self[self._CREATED_ON]) + + def _record_creation_data(self) -> None: + """Write the creation timestamp to the feature store.""" + if self._CREATED_ON not in self: + if not self._allow_reserved_writes: + logger.warning( + "Recorded creation from a write-protected backbone instance" + ) + self[self._CREATED_ON] = str(time.time()) + + os.environ[self.MLI_BACKBONE] = self.descriptor + + @classmethod + def from_writable_descriptor( + cls, + descriptor: str, + ) -> "BackboneFeatureStore": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached DragonFeatureStore + :raises SmartSimError: if attachment to DragonFeatureStore fails + """ + try: + return BackboneFeatureStore(dragon_ddict.DDict.attach(descriptor), True) + except Exception as ex: + raise SmartSimError( + f"Error creating backbone feature store: {descriptor}" + ) from ex + + def _check_wait_timeout( + self, start_time: float, timeout: float, indicators: t.Dict[str, bool] + ) -> None: + """Perform timeout verification. + + :param start_time: the start time to use for elapsed calculation + :param timeout: the timeout (in seconds) + :param indicators: latest retrieval status for requested keys + :raises SmartSimError: If the timeout elapses before all values are + retrieved + """ + elapsed = time.time() - start_time + if timeout and elapsed > timeout: + raise SmartSimError( + f"Backbone {self.descriptor=} timeout after {elapsed} " + f"seconds retrieving keys: {indicators}" + ) + + def wait_for( + self, keys: t.List[str], timeout: float = _DEFAULT_WAIT_TIMEOUT + ) -> t.Dict[str, t.Union[str, bytes, None]]: + """Perform a blocking wait until all specified keys have been found + in the backbone. + + :param keys: The required collection of keys to retrieve + :param timeout: The maximum wait time in seconds + :returns: Dictionary containing the keys and values requested + :raises SmartSimError: If the timeout elapses without retrieving + all requested keys + """ + if timeout < 0: + timeout = self._DEFAULT_WAIT_TIMEOUT + logger.info(f"Using default wait_for timeout: {timeout}s") + + if not keys: + return {} + + values: t.Dict[str, t.Union[str, bytes, None]] = {k: None for k in set(keys)} + is_found = {k: False for k in values.keys()} + + backoff = (0.1, 0.2, 0.4, 0.8) + backoff_iter = itertools.cycle(backoff) + start_time = time.time() + + while not all(is_found.values()): + delay = next(backoff_iter) + + for key in [k for k, v in is_found.items() if not v]: + try: + values[key] = self[key] + is_found[key] = True + except Exception: + if delay == backoff[-1]: + logger.debug(f"Re-attempting `{key}` retrieval in {delay}s") + + if all(is_found.values()): + logger.debug(f"wait_for({keys}) retrieved all keys") + continue + + self._check_wait_timeout(start_time, timeout, is_found) + time.sleep(delay) + + return values + + def get_env(self) -> t.Dict[str, str]: + """Returns a dictionary populated with environment variables necessary to + connect a process to the existing backbone instance. + + :returns: The dictionary populated with env vars + """ + return {self.MLI_BACKBONE: self.descriptor} diff --git a/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py b/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py new file mode 100644 index 0000000000..24f2221c87 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py @@ -0,0 +1,126 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +# pylint: disable=import-error +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim._core.mli.infrastructure.storage.dragon_util import ( + ddict_to_descriptor, + descriptor_to_ddict, +) +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim.error import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class DragonFeatureStore(FeatureStore): + """A feature store backed by a dragon distributed dictionary.""" + + def __init__(self, storage: "dragon_ddict.DDict") -> None: + """Initialize the DragonFeatureStore instance. + + :param storage: A distributed dictionary to be used as the underlying + storage mechanism of the feature store""" + if storage is None: + raise ValueError( + "Storage is required when instantiating a DragonFeatureStore." + ) + + descriptor = "" + if isinstance(storage, dragon_ddict.DDict): + descriptor = ddict_to_descriptor(storage) + + super().__init__(descriptor) + self._storage: t.Dict[str, t.Union[str, bytes]] = storage + """The underlying storage mechanism of the DragonFeatureStore; a + distributed, in-memory key-value store""" + + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism. + + :param key: The unique key that identifies the resource + :returns: The value identified by the key + :raises KeyError: If the key has not been used to store a value + """ + try: + return self._storage[key] + except dragon_ddict.DDictError as e: + raise KeyError(f"Key not found in FeatureStore: {key}") from e + + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism. + + :param key: The unique key that identifies the resource + :param value: The value to store + :returns: The value identified by the key + """ + self._storage[key] = value + + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key. + + :param key: The unique key that identifies the resource + :returns: True if the key is defined, False otherwise + """ + return key in self._storage + + def pop(self, key: str) -> t.Union[str, bytes, None]: + """Remove the value from the dictionary and return the value. + + :param key: Dictionary key to retrieve + :returns: The value held at the key if it exists, otherwise `None + `""" + try: + return self._storage.pop(key) + except dragon_ddict.DDictError: + return None + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "DragonFeatureStore": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached DragonFeatureStore + :raises SmartSimError: If attachment to DragonFeatureStore fails + """ + try: + logger.debug(f"Attaching to FeatureStore with descriptor: {descriptor}") + storage = descriptor_to_ddict(descriptor) + return cls(storage) + except Exception as ex: + raise SmartSimError( + f"Error creating dragon feature store from descriptor: {descriptor}" + ) from ex diff --git a/smartsim/_core/mli/infrastructure/storage/dragon_util.py b/smartsim/_core/mli/infrastructure/storage/dragon_util.py new file mode 100644 index 0000000000..50d15664c0 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/storage/dragon_util.py @@ -0,0 +1,101 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +def ddict_to_descriptor(ddict: dragon_ddict.DDict) -> str: + """Convert a DDict to a descriptor string. + + :param ddict: The dragon dictionary to convert + :returns: The descriptor string + :raises ValueError: If a ddict is not provided + """ + if ddict is None: + raise ValueError("DDict is not available to create a descriptor") + + # unlike other dragon objects, the dictionary serializes to a string + # instead of bytes + return str(ddict.serialize()) + + +def descriptor_to_ddict(descriptor: str) -> dragon_ddict.DDict: + """Create and attach a new DDict instance given + the string-encoded descriptor. + + :param descriptor: The descriptor of a dictionary to attach to + :returns: The attached dragon dictionary""" + return dragon_ddict.DDict.attach(descriptor) + + +def create_ddict( + num_nodes: int, mgr_per_node: int, mem_per_node: int +) -> dragon_ddict.DDict: + """Create a distributed dragon dictionary. + + :param num_nodes: The number of distributed nodes to distribute the dictionary to. + At least one node is required. + :param mgr_per_node: The number of manager processes per node + :param mem_per_node: The amount of memory (in megabytes) to allocate per node. Total + memory available will be calculated as `num_nodes * node_mem` + + :returns: The instantiated dragon dictionary + :raises ValueError: If invalid num_nodes is supplied + :raises ValueError: If invalid mem_per_node is supplied + :raises ValueError: If invalid mgr_per_node is supplied + """ + if num_nodes < 1: + raise ValueError("A dragon dictionary must have at least 1 node") + + if mgr_per_node < 1: + raise ValueError("A dragon dict requires at least 2 managers per ndode") + + if mem_per_node < dragon_ddict.DDICT_MIN_SIZE: + raise ValueError( + "A dragon dictionary requires at least " + f"{dragon_ddict.DDICT_MIN_SIZE / 1024} MB" + ) + + mem_total = num_nodes * mem_per_node + + logger.debug( + f"Creating dragon dictionary with {num_nodes} nodes, {mem_total} MB memory" + ) + + distributed_dict = dragon_ddict.DDict(num_nodes, mgr_per_node, total_mem=mem_total) + logger.debug( + "Successfully created dragon dictionary with " + f"{num_nodes} nodes, {mem_total} MB total memory" + ) + return distributed_dict diff --git a/smartsim/_core/mli/infrastructure/storage/feature_store.py b/smartsim/_core/mli/infrastructure/storage/feature_store.py new file mode 100644 index 0000000000..ebca07ed4e --- /dev/null +++ b/smartsim/_core/mli/infrastructure/storage/feature_store.py @@ -0,0 +1,224 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import enum +import typing as t +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class ReservedKeys(str, enum.Enum): + """Contains constants used to identify all featurestore keys that + may not be to used by users. Avoids overwriting system data.""" + + MLI_NOTIFY_CONSUMERS = "_SMARTSIM_MLI_NOTIFY_CONSUMERS" + """Storage location for the list of registered consumers that will receive + events from an EventBroadcaster""" + + MLI_REGISTRAR_CONSUMER = "_SMARTIM_MLI_REGISTRAR_CONSUMER" + """Storage location for the channel used to send messages directly to + the MLI backend""" + + MLI_WORKER_QUEUE = "_SMARTSIM_REQUEST_QUEUE" + """Storage location for the channel used to send work requests + to the available worker managers""" + + @classmethod + def contains(cls, value: str) -> bool: + """Convert a string representation into an enumeration member. + + :param value: The string to convert + :returns: The enumeration member if the conversion succeeded, otherwise None + """ + try: + cls(value) + except ValueError: + return False + + return True + + +@dataclass(frozen=True) +class TensorKey: + """A key,descriptor pair enabling retrieval of an item from a feature store.""" + + key: str + """The unique key of an item in a feature store""" + descriptor: str + """The unique identifier of the feature store containing the key""" + + def __post_init__(self) -> None: + """Ensure the key and descriptor have at least one character. + + :raises ValueError: If key or descriptor are empty strings + """ + if len(self.key) < 1: + raise ValueError("Key must have at least one character.") + if len(self.descriptor) < 1: + raise ValueError("Descriptor must have at least one character.") + + +@dataclass(frozen=True) +class ModelKey: + """A key,descriptor pair enabling retrieval of an item from a feature store.""" + + key: str + """The unique key of an item in a feature store""" + descriptor: str + """The unique identifier of the feature store containing the key""" + + def __post_init__(self) -> None: + """Ensure the key and descriptor have at least one character. + + :raises ValueError: If key or descriptor are empty strings + """ + if len(self.key) < 1: + raise ValueError("Key must have at least one character.") + if len(self.descriptor) < 1: + raise ValueError("Descriptor must have at least one character.") + + +class FeatureStore(ABC): + """Abstract base class providing the common interface for retrieving + values from a feature store implementation.""" + + def __init__(self, descriptor: str, allow_reserved_writes: bool = False) -> None: + """Initialize the feature store. + + :param descriptor: The stringified version of a storage descriptor + :param allow_reserved_writes: Override the default behavior of blocking + writes to reserved keys + """ + self._enable_reserved_writes = allow_reserved_writes + """Flag used to ensure that any keys written by the system to a feature store + are not overwritten by user code. Disabled by default. Subclasses must set the + value intentionally.""" + self._descriptor = descriptor + """Stringified version of the unique ID enabling a client to connect + to the feature store""" + + def _check_reserved(self, key: str) -> None: + """A utility method used to verify access to write to a reserved key + in the FeatureStore. Used by subclasses in __setitem___ implementations. + + :param key: A key to compare to the reserved keys + :raises SmartSimError: If the key is reserved + """ + if not self._enable_reserved_writes and ReservedKeys.contains(key): + raise SmartSimError( + "Use of reserved key denied. " + "Unable to overwrite system configuration" + ) + + def __getitem__(self, key: str) -> t.Union[str, bytes]: + """Retrieve an item using key. + + :param key: Unique key of an item to retrieve from the feature store + :returns: An item in the FeatureStore + :raises SmartSimError: If retrieving fails + """ + try: + return self._get(key) + except KeyError: + raise + except Exception as ex: + # note: explicitly avoid round-trip to check for key existence + raise SmartSimError( + f"Could not get value for existing key {key}, error:\n{ex}" + ) from ex + + def __setitem__(self, key: str, value: t.Union[str, bytes]) -> None: + """Assign a value using key. + + :param key: Unique key of an item to set in the feature store + :param value: Value to persist in the feature store + """ + self._check_reserved(key) + self._set(key, value) + + def __contains__(self, key: str) -> bool: + """Membership operator to test for a key existing within the feature store. + + :param key: Unique key of an item to retrieve from the feature store + :returns: `True` if the key is found, `False` otherwise + """ + return self._contains(key) + + @abstractmethod + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism. + + :param key: The unique key that identifies the resource + :returns: The value identified by the key + :raises KeyError: If the key has not been used to store a value + """ + + @abstractmethod + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism. + + :param key: The unique key that identifies the resource + :param value: The value to store + """ + + @abstractmethod + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key. + + :param key: The unique key that identifies the resource + :returns: `True` if the key is defined, `False` otherwise + """ + + @property + def _allow_reserved_writes(self) -> bool: + """Return the boolean flag indicating if writing to reserved keys is + enabled for this feature store. + + :returns: `True` if enabled, `False` otherwise + """ + return self._enable_reserved_writes + + @_allow_reserved_writes.setter + def _allow_reserved_writes(self, value: bool) -> None: + """Modify the boolean flag indicating if writing to reserved keys is + enabled for this feature store. + + :param value: The new value to set for the flag + """ + self._enable_reserved_writes = value + + @property + def descriptor(self) -> str: + """Unique identifier enabling a client to connect to the feature store. + + :returns: A descriptor encoded as a string + """ + return self._descriptor diff --git a/smartsim/_core/mli/infrastructure/worker/__init__.py b/smartsim/_core/mli/infrastructure/worker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py new file mode 100644 index 0000000000..64e94e5eb6 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/worker/torch_worker.py @@ -0,0 +1,276 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io + +import numpy as np +import torch + +# pylint: disable=import-error +from dragon.managed_memory import MemoryAlloc, MemoryPool + +from .....error import SmartSimError +from .....log import get_logger +from ...mli_schemas.tensor import tensor_capnp +from .worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) + +# pylint: enable=import-error + + +torch.set_num_threads(1) +torch.set_num_interop_threads(4) +logger = get_logger(__name__) + + +class TorchWorker(MachineLearningWorkerBase): + """A worker that executes a PyTorch model.""" + + @staticmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + """Given a loaded MachineLearningModel, ensure it is loaded into + device memory. + + :param request: The request that triggered the pipeline + :param device: The device on which the model must be placed + :returns: LoadModelResult wrapping the model loaded for the request + :raises ValueError: If model reference object is not found + :raises RuntimeError: If loading and evaluating the model failed + """ + if fetch_result.model_bytes: + model_bytes = fetch_result.model_bytes + elif batch.raw_model and batch.raw_model.data: + model_bytes = batch.raw_model.data + else: + raise ValueError("Unable to load model without reference object") + + device_to_torch = {"cpu": "cpu", "gpu": "cuda"} + for old, new in device_to_torch.items(): + device = device.replace(old, new) + + buffer = io.BytesIO(initial_bytes=model_bytes) + try: + with torch.no_grad(): + model = torch.jit.load(buffer, map_location=device) # type: ignore + model.eval() + except Exception as e: + raise RuntimeError( + "Failed to load and evaluate the model: " + f"Model key {batch.model_id.key}, Device {device}" + ) from e + result = LoadModelResult(model) + return result + + @staticmethod + def transform_input( + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: MemoryPool, + ) -> TransformInputResult: + """Given a collection of data, perform a transformation on the data and put + the raw tensor data on a MemoryPool allocation. + + :param request: The request that triggered the pipeline + :param fetch_result: Raw outputs from fetching inputs out of a feature store + :param mem_pool: The memory pool used to access batched input tensors + :returns: The transformed inputs wrapped in a TransformInputResult + :raises ValueError: If tensors cannot be reconstructed + :raises IndexError: If index out of range + """ + results: list[torch.Tensor] = [] + total_samples = 0 + slices: list[slice] = [] + + all_dims: list[list[int]] = [] + all_dtypes: list[str] = [] + if fetch_results[0].meta is None: + raise ValueError("Cannot reconstruct tensor without meta information") + # Traverse inputs to get total number of samples and compute slices + # Assumption: first dimension is samples, all tensors in the same input + # have same number of samples + # thus we only look at the first tensor for each input + for res_idx, fetch_result in enumerate(fetch_results): + if fetch_result.meta is None or any( + item_meta is None for item_meta in fetch_result.meta + ): + raise ValueError("Cannot reconstruct tensor without meta information") + first_tensor_desc: tensor_capnp.TensorDescriptor = fetch_result.meta[0] + num_samples = first_tensor_desc.dimensions[0] + slices.append(slice(total_samples, total_samples + num_samples)) + total_samples = total_samples + num_samples + + if res_idx == len(fetch_results) - 1: + # For each tensor in the last input, get remaining dimensions + # Assumptions: all inputs have the same number of tensors and + # last N-1 dimensions match across inputs for corresponding tensors + # thus: resulting array will be of size (num_samples, all_other_dims) + for item_meta in fetch_result.meta: + tensor_desc: tensor_capnp.TensorDescriptor = item_meta + tensor_dims = list(tensor_desc.dimensions) + all_dims.append([total_samples, *tensor_dims[1:]]) + all_dtypes.append(str(tensor_desc.dataType)) + + for result_tensor_idx, (dims, dtype) in enumerate(zip(all_dims, all_dtypes)): + itemsize = np.empty((1), dtype=dtype).itemsize + alloc_size = int(np.prod(dims) * itemsize) + mem_alloc = mem_pool.alloc(alloc_size) + mem_view = mem_alloc.get_memview() + try: + mem_view[:alloc_size] = b"".join( + [ + fetch_result.inputs[result_tensor_idx] + for fetch_result in fetch_results + ] + ) + except IndexError as e: + raise IndexError( + "Error accessing elements in fetch_result.inputs " + f"with index {result_tensor_idx}" + ) from e + + results.append(mem_alloc.serialize()) + + return TransformInputResult(results, slices, all_dims, all_dtypes) + + # pylint: disable-next=unused-argument + @staticmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + """Execute an ML model on inputs transformed for use by the model. + + :param batch: The batch of requests that triggered the pipeline + :param load_result: The result of loading the model onto device memory + :param transform_result: The result of transforming inputs for model consumption + :param device: The device on which the model will be executed + :returns: The result of inference wrapped in an ExecuteResult + :raises SmartSimError: If model is not loaded + :raises IndexError: If memory slicing is out of range + :raises ValueError: If tensor creation fails or is unable to evaluate the model + """ + if not load_result.model: + raise SmartSimError("Model must be loaded to execute") + device_to_torch = {"cpu": "cpu", "gpu": "cuda"} + for old, new in device_to_torch.items(): + device = device.replace(old, new) + + tensors = [] + mem_allocs = [] + for transformed, dims, dtype in zip( + transform_result.transformed, transform_result.dims, transform_result.dtypes + ): + mem_alloc = MemoryAlloc.attach(transformed) + mem_allocs.append(mem_alloc) + itemsize = np.empty((1), dtype=dtype).itemsize + try: + tensors.append( + torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[0 : np.prod(dims) * itemsize], + dtype=dtype, + ).reshape(dims) + ) + ) + except IndexError as e: + raise IndexError("Error during memory slicing") from e + except Exception as e: + raise ValueError("Error during tensor creation") from e + + model: torch.nn.Module = load_result.model + try: + with torch.no_grad(): + model.eval() + results = [ + model( + *[ + tensor.to(device, non_blocking=True).detach() + for tensor in tensors + ] + ) + ] + except Exception as e: + raise ValueError( + f"Error while evaluating the model: Model {batch.model_id.key}" + ) from e + + transform_result.transformed = [] + + execute_result = ExecuteResult(results, transform_result.slices) + for mem_alloc in mem_allocs: + mem_alloc.free() + return execute_result + + @staticmethod + def transform_output( + batch: RequestBatch, + execute_result: ExecuteResult, + ) -> list[TransformOutputResult]: + """Given inference results, perform transformations required to + transmit results to the requestor. + + :param batch: The batch of requests that triggered the pipeline + :param execute_result: The result of inference wrapped in an ExecuteResult + :returns: A list of transformed outputs + :raises IndexError: If indexing is out of range + :raises ValueError: If transforming output fails + """ + transformed_list: list[TransformOutputResult] = [] + cpu_predictions = [ + prediction.cpu() for prediction in execute_result.predictions + ] + for result_slice in execute_result.slices: + transformed = [] + for cpu_item in cpu_predictions: + try: + transformed.append(cpu_item[result_slice].numpy().tobytes()) + + # todo: need the shape from latest schemas added here. + transformed_list.append( + TransformOutputResult(transformed, None, "c", "float32") + ) # fixme + except IndexError as e: + raise IndexError( + f"Error accessing elements: result_slice {result_slice}" + ) from e + except Exception as e: + raise ValueError("Error transforming output") from e + + execute_result.predictions = [] + + return transformed_list diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py new file mode 100644 index 0000000000..9556b8e438 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -0,0 +1,646 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +from dragon.managed_memory import MemoryPool + +# isort: off +# isort: on + +import typing as t +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from .....error import SmartSimError +from .....log import get_logger +from ...comm.channel.channel import CommChannelBase +from ...message_handler import MessageHandler +from ...mli_schemas.model.model_capnp import Model +from ..storage.feature_store import FeatureStore, ModelKey, TensorKey + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + from smartsim._core.mli.mli_schemas.tensor.tensor_capnp import TensorDescriptor + +logger = get_logger(__name__) + +# Placeholder +ModelIdentifier = ModelKey + + +class InferenceRequest: + """Internal representation of an inference request from a client.""" + + def __init__( + self, + model_key: t.Optional[ModelKey] = None, + callback: t.Optional[CommChannelBase] = None, + raw_inputs: t.Optional[t.List[bytes]] = None, + input_keys: t.Optional[t.List[TensorKey]] = None, + input_meta: t.Optional[t.List[t.Any]] = None, + output_keys: t.Optional[t.List[TensorKey]] = None, + raw_model: t.Optional[Model] = None, + batch_size: int = 0, + ): + """Initialize the InferenceRequest. + + :param model_key: A tuple containing a (key, descriptor) pair + :param callback: The channel used for notification of inference completion + :param raw_inputs: Raw bytes of tensor inputs + :param input_keys: A list of tuples containing a (key, descriptor) pair + :param input_meta: Metadata about the input data + :param output_keys: A list of tuples containing a (key, descriptor) pair + :param raw_model: Raw bytes of an ML model + :param batch_size: The batch size to apply when batching + """ + self.model_key = model_key + """A tuple containing a (key, descriptor) pair""" + self.raw_model = raw_model + """Raw bytes of an ML model""" + self.callback = callback + """The channel used for notification of inference completion""" + self.raw_inputs = raw_inputs or [] + """Raw bytes of tensor inputs""" + self.input_keys = input_keys or [] + """A list of tuples containing a (key, descriptor) pair""" + self.input_meta = input_meta or [] + """Metadata about the input data""" + self.output_keys = output_keys or [] + """A list of tuples containing a (key, descriptor) pair""" + self.batch_size = batch_size + """The batch size to apply when batching""" + + @property + def has_raw_model(self) -> bool: + """Check if the InferenceRequest contains a raw_model. + + :returns: True if raw_model is not None, False otherwise + """ + return self.raw_model is not None + + @property + def has_model_key(self) -> bool: + """Check if the InferenceRequest contains a model_key. + + :returns: True if model_key is not None, False otherwise + """ + return self.model_key is not None + + @property + def has_raw_inputs(self) -> bool: + """Check if the InferenceRequest contains raw_inputs. + + :returns: True if raw_outputs is not None and is not an empty list, + False otherwise + """ + return self.raw_inputs is not None and bool(self.raw_inputs) + + @property + def has_input_keys(self) -> bool: + """Check if the InferenceRequest contains input_keys. + + :returns: True if input_keys is not None and is not an empty list, + False otherwise + """ + return self.input_keys is not None and bool(self.input_keys) + + @property + def has_output_keys(self) -> bool: + """Check if the InferenceRequest contains output_keys. + + :returns: True if output_keys is not None and is not an empty list, + False otherwise + """ + return self.output_keys is not None and bool(self.output_keys) + + @property + def has_input_meta(self) -> bool: + """Check if the InferenceRequest contains input_meta. + + :returns: True if input_meta is not None and is not an empty list, + False otherwise + """ + return self.input_meta is not None and bool(self.input_meta) + + +class InferenceReply: + """Internal representation of the reply to a client request for inference.""" + + def __init__( + self, + outputs: t.Optional[t.Collection[t.Any]] = None, + output_keys: t.Optional[t.Collection[TensorKey]] = None, + status_enum: "Status" = "running", + message: str = "In progress", + ) -> None: + """Initialize the InferenceReply. + + :param outputs: List of output data + :param output_keys: List of keys used for output data + :param status_enum: Status of the reply + :param message: Status message that corresponds with the status enum + """ + self.outputs: t.Collection[t.Any] = outputs or [] + """List of output data""" + self.output_keys: t.Collection[t.Optional[TensorKey]] = output_keys or [] + """List of keys used for output data""" + self.status_enum = status_enum + """Status of the reply""" + self.message = message + """Status message that corresponds with the status enum""" + + @property + def has_outputs(self) -> bool: + """Check if the InferenceReply contains outputs. + + :returns: True if outputs is not None and is not an empty list, + False otherwise + """ + return self.outputs is not None and bool(self.outputs) + + @property + def has_output_keys(self) -> bool: + """Check if the InferenceReply contains output_keys. + + :returns: True if output_keys is not None and is not an empty list, + False otherwise + """ + return self.output_keys is not None and bool(self.output_keys) + + +class LoadModelResult: + """A wrapper around a loaded model.""" + + def __init__(self, model: t.Any) -> None: + """Initialize the LoadModelResult. + + :param model: The loaded model + """ + self.model = model + """The loaded model (e.g. a TensorFlow, PyTorch, ONNX, etc. model)""" + + +class TransformInputResult: + """A wrapper around a transformed batch of input tensors""" + + def __init__( + self, + result: t.Any, + slices: list[slice], + dims: list[list[int]], + dtypes: list[str], + ) -> None: + """Initialize the TransformInputResult. + + :param result: List of Dragon MemoryAlloc objects on which + the tensors are stored + :param slices: The slices that represent which portion of the + input tensors belongs to which request + :param dims: Dimension of the transformed tensors + :param dtypes: Data type of transformed tensors + """ + self.transformed = result + """List of Dragon MemoryAlloc objects on which the tensors are stored""" + self.slices = slices + """Each slice represents which portion of the input tensors belongs to + which request""" + self.dims = dims + """Dimension of the transformed tensors""" + self.dtypes = dtypes + """Data type of transformed tensors""" + + +class ExecuteResult: + """A wrapper around inference results.""" + + def __init__(self, result: t.Any, slices: list[slice]) -> None: + """Initialize the ExecuteResult. + + :param result: Result of the execution + :param slices: The slices that represent which portion of the input + tensors belongs to which request + """ + self.predictions = result + """Result of the execution""" + self.slices = slices + """The slices that represent which portion of the input + tensors belongs to which request""" + + +class FetchInputResult: + """A wrapper around fetched inputs.""" + + def __init__(self, result: t.List[bytes], meta: t.Optional[t.List[t.Any]]) -> None: + """Initialize the FetchInputResult. + + :param result: List of input tensor bytes + :param meta: List of metadata that corresponds with the inputs + """ + self.inputs = result + """List of input tensor bytes""" + self.meta = meta + """List of metadata that corresponds with the inputs""" + + +class TransformOutputResult: + """A wrapper around inference results transformed for transmission.""" + + def __init__( + self, result: t.Any, shape: t.Optional[t.List[int]], order: str, dtype: str + ) -> None: + """Initialize the TransformOutputResult. + + :param result: Transformed output results + :param shape: Shape of output results + :param order: Order of output results + :param dtype: Datatype of output results + """ + self.outputs = result + """Transformed output results""" + self.shape = shape + """Shape of output results""" + self.order = order + """Order of output results""" + self.dtype = dtype + """Datatype of output results""" + + +class CreateInputBatchResult: + """A wrapper around inputs batched into a single request.""" + + def __init__(self, result: t.Any) -> None: + """Initialize the CreateInputBatchResult. + + :param result: Inputs batched into a single request + """ + self.batch = result + """Inputs batched into a single request""" + + +class FetchModelResult: + """A wrapper around raw fetched models.""" + + def __init__(self, result: bytes) -> None: + """Initialize the FetchModelResult. + + :param result: The raw fetched model + """ + self.model_bytes: bytes = result + """The raw fetched model""" + + +@dataclass +class RequestBatch: + """A batch of aggregated inference requests.""" + + requests: list[InferenceRequest] + """List of InferenceRequests in the batch""" + inputs: t.Optional[TransformInputResult] + """Transformed batch of input tensors""" + model_id: "ModelIdentifier" + """Model (key, descriptor) tuple""" + + @property + def has_valid_requests(self) -> bool: + """Returns whether the batch contains at least one request. + + :returns: True if at least one request is available + """ + return len(self.requests) > 0 + + @property + def has_raw_model(self) -> bool: + """Returns whether the batch has a raw model. + + :returns: True if the batch has a raw model + """ + return self.raw_model is not None + + @property + def raw_model(self) -> t.Optional[t.Any]: + """Returns the raw model to use to execute for this batch + if it is available. + + :returns: A model if available, otherwise None""" + if self.has_valid_requests: + return self.requests[0].raw_model + return None + + @property + def input_keys(self) -> t.List[TensorKey]: + """All input keys available in this batch's requests. + + :returns: All input keys belonging to requests in this batch""" + keys = [] + for request in self.requests: + keys.extend(request.input_keys) + + return keys + + @property + def output_keys(self) -> t.List[TensorKey]: + """All output keys available in this batch's requests. + + :returns: All output keys belonging to requests in this batch""" + keys = [] + for request in self.requests: + keys.extend(request.output_keys) + + return keys + + +class MachineLearningWorkerCore: + """Basic functionality of ML worker that is shared across all worker types.""" + + @staticmethod + def deserialize_message( + data_blob: bytes, + callback_factory: t.Callable[[str], CommChannelBase], + ) -> InferenceRequest: + """Deserialize a message from a byte stream into an InferenceRequest. + + :param data_blob: The byte stream to deserialize + :param callback_factory: A factory method that can create an instance + of the desired concrete comm channel type + :returns: The raw input message deserialized into an InferenceRequest + """ + request = MessageHandler.deserialize_request(data_blob) + model_key: t.Optional[ModelKey] = None + model_bytes: t.Optional[Model] = None + + if request.model.which() == "key": + model_key = ModelKey( + key=request.model.key.key, + descriptor=request.model.key.descriptor, + ) + elif request.model.which() == "data": + model_bytes = request.model.data + + callback_key = request.replyChannel.descriptor + comm_channel = callback_factory(callback_key) + input_keys: t.Optional[t.List[TensorKey]] = None + input_bytes: t.Optional[t.List[bytes]] = None + output_keys: t.Optional[t.List[TensorKey]] = None + input_meta: t.Optional[t.List[TensorDescriptor]] = None + + if request.input.which() == "keys": + input_keys = [ + TensorKey(key=value.key, descriptor=value.descriptor) + for value in request.input.keys + ] + elif request.input.which() == "descriptors": + input_meta = request.input.descriptors # type: ignore + + if request.output: + output_keys = [ + TensorKey(key=value.key, descriptor=value.descriptor) + for value in request.output + ] + + inference_request = InferenceRequest( + model_key=model_key, + callback=comm_channel, + raw_inputs=input_bytes, + input_meta=input_meta, + input_keys=input_keys, + output_keys=output_keys, + raw_model=model_bytes, + batch_size=0, + ) + return inference_request + + @staticmethod + def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]: + """Assemble the output information based on whether the output + information will be in the form of TensorKeys or TensorDescriptors. + + :param reply: The reply that the output belongs to + :returns: The list of prepared outputs, depending on the output + information needed in the reply + """ + prepared_outputs: t.List[t.Any] = [] + if reply.has_output_keys: + for value in reply.output_keys: + if not value: + continue + msg_key = MessageHandler.build_tensor_key(value.key, value.descriptor) + prepared_outputs.append(msg_key) + elif reply.has_outputs: + for _ in reply.outputs: + msg_tensor_desc = MessageHandler.build_tensor_descriptor( + "c", + "float32", + [1], + ) + prepared_outputs.append(msg_tensor_desc) + return prepared_outputs + + @staticmethod + def fetch_model( + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] + ) -> FetchModelResult: + """Given a resource key, retrieve the raw model from a feature store. + + :param batch: The batch of requests that triggered the pipeline + :param feature_stores: Available feature stores used for persistence + :returns: Raw bytes of the model + :raises SmartSimError: If neither a key or a model are provided or the + model cannot be retrieved from the feature store + :raises ValueError: If a feature store is not available and a raw + model is not provided + """ + # All requests in the same batch share the model + if batch.raw_model: + return FetchModelResult(batch.raw_model.data) + + if not feature_stores: + raise ValueError("Feature store is required for model retrieval") + + if batch.model_id is None: + raise SmartSimError( + "Key must be provided to retrieve model from feature store" + ) + + key, fsd = batch.model_id.key, batch.model_id.descriptor + + try: + feature_store = feature_stores[fsd] + raw_bytes: bytes = t.cast(bytes, feature_store[key]) + return FetchModelResult(raw_bytes) + except (FileNotFoundError, KeyError) as ex: + logger.exception(ex) + raise SmartSimError(f"Model could not be retrieved with key {key}") from ex + + @staticmethod + def fetch_inputs( + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] + ) -> t.List[FetchInputResult]: + """Given a collection of ResourceKeys, identify the physical location + and input metadata. + + :param batch: The batch of requests that triggered the pipeline + :param feature_stores: Available feature stores used for persistence + :returns: The fetched input + :raises ValueError: If neither an input key or an input tensor are provided + :raises SmartSimError: If a tensor for a given key cannot be retrieved + """ + fetch_results = [] + for request in batch.requests: + if request.raw_inputs: + fetch_results.append( + FetchInputResult(request.raw_inputs, request.input_meta) + ) + continue + + if not feature_stores: + raise ValueError("No input and no feature store provided") + + if request.has_input_keys: + data: t.List[bytes] = [] + + for fs_key in request.input_keys: + try: + feature_store = feature_stores[fs_key.descriptor] + tensor_bytes = t.cast(bytes, feature_store[fs_key.key]) + data.append(tensor_bytes) + except KeyError as ex: + logger.exception(ex) + raise SmartSimError( + f"Tensor could not be retrieved with key {fs_key.key}" + ) from ex + fetch_results.append( + FetchInputResult(data, meta=None) + ) # fixme: need to get both tensor and descriptor + continue + + raise ValueError("No input source") + + return fetch_results + + @staticmethod + def place_output( + request: InferenceRequest, + transform_result: TransformOutputResult, + feature_stores: t.Dict[str, FeatureStore], + ) -> t.Collection[t.Optional[TensorKey]]: + """Given a collection of data, make it available as a shared resource in the + feature store. + + :param request: The request that triggered the pipeline + :param transform_result: Transformed version of the inference result + :param feature_stores: Available feature stores used for persistence + :returns: A collection of keys that were placed in the feature store + :raises ValueError: If a feature store is not provided + """ + if not feature_stores: + raise ValueError("Feature store is required for output persistence") + + keys: t.List[t.Optional[TensorKey]] = [] + # need to decide how to get back to original sub-batch inputs so they can be + # accurately placed, datum might need to include this. + + # Consider parallelizing all PUT feature_store operations + for fs_key, v in zip(request.output_keys, transform_result.outputs): + feature_store = feature_stores[fs_key.descriptor] + feature_store[fs_key.key] = v + keys.append(fs_key) + + return keys + + +class MachineLearningWorkerBase(MachineLearningWorkerCore, ABC): + """Abstract base class providing contract for a machine learning + worker implementation.""" + + @staticmethod + @abstractmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + """Given the raw bytes of an ML model that were fetched, ensure + it is loaded into device memory. + + :param request: The request that triggered the pipeline + :param fetch_result: The result of a fetch-model operation; contains + the raw bytes of the ML model. + :param device: The device on which the model must be placed + :returns: LoadModelResult wrapping the model loaded for the request + :raises ValueError: If model reference object is not found + :raises RuntimeError: If loading and evaluating the model failed + """ + + @staticmethod + @abstractmethod + def transform_input( + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: MemoryPool, + ) -> TransformInputResult: + """Given a collection of data, perform a transformation on the data and put + the raw tensor data on a MemoryPool allocation. + + :param batch: The request that triggered the pipeline + :param fetch_result: Raw outputs from fetching inputs out of a feature store + :param mem_pool: The memory pool used to access batched input tensors + :returns: The transformed inputs wrapped in a TransformInputResult + :raises ValueError: If tensors cannot be reconstructed + :raises IndexError: If index out of range + """ + + @staticmethod + @abstractmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + """Execute an ML model on inputs transformed for use by the model. + + :param batch: The batch of requests that triggered the pipeline + :param load_result: The result of loading the model onto device memory + :param transform_result: The result of transforming inputs for model consumption + :param device: The device on which the model will be executed + :returns: The result of inference wrapped in an ExecuteResult + :raises SmartSimError: If model is not loaded + :raises IndexError: If memory slicing is out of range + :raises ValueError: If tensor creation fails or is unable to evaluate the model + """ + + @staticmethod + @abstractmethod + def transform_output( + batch: RequestBatch, execute_result: ExecuteResult + ) -> t.List[TransformOutputResult]: + """Given inference results, perform transformations required to + transmit results to the requestor. + + :param batch: The batch of requests that triggered the pipeline + :param execute_result: The result of inference wrapped in an ExecuteResult + :returns: A list of transformed outputs + :raises IndexError: If indexing is out of range + :raises ValueError: If transforming output fails + """ diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py new file mode 100644 index 0000000000..e3d46a7ab3 --- /dev/null +++ b/smartsim/_core/mli/message_handler.py @@ -0,0 +1,602 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import typing as t + +from .mli_schemas.data import data_references_capnp +from .mli_schemas.model import model_capnp +from .mli_schemas.request import request_capnp +from .mli_schemas.request.request_attributes import request_attributes_capnp +from .mli_schemas.response import response_capnp +from .mli_schemas.response.response_attributes import response_attributes_capnp +from .mli_schemas.tensor import tensor_capnp + + +class MessageHandler: + """Utility methods for transforming capnproto messages to and from + internal representations. + """ + + @staticmethod + def build_tensor_descriptor( + order: "tensor_capnp.Order", + data_type: "tensor_capnp.NumericalType", + dimensions: t.List[int], + ) -> tensor_capnp.TensorDescriptor: + """ + Builds a TensorDescriptor message using the provided + order, data type, and dimensions. + + :param order: Order of the tensor, such as row-major (c) or column-major (f) + :param data_type: Data type of the tensor + :param dimensions: Dimensions of the tensor + :returns: The TensorDescriptor + :raises ValueError: If building fails + """ + try: + description = tensor_capnp.TensorDescriptor.new_message() + description.order = order + description.dataType = data_type + description.dimensions = dimensions + except Exception as e: + raise ValueError("Error building tensor descriptor.") from e + + return description + + @staticmethod + def build_output_tensor_descriptor( + order: "tensor_capnp.Order", + keys: t.List["data_references_capnp.TensorKey"], + data_type: "tensor_capnp.ReturnNumericalType", + dimensions: t.List[int], + ) -> tensor_capnp.OutputDescriptor: + """ + Builds an OutputDescriptor message using the provided + order, data type, and dimensions. + + :param order: Order of the tensor, such as row-major (c) or column-major (f) + :param keys: List of TensorKey to apply transorm descriptor to + :param data_type: Tranform data type of the tensor + :param dimensions: Transform dimensions of the tensor + :returns: The OutputDescriptor + :raises ValueError: If building fails + """ + try: + description = tensor_capnp.OutputDescriptor.new_message() + description.order = order + description.optionalKeys = keys + description.optionalDatatype = data_type + description.optionalDimension = dimensions + + except Exception as e: + raise ValueError("Error building output tensor descriptor.") from e + + return description + + @staticmethod + def build_tensor_key(key: str, descriptor: str) -> data_references_capnp.TensorKey: + """ + Builds a new TensorKey message with the provided key. + + :param key: String to set the TensorKey + :param descriptor: A descriptor identifying the feature store + containing the key + :returns: The TensorKey + :raises ValueError: If building fails + """ + try: + tensor_key = data_references_capnp.TensorKey.new_message() + tensor_key.key = key + tensor_key.descriptor = descriptor + except Exception as e: + raise ValueError("Error building tensor key.") from e + return tensor_key + + @staticmethod + def build_model(data: bytes, name: str, version: str) -> model_capnp.Model: + """ + Builds a new Model message with the provided data, name, and version. + + :param data: Model data + :param name: Model name + :param version: Model version + :returns: The Model + :raises ValueError: If building fails + """ + try: + model = model_capnp.Model.new_message() + model.data = data + model.name = name + model.version = version + except Exception as e: + raise ValueError("Error building model.") from e + return model + + @staticmethod + def build_model_key(key: str, descriptor: str) -> data_references_capnp.ModelKey: + """ + Builds a new ModelKey message with the provided key. + + :param key: String to set the ModelKey + :param descriptor: A descriptor identifying the feature store + containing the key + :returns: The ModelKey + :raises ValueError: If building fails + """ + try: + model_key = data_references_capnp.ModelKey.new_message() + model_key.key = key + model_key.descriptor = descriptor + except Exception as e: + raise ValueError("Error building tensor key.") from e + return model_key + + @staticmethod + def build_torch_request_attributes( + tensor_type: "request_attributes_capnp.TorchTensorType", + ) -> request_attributes_capnp.TorchRequestAttributes: + """ + Builds a new TorchRequestAttributes message with the provided tensor type. + + :param tensor_type: Type of the tensor passed in + :returns: The TorchRequestAttributes + :raises ValueError: If building fails + """ + try: + attributes = request_attributes_capnp.TorchRequestAttributes.new_message() + attributes.tensorType = tensor_type + except Exception as e: + raise ValueError("Error building Torch request attributes.") from e + return attributes + + @staticmethod + def build_tf_request_attributes( + name: str, tensor_type: "request_attributes_capnp.TFTensorType" + ) -> request_attributes_capnp.TensorFlowRequestAttributes: + """ + Builds a new TensorFlowRequestAttributes message with + the provided name and tensor type. + + :param name: Name of the tensor + :param tensor_type: Type of the tensor passed in + :returns: The TensorFlowRequestAttributes + :raises ValueError: If building fails + """ + try: + attributes = ( + request_attributes_capnp.TensorFlowRequestAttributes.new_message() + ) + attributes.name = name + attributes.tensorType = tensor_type + except Exception as e: + raise ValueError("Error building TensorFlow request attributes.") from e + return attributes + + @staticmethod + def build_torch_response_attributes() -> ( + response_attributes_capnp.TorchResponseAttributes + ): + """ + Builds a new TorchResponseAttributes message. + + :returns: The TorchResponseAttributes + """ + return response_attributes_capnp.TorchResponseAttributes.new_message() + + @staticmethod + def build_tf_response_attributes() -> ( + response_attributes_capnp.TensorFlowResponseAttributes + ): + """ + Builds a new TensorFlowResponseAttributes message. + + :returns: The TensorFlowResponseAttributes + """ + return response_attributes_capnp.TensorFlowResponseAttributes.new_message() + + @staticmethod + def _assign_model( + request: request_capnp.Request, + model: t.Union[data_references_capnp.ModelKey, model_capnp.Model], + ) -> None: + """ + Assigns a model to the supplied request. + + :param request: Request being built + :param model: Model to be assigned + :raises ValueError: If building fails + """ + try: + class_name = model.schema.node.displayName.split(":")[-1] # type: ignore + if class_name == "Model": + request.model.data = model # type: ignore + elif class_name == "ModelKey": + request.model.key = model # type: ignore + else: + raise ValueError("""Invalid custom attribute class name. + Expected 'Model' or 'ModelKey'.""") + except Exception as e: + raise ValueError("Error building model portion of request.") from e + + @staticmethod + def _assign_reply_channel( + request: request_capnp.Request, reply_channel: str + ) -> None: + """ + Assigns a reply channel to the supplied request. + + :param request: Request being built + :param reply_channel: Reply channel to be assigned + :raises ValueError: If building fails + """ + try: + request.replyChannel.descriptor = reply_channel + except Exception as e: + raise ValueError("Error building reply channel portion of request.") from e + + @staticmethod + def _assign_inputs( + request: request_capnp.Request, + inputs: t.Union[ + t.List[data_references_capnp.TensorKey], + t.List[tensor_capnp.TensorDescriptor], + ], + ) -> None: + """ + Assigns inputs to the supplied request. + + :param request: Request being built + :param inputs: Inputs to be assigned + :raises ValueError: If building fails + """ + try: + if inputs: + display_name = inputs[0].schema.node.displayName # type: ignore + input_class_name = display_name.split(":")[-1] + if input_class_name == "TensorDescriptor": + request.input.descriptors = inputs # type: ignore + elif input_class_name == "TensorKey": + request.input.keys = inputs # type: ignore + else: + raise ValueError("""Invalid input class name. Expected + 'TensorDescriptor' or 'TensorKey'.""") + except Exception as e: + raise ValueError("Error building inputs portion of request.") from e + + @staticmethod + def _assign_outputs( + request: request_capnp.Request, + outputs: t.List[data_references_capnp.TensorKey], + ) -> None: + """ + Assigns outputs to the supplied request. + + :param request: Request being built + :param outputs: Outputs to be assigned + :raises ValueError: If building fails + """ + try: + request.output = outputs + + except Exception as e: + raise ValueError("Error building outputs portion of request.") from e + + @staticmethod + def _assign_output_descriptors( + request: request_capnp.Request, + output_descriptors: t.List[tensor_capnp.OutputDescriptor], + ) -> None: + """ + Assigns a list of output tensor descriptors to the supplied request. + + :param request: Request being built + :param output_descriptors: Output descriptors to be assigned + :raises ValueError: If building fails + """ + try: + request.outputDescriptors = output_descriptors + except Exception as e: + raise ValueError( + "Error building the output descriptors portion of request." + ) from e + + @staticmethod + def _assign_custom_request_attributes( + request: request_capnp.Request, + custom_attrs: t.Union[ + request_attributes_capnp.TorchRequestAttributes, + request_attributes_capnp.TensorFlowRequestAttributes, + None, + ], + ) -> None: + """ + Assigns request attributes to the supplied request. + + :param request: Request being built + :param custom_attrs: Custom attributes to be assigned + :raises ValueError: If building fails + """ + try: + if custom_attrs is None: + request.customAttributes.none = custom_attrs + else: + custom_attribute_class_name = ( + custom_attrs.schema.node.displayName.split(":")[-1] # type: ignore + ) + if custom_attribute_class_name == "TorchRequestAttributes": + request.customAttributes.torch = custom_attrs # type: ignore + elif custom_attribute_class_name == "TensorFlowRequestAttributes": + request.customAttributes.tf = custom_attrs # type: ignore + else: + raise ValueError("""Invalid custom attribute class name. + Expected 'TensorFlowRequestAttributes' or + 'TorchRequestAttributes'.""") + except Exception as e: + raise ValueError( + "Error building custom attributes portion of request." + ) from e + + @staticmethod + def build_request( + reply_channel: str, + model: t.Union[data_references_capnp.ModelKey, model_capnp.Model], + inputs: t.Union[ + t.List[data_references_capnp.TensorKey], + t.List[tensor_capnp.TensorDescriptor], + ], + outputs: t.List[data_references_capnp.TensorKey], + output_descriptors: t.List[tensor_capnp.OutputDescriptor], + custom_attributes: t.Union[ + request_attributes_capnp.TorchRequestAttributes, + request_attributes_capnp.TensorFlowRequestAttributes, + None, + ], + ) -> request_capnp.RequestBuilder: + """ + Builds the request message. + + :param reply_channel: Reply channel to be assigned to request + :param model: Model to be assigned to request + :param inputs: Inputs to be assigned to request + :param outputs: Outputs to be assigned to request + :param output_descriptors: Output descriptors to be assigned to request + :param custom_attributes: Custom attributes to be assigned to request + :returns: The Request + """ + request = request_capnp.Request.new_message() + MessageHandler._assign_reply_channel(request, reply_channel) + MessageHandler._assign_model(request, model) + MessageHandler._assign_inputs(request, inputs) + MessageHandler._assign_outputs(request, outputs) + MessageHandler._assign_output_descriptors(request, output_descriptors) + MessageHandler._assign_custom_request_attributes(request, custom_attributes) + return request + + @staticmethod + def serialize_request(request: request_capnp.RequestBuilder) -> bytes: + """ + Serializes a built request message. + + :param request: Request to be serialized + :returns: Serialized request bytes + :raises ValueError: If serialization fails + """ + display_name = request.schema.node.displayName # type: ignore + class_name = display_name.split(":")[-1] + if class_name != "Request": + raise ValueError( + "Error serializing the request. Value passed in is not " + f"a request: {class_name}" + ) + try: + return request.to_bytes() + except Exception as e: + raise ValueError("Error serializing the request") from e + + @staticmethod + def deserialize_request(request_bytes: bytes) -> request_capnp.Request: + """ + Deserializes a serialized request message. + + :param request_bytes: Bytes to be deserialized into a request + :returns: Deserialized request + :raises ValueError: If deserialization fails + """ + try: + bytes_message = request_capnp.Request.from_bytes( + request_bytes, traversal_limit_in_words=2**63 + ) + + with bytes_message as message: + return message + except Exception as e: + raise ValueError("Error deserializing the request") from e + + @staticmethod + def _assign_status( + response: response_capnp.Response, status: "response_capnp.Status" + ) -> None: + """ + Assigns a status to the supplied response. + + :param response: Response being built + :param status: Status to be assigned + :raises ValueError: If building fails + """ + try: + response.status = status + except Exception as e: + raise ValueError("Error assigning status to response.") from e + + @staticmethod + def _assign_message(response: response_capnp.Response, message: str) -> None: + """ + Assigns a message to the supplied response. + + :param response: Response being built + :param message: Message to be assigned + :raises ValueError: If building fails + """ + try: + response.message = message + except Exception as e: + raise ValueError("Error assigning message to response.") from e + + @staticmethod + def _assign_result( + response: response_capnp.Response, + result: t.Union[ + t.List[tensor_capnp.TensorDescriptor], + t.List[data_references_capnp.TensorKey], + None, + ], + ) -> None: + """ + Assigns a result to the supplied response. + + :param response: Response being built + :param result: Result to be assigned + :raises ValueError: If building fails + """ + try: + if result: + first_result = result[0] + display_name = first_result.schema.node.displayName # type: ignore + result_class_name = display_name.split(":")[-1] + if result_class_name == "TensorDescriptor": + response.result.descriptors = result # type: ignore + elif result_class_name == "TensorKey": + response.result.keys = result # type: ignore + else: + raise ValueError("""Invalid custom attribute class name. + Expected 'TensorDescriptor' or 'TensorKey'.""") + except Exception as e: + raise ValueError("Error assigning result to response.") from e + + @staticmethod + def _assign_custom_response_attributes( + response: response_capnp.Response, + custom_attrs: t.Union[ + response_attributes_capnp.TorchResponseAttributes, + response_attributes_capnp.TensorFlowResponseAttributes, + None, + ], + ) -> None: + """ + Assigns custom attributes to the supplied response. + + :param response: Response being built + :param custom_attrs: Custom attributes to be assigned + :raises ValueError: If building fails + """ + try: + if custom_attrs is None: + response.customAttributes.none = custom_attrs + else: + custom_attribute_class_name = ( + custom_attrs.schema.node.displayName.split(":")[-1] # type: ignore + ) + if custom_attribute_class_name == "TorchResponseAttributes": + response.customAttributes.torch = custom_attrs # type: ignore + elif custom_attribute_class_name == "TensorFlowResponseAttributes": + response.customAttributes.tf = custom_attrs # type: ignore + else: + raise ValueError("""Invalid custom attribute class name. + Expected 'TensorFlowResponseAttributes' or + 'TorchResponseAttributes'.""") + except Exception as e: + raise ValueError("Error assigning custom attributes to response.") from e + + @staticmethod + def build_response( + status: "response_capnp.Status", + message: str, + result: t.Union[ + t.List[tensor_capnp.TensorDescriptor], + t.List[data_references_capnp.TensorKey], + None, + ], + custom_attributes: t.Union[ + response_attributes_capnp.TorchResponseAttributes, + response_attributes_capnp.TensorFlowResponseAttributes, + None, + ], + ) -> response_capnp.ResponseBuilder: + """ + Builds the response message. + + :param status: Status to be assigned to response + :param message: Message to be assigned to response + :param result: Result to be assigned to response + :param custom_attributes: Custom attributes to be assigned to response + :returns: The Response + """ + response = response_capnp.Response.new_message() + MessageHandler._assign_status(response, status) + MessageHandler._assign_message(response, message) + MessageHandler._assign_result(response, result) + MessageHandler._assign_custom_response_attributes(response, custom_attributes) + return response + + @staticmethod + def serialize_response(response: response_capnp.ResponseBuilder) -> bytes: + """ + Serializes a built response message. + + :param response: Response to be serialized + :returns: Serialized response bytes + :raises ValueError: If serialization fails + """ + display_name = response.schema.node.displayName # type: ignore + class_name = display_name.split(":")[-1] + if class_name != "Response": + raise ValueError( + "Error serializing the response. Value passed in is not " + f"a response: {class_name}" + ) + try: + return response.to_bytes() + except Exception as e: + raise ValueError("Error serializing the response") from e + + @staticmethod + def deserialize_response(response_bytes: bytes) -> response_capnp.Response: + """ + Deserializes a serialized response message. + + :param response_bytes: Bytes to be deserialized into a response + :returns: Deserialized response + :raises ValueError: If deserialization fails + """ + try: + bytes_message = response_capnp.Response.from_bytes( + response_bytes, traversal_limit_in_words=2**63 + ) + + with bytes_message as message: + return message + + except Exception as e: + raise ValueError("Error deserializing the response") from e diff --git a/smartsim/_core/mli/mli_schemas/data/data_references.capnp b/smartsim/_core/mli/mli_schemas/data/data_references.capnp new file mode 100644 index 0000000000..65293be7b2 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/data/data_references.capnp @@ -0,0 +1,37 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0x8ca69fd1aacb6668; + +struct ModelKey { + key @0 :Text; + descriptor @1 :Text; +} + +struct TensorKey { + key @0 :Text; + descriptor @1 :Text; +} diff --git a/smartsim/_core/mli/mli_schemas/data/data_references_capnp.py b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.py new file mode 100644 index 0000000000..099d10c438 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `data_references.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "data_references.capnp")) +ModelKey = capnp.load(module_file).ModelKey +ModelKeyBuilder = ModelKey +ModelKeyReader = ModelKey +TensorKey = capnp.load(module_file).TensorKey +TensorKeyBuilder = TensorKey +TensorKeyReader = TensorKey diff --git a/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi new file mode 100644 index 0000000000..a5e318a556 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi @@ -0,0 +1,107 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `data_references.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator + +class ModelKey: + key: str + descriptor: str + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[ModelKeyReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> ModelKeyReader: ... + @staticmethod + def new_message() -> ModelKeyBuilder: ... + def to_dict(self) -> dict: ... + +class ModelKeyReader(ModelKey): + def as_builder(self) -> ModelKeyBuilder: ... + +class ModelKeyBuilder(ModelKey): + @staticmethod + def from_dict(dictionary: dict) -> ModelKeyBuilder: ... + def copy(self) -> ModelKeyBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> ModelKeyReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class TensorKey: + key: str + descriptor: str + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TensorKeyReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TensorKeyReader: ... + @staticmethod + def new_message() -> TensorKeyBuilder: ... + def to_dict(self) -> dict: ... + +class TensorKeyReader(TensorKey): + def as_builder(self) -> TensorKeyBuilder: ... + +class TensorKeyBuilder(TensorKey): + @staticmethod + def from_dict(dictionary: dict) -> TensorKeyBuilder: ... + def copy(self) -> TensorKeyBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TensorKeyReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/model/__init__.py b/smartsim/_core/mli/mli_schemas/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/mli_schemas/model/model.capnp b/smartsim/_core/mli/mli_schemas/model/model.capnp new file mode 100644 index 0000000000..fc9ed73663 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/model/model.capnp @@ -0,0 +1,33 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xaefb9301e14ba4bd; + +struct Model { + data @0 :Data; + name @1 :Text; + version @2 :Text; +} diff --git a/smartsim/_core/mli/mli_schemas/model/model_capnp.py b/smartsim/_core/mli/mli_schemas/model/model_capnp.py new file mode 100644 index 0000000000..be2c276c23 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/model/model_capnp.py @@ -0,0 +1,38 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `model.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "model.capnp")) +Model = capnp.load(module_file).Model +ModelBuilder = Model +ModelReader = Model diff --git a/smartsim/_core/mli/mli_schemas/model/model_capnp.pyi b/smartsim/_core/mli/mli_schemas/model/model_capnp.pyi new file mode 100644 index 0000000000..6ca53a3579 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/model/model_capnp.pyi @@ -0,0 +1,72 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `model.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator + +class Model: + data: bytes + name: str + version: str + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[ModelReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> ModelReader: ... + @staticmethod + def new_message() -> ModelBuilder: ... + def to_dict(self) -> dict: ... + +class ModelReader(Model): + def as_builder(self) -> ModelBuilder: ... + +class ModelBuilder(Model): + @staticmethod + def from_dict(dictionary: dict) -> ModelBuilder: ... + def copy(self) -> ModelBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> ModelReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/request/request.capnp b/smartsim/_core/mli/mli_schemas/request/request.capnp new file mode 100644 index 0000000000..26d9542d9f --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request.capnp @@ -0,0 +1,55 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xa27f0152c7bb299e; + +using Tensors = import "../tensor/tensor.capnp"; +using RequestAttributes = import "request_attributes/request_attributes.capnp"; +using DataRef = import "../data/data_references.capnp"; +using Models = import "../model/model.capnp"; + +struct ChannelDescriptor { + descriptor @0 :Text; +} + +struct Request { + replyChannel @0 :ChannelDescriptor; + model :union { + key @1 :DataRef.ModelKey; + data @2 :Models.Model; + } + input :union { + keys @3 :List(DataRef.TensorKey); + descriptors @4 :List(Tensors.TensorDescriptor); + } + output @5 :List(DataRef.TensorKey); + outputDescriptors @6 :List(Tensors.OutputDescriptor); + customAttributes :union { + torch @7 :RequestAttributes.TorchRequestAttributes; + tf @8 :RequestAttributes.TensorFlowRequestAttributes; + none @9 :Void; + } +} diff --git a/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp new file mode 100644 index 0000000000..f0a319f0a3 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp @@ -0,0 +1,49 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xdd14d8ba5c06743f; + +enum TorchTensorType { + nested @0; # ragged + sparse @1; + tensor @2; # "normal" tensor +} + +enum TFTensorType { + ragged @0; + sparse @1; + variable @2; + constant @3; +} + +struct TorchRequestAttributes { + tensorType @0 :TorchTensorType; +} + +struct TensorFlowRequestAttributes { + name @0 :Text; + tensorType @1 :TFTensorType; +} diff --git a/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py new file mode 100644 index 0000000000..8969f38457 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `request_attributes.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "request_attributes.capnp")) +TorchRequestAttributes = capnp.load(module_file).TorchRequestAttributes +TorchRequestAttributesBuilder = TorchRequestAttributes +TorchRequestAttributesReader = TorchRequestAttributes +TensorFlowRequestAttributes = capnp.load(module_file).TensorFlowRequestAttributes +TensorFlowRequestAttributesBuilder = TensorFlowRequestAttributes +TensorFlowRequestAttributesReader = TensorFlowRequestAttributes diff --git a/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi new file mode 100644 index 0000000000..c474de4b4f --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi @@ -0,0 +1,109 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `request_attributes.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator, Literal + +TorchTensorType = Literal["nested", "sparse", "tensor"] +TFTensorType = Literal["ragged", "sparse", "variable", "constant"] + +class TorchRequestAttributes: + tensorType: TorchTensorType + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TorchRequestAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TorchRequestAttributesReader: ... + @staticmethod + def new_message() -> TorchRequestAttributesBuilder: ... + def to_dict(self) -> dict: ... + +class TorchRequestAttributesReader(TorchRequestAttributes): + def as_builder(self) -> TorchRequestAttributesBuilder: ... + +class TorchRequestAttributesBuilder(TorchRequestAttributes): + @staticmethod + def from_dict(dictionary: dict) -> TorchRequestAttributesBuilder: ... + def copy(self) -> TorchRequestAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TorchRequestAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class TensorFlowRequestAttributes: + name: str + tensorType: TFTensorType + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TensorFlowRequestAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TensorFlowRequestAttributesReader: ... + @staticmethod + def new_message() -> TensorFlowRequestAttributesBuilder: ... + def to_dict(self) -> dict: ... + +class TensorFlowRequestAttributesReader(TensorFlowRequestAttributes): + def as_builder(self) -> TensorFlowRequestAttributesBuilder: ... + +class TensorFlowRequestAttributesBuilder(TensorFlowRequestAttributes): + @staticmethod + def from_dict(dictionary: dict) -> TensorFlowRequestAttributesBuilder: ... + def copy(self) -> TensorFlowRequestAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TensorFlowRequestAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/request/request_capnp.py b/smartsim/_core/mli/mli_schemas/request/request_capnp.py new file mode 100644 index 0000000000..90b8ce194e --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `request.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "request.capnp")) +ChannelDescriptor = capnp.load(module_file).ChannelDescriptor +ChannelDescriptorBuilder = ChannelDescriptor +ChannelDescriptorReader = ChannelDescriptor +Request = capnp.load(module_file).Request +RequestBuilder = Request +RequestReader = Request diff --git a/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi new file mode 100644 index 0000000000..2aab80b1d0 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi @@ -0,0 +1,319 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `request.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator, Literal, Sequence, overload + +from ..data.data_references_capnp import ( + ModelKey, + ModelKeyBuilder, + ModelKeyReader, + TensorKey, + TensorKeyBuilder, + TensorKeyReader, +) +from ..model.model_capnp import Model, ModelBuilder, ModelReader +from ..tensor.tensor_capnp import ( + OutputDescriptor, + OutputDescriptorBuilder, + OutputDescriptorReader, + TensorDescriptor, + TensorDescriptorBuilder, + TensorDescriptorReader, +) +from .request_attributes.request_attributes_capnp import ( + TensorFlowRequestAttributes, + TensorFlowRequestAttributesBuilder, + TensorFlowRequestAttributesReader, + TorchRequestAttributes, + TorchRequestAttributesBuilder, + TorchRequestAttributesReader, +) + +class ChannelDescriptor: + descriptor: str + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[ChannelDescriptorReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> ChannelDescriptorReader: ... + @staticmethod + def new_message() -> ChannelDescriptorBuilder: ... + def to_dict(self) -> dict: ... + +class ChannelDescriptorReader(ChannelDescriptor): + def as_builder(self) -> ChannelDescriptorBuilder: ... + +class ChannelDescriptorBuilder(ChannelDescriptor): + @staticmethod + def from_dict(dictionary: dict) -> ChannelDescriptorBuilder: ... + def copy(self) -> ChannelDescriptorBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> ChannelDescriptorReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class Request: + class Model: + key: ModelKey | ModelKeyBuilder | ModelKeyReader + data: Model | ModelBuilder | ModelReader + def which(self) -> Literal["key", "data"]: ... + @overload + def init(self, name: Literal["key"]) -> ModelKey: ... + @overload + def init(self, name: Literal["data"]) -> Model: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Request.ModelReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Request.ModelReader: ... + @staticmethod + def new_message() -> Request.ModelBuilder: ... + def to_dict(self) -> dict: ... + + class ModelReader(Request.Model): + key: ModelKeyReader + data: ModelReader + def as_builder(self) -> Request.ModelBuilder: ... + + class ModelBuilder(Request.Model): + key: ModelKey | ModelKeyBuilder | ModelKeyReader + data: Model | ModelBuilder | ModelReader + @staticmethod + def from_dict(dictionary: dict) -> Request.ModelBuilder: ... + def copy(self) -> Request.ModelBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Request.ModelReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + + class Input: + keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + def which(self) -> Literal["keys", "descriptors"]: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Request.InputReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Request.InputReader: ... + @staticmethod + def new_message() -> Request.InputBuilder: ... + def to_dict(self) -> dict: ... + + class InputReader(Request.Input): + keys: Sequence[TensorKeyReader] + descriptors: Sequence[TensorDescriptorReader] + def as_builder(self) -> Request.InputBuilder: ... + + class InputBuilder(Request.Input): + keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + @staticmethod + def from_dict(dictionary: dict) -> Request.InputBuilder: ... + def copy(self) -> Request.InputBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Request.InputReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + + class CustomAttributes: + torch: ( + TorchRequestAttributes + | TorchRequestAttributesBuilder + | TorchRequestAttributesReader + ) + tf: ( + TensorFlowRequestAttributes + | TensorFlowRequestAttributesBuilder + | TensorFlowRequestAttributesReader + ) + none: None + def which(self) -> Literal["torch", "tf", "none"]: ... + @overload + def init(self, name: Literal["torch"]) -> TorchRequestAttributes: ... + @overload + def init(self, name: Literal["tf"]) -> TensorFlowRequestAttributes: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Request.CustomAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Request.CustomAttributesReader: ... + @staticmethod + def new_message() -> Request.CustomAttributesBuilder: ... + def to_dict(self) -> dict: ... + + class CustomAttributesReader(Request.CustomAttributes): + torch: TorchRequestAttributesReader + tf: TensorFlowRequestAttributesReader + def as_builder(self) -> Request.CustomAttributesBuilder: ... + + class CustomAttributesBuilder(Request.CustomAttributes): + torch: ( + TorchRequestAttributes + | TorchRequestAttributesBuilder + | TorchRequestAttributesReader + ) + tf: ( + TensorFlowRequestAttributes + | TensorFlowRequestAttributesBuilder + | TensorFlowRequestAttributesReader + ) + @staticmethod + def from_dict(dictionary: dict) -> Request.CustomAttributesBuilder: ... + def copy(self) -> Request.CustomAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Request.CustomAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + replyChannel: ChannelDescriptor | ChannelDescriptorBuilder | ChannelDescriptorReader + model: Request.Model | Request.ModelBuilder | Request.ModelReader + input: Request.Input | Request.InputBuilder | Request.InputReader + output: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + outputDescriptors: Sequence[ + OutputDescriptor | OutputDescriptorBuilder | OutputDescriptorReader + ] + customAttributes: ( + Request.CustomAttributes + | Request.CustomAttributesBuilder + | Request.CustomAttributesReader + ) + @overload + def init(self, name: Literal["replyChannel"]) -> ChannelDescriptor: ... + @overload + def init(self, name: Literal["model"]) -> Model: ... + @overload + def init(self, name: Literal["input"]) -> Input: ... + @overload + def init(self, name: Literal["customAttributes"]) -> CustomAttributes: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[RequestReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> RequestReader: ... + @staticmethod + def new_message() -> RequestBuilder: ... + def to_dict(self) -> dict: ... + +class RequestReader(Request): + replyChannel: ChannelDescriptorReader + model: Request.ModelReader + input: Request.InputReader + output: Sequence[TensorKeyReader] + outputDescriptors: Sequence[OutputDescriptorReader] + customAttributes: Request.CustomAttributesReader + def as_builder(self) -> RequestBuilder: ... + +class RequestBuilder(Request): + replyChannel: ChannelDescriptor | ChannelDescriptorBuilder | ChannelDescriptorReader + model: Request.Model | Request.ModelBuilder | Request.ModelReader + input: Request.Input | Request.InputBuilder | Request.InputReader + output: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + outputDescriptors: Sequence[ + OutputDescriptor | OutputDescriptorBuilder | OutputDescriptorReader + ] + customAttributes: ( + Request.CustomAttributes + | Request.CustomAttributesBuilder + | Request.CustomAttributesReader + ) + @staticmethod + def from_dict(dictionary: dict) -> RequestBuilder: ... + def copy(self) -> RequestBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> RequestReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/response/response.capnp b/smartsim/_core/mli/mli_schemas/response/response.capnp new file mode 100644 index 0000000000..7194524cd0 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response.capnp @@ -0,0 +1,52 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xa05dcb4444780705; + +using Tensors = import "../tensor/tensor.capnp"; +using ResponseAttributes = import "response_attributes/response_attributes.capnp"; +using DataRef = import "../data/data_references.capnp"; + +enum Status { + complete @0; + fail @1; + timeout @2; + running @3; +} + +struct Response { + status @0 :Status; + message @1 :Text; + result :union { + keys @2 :List(DataRef.TensorKey); + descriptors @3 :List(Tensors.TensorDescriptor); + } + customAttributes :union { + torch @4 :ResponseAttributes.TorchResponseAttributes; + tf @5 :ResponseAttributes.TensorFlowResponseAttributes; + none @6 :Void; + } +} diff --git a/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp new file mode 100644 index 0000000000..b4dcf18e88 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp @@ -0,0 +1,33 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xee59c60fccbb1bf9; + +struct TorchResponseAttributes { +} + +struct TensorFlowResponseAttributes { +} diff --git a/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py new file mode 100644 index 0000000000..4839334d52 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `response_attributes.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "response_attributes.capnp")) +TorchResponseAttributes = capnp.load(module_file).TorchResponseAttributes +TorchResponseAttributesBuilder = TorchResponseAttributes +TorchResponseAttributesReader = TorchResponseAttributes +TensorFlowResponseAttributes = capnp.load(module_file).TensorFlowResponseAttributes +TensorFlowResponseAttributesBuilder = TensorFlowResponseAttributes +TensorFlowResponseAttributesReader = TensorFlowResponseAttributes diff --git a/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi new file mode 100644 index 0000000000..f40688d74a --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi @@ -0,0 +1,103 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `response_attributes.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator + +class TorchResponseAttributes: + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TorchResponseAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TorchResponseAttributesReader: ... + @staticmethod + def new_message() -> TorchResponseAttributesBuilder: ... + def to_dict(self) -> dict: ... + +class TorchResponseAttributesReader(TorchResponseAttributes): + def as_builder(self) -> TorchResponseAttributesBuilder: ... + +class TorchResponseAttributesBuilder(TorchResponseAttributes): + @staticmethod + def from_dict(dictionary: dict) -> TorchResponseAttributesBuilder: ... + def copy(self) -> TorchResponseAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TorchResponseAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class TensorFlowResponseAttributes: + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TensorFlowResponseAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TensorFlowResponseAttributesReader: ... + @staticmethod + def new_message() -> TensorFlowResponseAttributesBuilder: ... + def to_dict(self) -> dict: ... + +class TensorFlowResponseAttributesReader(TensorFlowResponseAttributes): + def as_builder(self) -> TensorFlowResponseAttributesBuilder: ... + +class TensorFlowResponseAttributesBuilder(TensorFlowResponseAttributes): + @staticmethod + def from_dict(dictionary: dict) -> TensorFlowResponseAttributesBuilder: ... + def copy(self) -> TensorFlowResponseAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TensorFlowResponseAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.py b/smartsim/_core/mli/mli_schemas/response/response_capnp.py new file mode 100644 index 0000000000..eaa3451045 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.py @@ -0,0 +1,38 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `response.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "response.capnp")) +Response = capnp.load(module_file).Response +ResponseBuilder = Response +ResponseReader = Response diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi new file mode 100644 index 0000000000..6b4c50fd05 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi @@ -0,0 +1,212 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `response.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator, Literal, Sequence, overload + +from ..data.data_references_capnp import TensorKey, TensorKeyBuilder, TensorKeyReader +from ..tensor.tensor_capnp import ( + TensorDescriptor, + TensorDescriptorBuilder, + TensorDescriptorReader, +) +from .response_attributes.response_attributes_capnp import ( + TensorFlowResponseAttributes, + TensorFlowResponseAttributesBuilder, + TensorFlowResponseAttributesReader, + TorchResponseAttributes, + TorchResponseAttributesBuilder, + TorchResponseAttributesReader, +) + +Status = Literal["complete", "fail", "timeout", "running"] + +class Response: + class Result: + keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + def which(self) -> Literal["keys", "descriptors"]: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Response.ResultReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Response.ResultReader: ... + @staticmethod + def new_message() -> Response.ResultBuilder: ... + def to_dict(self) -> dict: ... + + class ResultReader(Response.Result): + keys: Sequence[TensorKeyReader] + descriptors: Sequence[TensorDescriptorReader] + def as_builder(self) -> Response.ResultBuilder: ... + + class ResultBuilder(Response.Result): + keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + @staticmethod + def from_dict(dictionary: dict) -> Response.ResultBuilder: ... + def copy(self) -> Response.ResultBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Response.ResultReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + + class CustomAttributes: + torch: ( + TorchResponseAttributes + | TorchResponseAttributesBuilder + | TorchResponseAttributesReader + ) + tf: ( + TensorFlowResponseAttributes + | TensorFlowResponseAttributesBuilder + | TensorFlowResponseAttributesReader + ) + none: None + def which(self) -> Literal["torch", "tf", "none"]: ... + @overload + def init(self, name: Literal["torch"]) -> TorchResponseAttributes: ... + @overload + def init(self, name: Literal["tf"]) -> TensorFlowResponseAttributes: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Response.CustomAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Response.CustomAttributesReader: ... + @staticmethod + def new_message() -> Response.CustomAttributesBuilder: ... + def to_dict(self) -> dict: ... + + class CustomAttributesReader(Response.CustomAttributes): + torch: TorchResponseAttributesReader + tf: TensorFlowResponseAttributesReader + def as_builder(self) -> Response.CustomAttributesBuilder: ... + + class CustomAttributesBuilder(Response.CustomAttributes): + torch: ( + TorchResponseAttributes + | TorchResponseAttributesBuilder + | TorchResponseAttributesReader + ) + tf: ( + TensorFlowResponseAttributes + | TensorFlowResponseAttributesBuilder + | TensorFlowResponseAttributesReader + ) + @staticmethod + def from_dict(dictionary: dict) -> Response.CustomAttributesBuilder: ... + def copy(self) -> Response.CustomAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Response.CustomAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + status: Status + message: str + result: Response.Result | Response.ResultBuilder | Response.ResultReader + customAttributes: ( + Response.CustomAttributes + | Response.CustomAttributesBuilder + | Response.CustomAttributesReader + ) + @overload + def init(self, name: Literal["result"]) -> Result: ... + @overload + def init(self, name: Literal["customAttributes"]) -> CustomAttributes: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[ResponseReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> ResponseReader: ... + @staticmethod + def new_message() -> ResponseBuilder: ... + def to_dict(self) -> dict: ... + +class ResponseReader(Response): + result: Response.ResultReader + customAttributes: Response.CustomAttributesReader + def as_builder(self) -> ResponseBuilder: ... + +class ResponseBuilder(Response): + result: Response.Result | Response.ResultBuilder | Response.ResultReader + customAttributes: ( + Response.CustomAttributes + | Response.CustomAttributesBuilder + | Response.CustomAttributesReader + ) + @staticmethod + def from_dict(dictionary: dict) -> ResponseBuilder: ... + def copy(self) -> ResponseBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> ResponseReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp b/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp new file mode 100644 index 0000000000..4b2218b166 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp @@ -0,0 +1,75 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0x9a0aeb2e04838fb1; + +using DataRef = import "../data/data_references.capnp"; + +enum Order { + c @0; # row major (contiguous layout) + f @1; # column major (fortran contiguous layout) +} + +enum NumericalType { + int8 @0; + int16 @1; + int32 @2; + int64 @3; + uInt8 @4; + uInt16 @5; + uInt32 @6; + uInt64 @7; + float32 @8; + float64 @9; +} + +enum ReturnNumericalType { + int8 @0; + int16 @1; + int32 @2; + int64 @3; + uInt8 @4; + uInt16 @5; + uInt32 @6; + uInt64 @7; + float32 @8; + float64 @9; + none @10; + auto @11; +} + +struct TensorDescriptor { + dimensions @0 :List(Int32); + order @1 :Order; + dataType @2 :NumericalType; +} + +struct OutputDescriptor { + order @0 :Order; + optionalKeys @1 :List(DataRef.TensorKey); + optionalDimension @2 :List(Int32); + optionalDatatype @3 :ReturnNumericalType; +} diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py new file mode 100644 index 0000000000..8c9d6c9029 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `tensor.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "tensor.capnp")) +TensorDescriptor = capnp.load(module_file).TensorDescriptor +TensorDescriptorBuilder = TensorDescriptor +TensorDescriptorReader = TensorDescriptor +OutputDescriptor = capnp.load(module_file).OutputDescriptor +OutputDescriptorBuilder = OutputDescriptor +OutputDescriptorReader = OutputDescriptor diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi new file mode 100644 index 0000000000..b55f26b452 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi @@ -0,0 +1,142 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `tensor.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator, Literal, Sequence + +from ..data.data_references_capnp import TensorKey, TensorKeyBuilder, TensorKeyReader + +Order = Literal["c", "f"] +NumericalType = Literal[ + "int8", + "int16", + "int32", + "int64", + "uInt8", + "uInt16", + "uInt32", + "uInt64", + "float32", + "float64", +] +ReturnNumericalType = Literal[ + "int8", + "int16", + "int32", + "int64", + "uInt8", + "uInt16", + "uInt32", + "uInt64", + "float32", + "float64", + "none", + "auto", +] + +class TensorDescriptor: + dimensions: Sequence[int] + order: Order + dataType: NumericalType + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TensorDescriptorReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TensorDescriptorReader: ... + @staticmethod + def new_message() -> TensorDescriptorBuilder: ... + def to_dict(self) -> dict: ... + +class TensorDescriptorReader(TensorDescriptor): + def as_builder(self) -> TensorDescriptorBuilder: ... + +class TensorDescriptorBuilder(TensorDescriptor): + @staticmethod + def from_dict(dictionary: dict) -> TensorDescriptorBuilder: ... + def copy(self) -> TensorDescriptorBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TensorDescriptorReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class OutputDescriptor: + order: Order + optionalKeys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + optionalDimension: Sequence[int] + optionalDatatype: ReturnNumericalType + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[OutputDescriptorReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> OutputDescriptorReader: ... + @staticmethod + def new_message() -> OutputDescriptorBuilder: ... + def to_dict(self) -> dict: ... + +class OutputDescriptorReader(OutputDescriptor): + optionalKeys: Sequence[TensorKeyReader] + def as_builder(self) -> OutputDescriptorBuilder: ... + +class OutputDescriptorBuilder(OutputDescriptor): + optionalKeys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + @staticmethod + def from_dict(dictionary: dict) -> OutputDescriptorBuilder: ... + def copy(self) -> OutputDescriptorBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> OutputDescriptorReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/schemas/utils.py b/smartsim/_core/schemas/utils.py index 9cb36bcf57..905fe8955c 100644 --- a/smartsim/_core/schemas/utils.py +++ b/smartsim/_core/schemas/utils.py @@ -48,7 +48,7 @@ class _Message(t.Generic[_SchemaT]): delimiter: str = pydantic.Field(min_length=1, default=_DEFAULT_MSG_DELIM) def __str__(self) -> str: - return self.delimiter.join((self.header, self.payload.json())) + return self.delimiter.join((self.header, self.payload.model_dump_json())) @classmethod def from_str( @@ -58,7 +58,7 @@ def from_str( delimiter: str = _DEFAULT_MSG_DELIM, ) -> "_Message[_SchemaT]": header, payload = str_.split(delimiter, 1) - return cls(payload_type.parse_raw(payload), header, delimiter) + return cls(payload_type.model_validate_json(payload), header, delimiter) class SchemaRegistry(t.Generic[_SchemaT]): diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index b17be763b4..bf5838928e 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -220,7 +220,7 @@ def _installed(base_path: Path, backend: str) -> bool: """ backend_key = f"redisai_{backend}" backend_path = base_path / backend_key / f"{backend_key}.so" - backend_so = Path(os.environ.get("RAI_PATH", backend_path)).resolve() + backend_so = Path(os.environ.get("SMARTSIM_RAI_LIB", backend_path)).resolve() return backend_so.is_file() diff --git a/smartsim/_core/utils/timings.py b/smartsim/_core/utils/timings.py new file mode 100644 index 0000000000..f99950739e --- /dev/null +++ b/smartsim/_core/utils/timings.py @@ -0,0 +1,175 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import time +import typing as t +from collections import OrderedDict + +import numpy as np + +from ...log import get_logger + +logger = get_logger("PerfTimer") + + +class PerfTimer: + def __init__( + self, + filename: str = "timings", + prefix: str = "", + timing_on: bool = True, + debug: bool = False, + ): + self._start: t.Optional[float] = None + self._interm: t.Optional[float] = None + self._timings: OrderedDict[str, list[t.Union[float, int, str]]] = OrderedDict() + self._timing_on = timing_on + self._filename = filename + self._prefix = prefix + self._debug = debug + + def _add_label_to_timings(self, label: str) -> None: + if label not in self._timings: + self._timings[label] = [] + + @staticmethod + def _format_number(number: t.Union[float, int]) -> str: + """Formats the input value with a fixed precision appropriate for logging""" + return f"{number:0.4e}" + + def start_timings( + self, + first_label: t.Optional[str] = None, + first_value: t.Optional[t.Union[float, int]] = None, + ) -> None: + """Start a recording session by recording + + :param first_label: a label for an event that will be manually prepended + to the timing information before starting timers + :param first_label: a value for an event that will be manually prepended + to the timing information before starting timers""" + if self._timing_on: + if first_label is not None and first_value is not None: + mod_label = self._make_label(first_label) + value = self._format_number(first_value) + self._log(f"Started timing: {first_label}: {value}") + self._add_label_to_timings(mod_label) + self._timings[mod_label].append(value) + self._start = time.perf_counter() + self._interm = time.perf_counter() + + def end_timings(self) -> None: + """Record a timing event and clear the last checkpoint""" + if self._timing_on and self._start is not None: + mod_label = self._make_label("total_time") + self._add_label_to_timings(mod_label) + delta = self._format_number(time.perf_counter() - self._start) + self._timings[self._make_label("total_time")].append(delta) + self._log(f"Finished timing: {mod_label}: {delta}") + self._interm = None + + def _make_label(self, label: str) -> str: + """Return a label formatted with the current label prefix + + :param label: the original label + :returns: the adjusted label value""" + return self._prefix + label + + def _get_delta(self) -> float: + """Calculates the offset from the last intermediate checkpoint time + + :returns: the number of seconds elapsed""" + if self._interm is None: + return 0 + return time.perf_counter() - self._interm + + def get_last(self, label: str) -> str: + """Return the last timing value collected for the given label in + the format `{label}: {value}`. If no timing value has been collected + with the label, returns `Not measured yet`""" + mod_label = self._make_label(label) + if mod_label in self._timings: + value = self._timings[mod_label][-1] + if value: + return f"{label}: {value}" + + return "Not measured yet" + + def measure_time(self, label: str) -> None: + """Record a new time event if timing is enabled + + :param label: the label to record a timing event for""" + if self._timing_on and self._interm is not None: + mod_label = self._make_label(label) + self._add_label_to_timings(mod_label) + delta = self._format_number(self._get_delta()) + self._timings[mod_label].append(delta) + self._log(f"{mod_label}: {delta}") + self._interm = time.perf_counter() + + def _log(self, msg: str) -> None: + """Conditionally logs a message when the debug flag is enabled + + :param msg: the message to be logged""" + if self._debug: + logger.info(msg) + + @property + def max_length(self) -> int: + """Returns the number of records contained in the largest timing set""" + if len(self._timings) == 0: + return 0 + return max(len(value) for value in self._timings.values()) + + def print_timings(self, to_file: bool = False) -> None: + """Print timing information to standard output. If `to_file` + is `True`, also write results to a file. + + :param to_file: If `True`, also saves timing information + to the files `timings.npy` and `timings.txt` + """ + print(" ".join(self._timings.keys())) + try: + value_array = np.array(list(self._timings.values()), dtype=float) + except Exception as e: + logger.exception(e) + return + value_array = np.transpose(value_array) + if self._debug: + for i in range(value_array.shape[0]): + print(" ".join(self._format_number(value) for value in value_array[i])) + if to_file: + np.save(self._prefix + self._filename + ".npy", value_array) + + @property + def is_active(self) -> bool: + """Return `True` if timer is recording, `False` otherwise""" + return self._timing_on + + @is_active.setter + def is_active(self, active: bool) -> None: + """Set to `True` to record timing information, `False` otherwise""" + self._timing_on = active diff --git a/smartsim/database/orchestrator.py b/smartsim/database/orchestrator.py index e2549891af..e5e99c8932 100644 --- a/smartsim/database/orchestrator.py +++ b/smartsim/database/orchestrator.py @@ -265,8 +265,8 @@ def __init__( raise SSConfigError( "SmartSim not installed with pre-built extensions (Redis)\n" "Use the `smart` cli tool to install needed extensions\n" - "or set REDIS_PATH and REDIS_CLI_PATH in your environment\n" - "See documentation for more information" + "or set SMARTSIM_REDIS_SERVER_EXE and SMARTSIM_REDIS_CLI_EXE " + "in your environment\nSee documentation for more information" ) from e if self.launcher != "local": diff --git a/smartsim/experiment.py b/smartsim/experiment.py index 607a90ae16..9a14eecdc8 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -178,7 +178,7 @@ def __init__( def _set_dragon_server_path(self) -> None: """Set path for dragon server through environment varialbes""" if not "SMARTSIM_DRAGON_SERVER_PATH" in environ: - environ["SMARTSIM_DRAGON_SERVER_PATH_EXP"] = osp.join( + environ["_SMARTSIM_DRAGON_SERVER_PATH_EXP"] = osp.join( self.exp_path, CONFIG.dragon_default_subdir ) diff --git a/smartsim/log.py b/smartsim/log.py index 3d6c0860ee..c8fed9329f 100644 --- a/smartsim/log.py +++ b/smartsim/log.py @@ -252,16 +252,21 @@ def filter(self, record: logging.LogRecord) -> bool: return record.levelno <= level_no -def log_to_file(filename: str, log_level: str = "debug") -> None: +def log_to_file( + filename: str, log_level: str = "debug", logger: t.Optional[logging.Logger] = None +) -> None: """Installs a second filestream handler to the root logger, allowing subsequent logging calls to be sent to filename. - :param filename: the name of the desired log file. - :param log_level: as defined in get_logger. Can be specified + :param filename: The name of the desired log file. + :param log_level: As defined in get_logger. Can be specified to allow the file to store more or less verbose logging information. + :param logger: If supplied, a logger to add the file stream logging + behavior to. By default, a new logger is instantiated. """ - logger = logging.getLogger("SmartSim") + if logger is None: + logger = logging.getLogger("SmartSim") stream = open( # pylint: disable=consider-using-with filename, "w+", encoding="utf-8" ) diff --git a/smartsim/settings/dragonRunSettings.py b/smartsim/settings/dragonRunSettings.py index 69a91547e7..15e5855448 100644 --- a/smartsim/settings/dragonRunSettings.py +++ b/smartsim/settings/dragonRunSettings.py @@ -95,6 +95,26 @@ def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: self.run_args["node-feature"] = ",".join(feature_list) + @override + def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + """Specify the hostlist for this job + + :param host_list: hosts to launch on + :raises ValueError: if an empty host list is supplied + """ + if not host_list: + raise ValueError("empty hostlist provided") + + if isinstance(host_list, str): + host_list = host_list.replace(" ", "").split(",") + + # strip out all whitespace-only values + cleaned_list = [host.strip() for host in host_list if host and host.strip()] + if not len(cleaned_list) == len(host_list): + raise ValueError(f"invalid names found in hostlist: {host_list}") + + self.run_args["host-list"] = ",".join(cleaned_list) + def set_cpu_affinity(self, devices: t.List[int]) -> None: """Set the CPU affinity for this job diff --git a/tests/dragon/__init__.py b/tests/dragon/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dragon/channel.py b/tests/dragon/channel.py new file mode 100644 index 0000000000..4c46359c2d --- /dev/null +++ b/tests/dragon/channel.py @@ -0,0 +1,125 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import threading +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. + + :param key: a path to the root directory of the feature store + """ + self._lock = threading.RLock() + + super().__init__(key.as_posix()) + self._file_path = key + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes, timeout: float = 0) -> None: + """Send a message throuh the underlying communication channel. + + :param value: The value to send + :param timeout: maximum time to wait (in seconds) for messages to send + """ + with self._lock: + # write as text so we can add newlines as delimiters + with open(self._file_path, "a") as fp: + encoded_value = base64.b64encode(value).decode("utf-8") + fp.write(f"{encoded_value}\n") + logger.debug(f"FileSystemCommChannel {self._file_path} sent message") + + def recv(self, timeout: float = 0) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: maximum time to wait (in seconds) for messages to arrive + :returns: the received message + :raises SmartSimError: if the descriptor points to a missing file + """ + with self._lock: + messages: t.List[bytes] = [] + if not self._file_path.exists(): + raise SmartSimError("Empty channel") + + # read as text so we can split on newlines + with open(self._file_path, "r") as fp: + lines = fp.readlines() + + if lines: + line = lines.pop(0) + event_bytes = base64.b64decode(line.encode("utf-8")) + messages.append(event_bytes) + + self.clear() + + # remove the first message only, write remainder back... + if len(lines) > 0: + with open(self._file_path, "w") as fp: + fp.writelines(lines) + + logger.debug( + f"FileSystemCommChannel {self._file_path} received message" + ) + + return messages + + def clear(self) -> None: + """Create an empty file for events.""" + if self._file_path.exists(): + self._file_path.unlink() + self._file_path.touch() + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemCommChannel + """ + try: + path = pathlib.Path(descriptor) + return FileSystemCommChannel(path) + except: + logger.warning(f"failed to create fs comm channel: {descriptor}") + raise diff --git a/tests/dragon/conftest.py b/tests/dragon/conftest.py new file mode 100644 index 0000000000..d542700175 --- /dev/null +++ b/tests/dragon/conftest.py @@ -0,0 +1,129 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import os +import pathlib +import socket +import subprocess +import sys +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.data.ddict.ddict as dragon_ddict +import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc +import dragon.native.process as dragon_process + +from dragon.fli import FLInterface + +# isort: on + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.storage import dragon_util +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_storage() -> dragon_ddict.DDict: + """Fixture to instantiate a dragon distributed dictionary.""" + return dragon_util.create_ddict(1, 2, 32 * 1024**2) + + +@pytest.fixture(scope="module") +def the_worker_channel() -> DragonFLIChannel: + """Fixture to create a valid descriptor for a worker channel + that can be attached to.""" + channel_ = create_local() + fli_ = FLInterface(main_ch=channel_, manager_ch=None) + comm_channel = DragonFLIChannel(fli_) + return comm_channel + + +@pytest.fixture(scope="module") +def the_backbone( + the_storage: t.Any, the_worker_channel: DragonFLIChannel +) -> BackboneFeatureStore: + """Fixture to create a distributed dragon dictionary and wrap it + in a BackboneFeatureStore. + + :param the_storage: The dragon storage engine to use + :param the_worker_channel: Pre-configured worker channel + """ + + backbone = BackboneFeatureStore(the_storage, allow_reserved_writes=True) + backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_channel.descriptor + + return backbone + + +@pytest.fixture(scope="module") +def backbone_descriptor(the_backbone: BackboneFeatureStore) -> str: + # create a shared backbone featurestore + return the_backbone.descriptor + + +def function_as_dragon_proc( + entrypoint_fn: t.Callable[[t.Any], None], + args: t.List[t.Any], + cpu_affinity: t.List[int], + gpu_affinity: t.List[int], +) -> dragon_process.Process: + """Execute a function as an independent dragon process. + + :param entrypoint_fn: The function to execute + :param args: The arguments for the entrypoint function + :param cpu_affinity: The cpu affinity for the process + :param gpu_affinity: The gpu affinity for the process + :returns: The dragon process handle + """ + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + return dragon_process.Process( + target=entrypoint_fn, + args=args, + cwd=os.getcwd(), + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) diff --git a/tests/dragon/feature_store.py b/tests/dragon/feature_store.py new file mode 100644 index 0000000000..d06b0b334e --- /dev/null +++ b/tests/dragon/feature_store.py @@ -0,0 +1,152 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import typing as t + +import smartsim.error as sse +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class MemoryFeatureStore(FeatureStore): + """A feature store with values persisted only in local memory""" + + def __init__( + self, storage: t.Optional[t.Dict[str, t.Union[str, bytes]]] = None + ) -> None: + """Initialize the MemoryFeatureStore instance""" + super().__init__("in-memory-fs") + if storage is None: + storage = {"_": "abc"} + self._storage = storage + + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism + + :param key: The unique key that identifies the resource + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + return self._storage[key] + + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism + + :param key: The unique key that identifies the resource + :param value: The value to store + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + self._storage[key] = value + + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key + + :param key: The unique key that identifies the resource + :returns: True if the key is defined, False otherwise""" + return key in self._storage + + +class FileSystemFeatureStore(FeatureStore): + """Alternative feature store implementation for testing. Stores all + data on the file system""" + + def __init__(self, storage_dir: t.Union[pathlib.Path, str]) -> None: + """Initialize the FileSystemFeatureStore instance + + :param storage_dir: (optional) root directory to store all data relative to""" + if isinstance(storage_dir, str): + storage_dir = pathlib.Path(storage_dir) + self._storage_dir = storage_dir + super().__init__(storage_dir.as_posix()) + + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism + + :param key: The unique key that identifies the resource + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + path = self._key_path(key) + if not path.exists(): + raise sse.SmartSimError(f"{path} not found in feature store") + return path.read_bytes() + + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism + + :param key: The unique key that identifies the resource + :param value: The value to store + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + path = self._key_path(key, create=True) + if isinstance(value, str): + value = value.encode("utf-8") + path.write_bytes(value) + + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key + + :param key: The unique key that identifies the resource + :returns: True if the key is defined, False otherwise""" + path = self._key_path(key) + return path.exists() + + def _key_path(self, key: str, create: bool = False) -> pathlib.Path: + """Given a key, return a path that is optionally combined with a base + directory used by the FileSystemFeatureStore. + + :param key: Unique key of an item to retrieve from the feature store""" + value = pathlib.Path(key) + + if self._storage_dir is not None: + value = self._storage_dir / key + + if create: + value.parent.mkdir(parents=True, exist_ok=True) + + return value + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemFeatureStore": + """A factory method that creates an instance from a descriptor string + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemFeatureStore""" + try: + path = pathlib.Path(descriptor) + path.mkdir(parents=True, exist_ok=True) + if not path.is_dir(): + raise ValueError("FileSystemFeatureStore requires a directory path") + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + return FileSystemFeatureStore(path) + except: + logger.error(f"Error while creating FileSystemFeatureStore: {descriptor}") + raise diff --git a/tests/dragon/test_core_machine_learning_worker.py b/tests/dragon/test_core_machine_learning_worker.py new file mode 100644 index 0000000000..e9c356b4e0 --- /dev/null +++ b/tests/dragon/test_core_machine_learning_worker.py @@ -0,0 +1,377 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import time + +import pytest + +dragon = pytest.importorskip("dragon") + +import torch + +import smartsim.error as sse +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey, TensorKey +from smartsim._core.mli.infrastructure.worker.worker import ( + InferenceRequest, + MachineLearningWorkerCore, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) +from smartsim._core.utils import installed_redisai_backends + +from .feature_store import FileSystemFeatureStore, MemoryFeatureStore + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +# retrieved from pytest fixtures +is_dragon = ( + pytest.test_launcher == "dragon" if hasattr(pytest, "test_launcher") else False +) +torch_available = "torch" in installed_redisai_backends() + + +@pytest.fixture +def persist_torch_model(test_dir: str) -> pathlib.Path: + ts_start = time.time_ns() + print("Starting model file creation...") + test_path = pathlib.Path(test_dir) + model_path = test_path / "basic.pt" + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + ts_end = time.time_ns() + + ts_elapsed = (ts_end - ts_start) / 1000000000 + print(f"Model file creation took {ts_elapsed} seconds") + return model_path + + +@pytest.fixture +def persist_torch_tensor(test_dir: str) -> pathlib.Path: + ts_start = time.time_ns() + print("Starting model file creation...") + test_path = pathlib.Path(test_dir) + file_path = test_path / "tensor.pt" + + tensor = torch.randn((100, 100, 2)) + torch.save(tensor, file_path) + ts_end = time.time_ns() + + ts_elapsed = (ts_end - ts_start) / 1000000000 + print(f"Tensor file creation took {ts_elapsed} seconds") + return file_path + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_model_disk(persist_torch_model: pathlib.Path, test_dir: str) -> None: + """Verify that the ML worker successfully retrieves a model + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + key = str(persist_torch_model) + feature_store = FileSystemFeatureStore(test_dir) + fsd = feature_store.descriptor + feature_store[str(persist_torch_model)] = persist_torch_model.read_bytes() + + model_key = ModelKey(key=key, descriptor=fsd) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) + assert fetch_result.model_bytes + assert fetch_result.model_bytes == persist_torch_model.read_bytes() + + +def test_fetch_model_disk_missing() -> None: + """Verify that the ML worker fails to retrieves a model + when given an invalid (file system) key""" + worker = MachineLearningWorkerCore + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + key = "/path/that/doesnt/exist" + + model_key = ModelKey(key=key, descriptor=fsd) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_model(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_model_feature_store(persist_torch_model: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a model + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + + # create a key to retrieve from the feature store + key = "test-model" + + # put model bytes into the feature store + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + feature_store[key] = persist_torch_model.read_bytes() + + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) + assert fetch_result.model_bytes + assert fetch_result.model_bytes == persist_torch_model.read_bytes() + + +def test_fetch_model_feature_store_missing() -> None: + """Verify that the ML worker fails to retrieves a model + when given an invalid (feature store) key""" + worker = MachineLearningWorkerCore + + key = "some-key" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + # todo: consider that raising this exception shows impl. replace... + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_model(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_model_memory(persist_torch_model: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a model + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + + key = "test-model" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + feature_store[key] = persist_torch_model.read_bytes() + + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) + assert fetch_result.model_bytes + assert fetch_result.model_bytes == persist_torch_model.read_bytes() + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_input_disk(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a tensor/input + when given a valid (file system) key""" + tensor_name = str(persist_torch_tensor) + + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + worker = MachineLearningWorkerCore + + feature_store[tensor_name] = persist_torch_tensor.read_bytes() + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs is not None + + +def test_fetch_input_disk_missing() -> None: + """Verify that the ML worker fails to retrieves a tensor/input + when given an invalid (file system) key""" + worker = MachineLearningWorkerCore + + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + key = "/path/that/doesnt/exist" + + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_inputs(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key[0] in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a tensor/input + when given a valid (feature store) key""" + worker = MachineLearningWorkerCore + + tensor_name = "test-tensor" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) + + # put model bytes into the feature store + feature_store[tensor_name] = persist_torch_tensor.read_bytes() + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs + assert ( + list(fetch_result[0].inputs)[0][:10] == persist_torch_tensor.read_bytes()[:10] + ) + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_multi_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves multiple tensor/input + when given a valid collection of (feature store) keys""" + worker = MachineLearningWorkerCore + + tensor_name = "test-tensor" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + # put model bytes into the feature store + body1 = persist_torch_tensor.read_bytes() + feature_store[tensor_name + "1"] = body1 + + body2 = b"abcdefghijklmnopqrstuvwxyz" + feature_store[tensor_name + "2"] = body2 + + body3 = b"mnopqrstuvwxyzabcdefghijkl" + feature_store[tensor_name + "3"] = body3 + + request = InferenceRequest( + input_keys=[ + TensorKey(key=tensor_name + "1", descriptor=fsd), + TensorKey(key=tensor_name + "2", descriptor=fsd), + TensorKey(key=tensor_name + "3", descriptor=fsd), + ] + ) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + + raw_bytes = list(fetch_result[0].inputs) + assert raw_bytes + assert raw_bytes[0][:10] == persist_torch_tensor.read_bytes()[:10] + assert raw_bytes[1][:10] == body2[:10] + assert raw_bytes[2][:10] == body3[:10] + + +def test_fetch_input_feature_store_missing() -> None: + """Verify that the ML worker fails to retrieves a tensor/input + when given an invalid (feature store) key""" + worker = MachineLearningWorkerCore + + key = "bad-key" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_inputs(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_input_memory(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a tensor/input + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + key = "test-model" + feature_store[key] = persist_torch_tensor.read_bytes() + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs is not None + + +def test_place_outputs() -> None: + """Verify outputs are shared using the feature store""" + worker = MachineLearningWorkerCore + + key_name = "test-model" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + # create a key to retrieve from the feature store + keys = [ + TensorKey(key=key_name + "1", descriptor=fsd), + TensorKey(key=key_name + "2", descriptor=fsd), + TensorKey(key=key_name + "3", descriptor=fsd), + ] + data = [b"abcdef", b"ghijkl", b"mnopqr"] + + for fsk, v in zip(keys, data): + feature_store[fsk.key] = v + + request = InferenceRequest(output_keys=keys) + transform_result = TransformOutputResult(data, [1], "c", "float32") + + worker.place_output(request, transform_result, {fsd: feature_store}) + + for i in range(3): + assert feature_store[keys[i].key] == data[i] + + +@pytest.mark.parametrize( + "key, descriptor", + [ + pytest.param("", "desc", id="invalid key"), + pytest.param("key", "", id="invalid descriptor"), + ], +) +def test_invalid_tensorkey(key, descriptor) -> None: + with pytest.raises(ValueError): + fsk = TensorKey(key, descriptor) diff --git a/tests/dragon/test_device_manager.py b/tests/dragon/test_device_manager.py new file mode 100644 index 0000000000..d270e921cb --- /dev/null +++ b/tests/dragon/test_device_manager.py @@ -0,0 +1,186 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.control.device_manager import ( + DeviceManager, + WorkerDevice, +) +from smartsim._core.mli.infrastructure.storage.feature_store import ( + FeatureStore, + ModelKey, + TensorKey, +) +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +class MockWorker(MachineLearningWorkerBase): + @staticmethod + def fetch_model( + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] + ) -> FetchModelResult: + if batch.has_raw_model: + return FetchModelResult(batch.raw_model) + return FetchModelResult(b"fetched_model") + + @staticmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + return LoadModelResult(fetch_result.model_bytes) + + @staticmethod + def transform_input( + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: "MemoryPool", + ) -> TransformInputResult: + return TransformInputResult(b"result", [slice(0, 1)], [[1, 2]], ["float32"]) + + @staticmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + return ExecuteResult(b"result", [slice(0, 1)]) + + @staticmethod + def transform_output( + batch: RequestBatch, execute_result: ExecuteResult + ) -> t.List[TransformOutputResult]: + return [TransformOutputResult(b"result", None, "c", "float32")] + + +def test_worker_device(): + worker_device = WorkerDevice("gpu:0") + assert worker_device.name == "gpu:0" + + model_key = "my_model_key" + model = b"the model" + + worker_device.add_model(model_key, model) + + assert model_key in worker_device + assert worker_device.get_model(model_key) == model + worker_device.remove_model(model_key) + + assert model_key not in worker_device + + +def test_device_manager_model_in_request(): + + worker_device = WorkerDevice("gpu:0") + device_manager = DeviceManager(worker_device) + + worker = MockWorker() + + tensor_key = TensorKey(key="key", descriptor="desc") + output_key = TensorKey(key="key", descriptor="desc") + model_key = ModelKey(key="model key", descriptor="desc") + + request = InferenceRequest( + model_key=model_key, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"raw model", + batch_size=0, + ) + + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_key, + ) + + with device_manager.get_device( + worker=worker, batch=request_batch, feature_stores={} + ) as returned_device: + + assert returned_device == worker_device + assert worker_device.get_model(model_key.key) == b"raw model" + + assert model_key.key not in worker_device + + +def test_device_manager_model_key(): + + worker_device = WorkerDevice("gpu:0") + device_manager = DeviceManager(worker_device) + + worker = MockWorker() + + tensor_key = TensorKey(key="key", descriptor="desc") + output_key = TensorKey(key="key", descriptor="desc") + model_key = ModelKey(key="model key", descriptor="desc") + + request = InferenceRequest( + model_key=model_key, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=None, + batch_size=0, + ) + + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_key, + ) + + with device_manager.get_device( + worker=worker, batch=request_batch, feature_stores={} + ) as returned_device: + + assert returned_device == worker_device + assert worker_device.get_model(model_key.key) == b"fetched_model" + + assert model_key.key in worker_device diff --git a/tests/dragon/test_dragon_backend.py b/tests/dragon/test_dragon_backend.py new file mode 100644 index 0000000000..0e64c358df --- /dev/null +++ b/tests/dragon/test_dragon_backend.py @@ -0,0 +1,308 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import time +import uuid + +import pytest + +dragon = pytest.importorskip("dragon") + + +from smartsim._core.launcher.dragon.dragonBackend import DragonBackend +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_backend() -> DragonBackend: + return DragonBackend(pid=9999) + + +@pytest.mark.skip("Test is unreliable on build agent and may hang. TODO: Fix") +def test_dragonbackend_start_listener(the_backend: DragonBackend): + """Verify the background process listening to consumer registration events + is up and processing messages as expected.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor + + with pytest.raises(KeyError) as ex: + # we expect the value of the consumer to be empty until + # the listener start-up completes. + backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + + assert "not found" in ex.value.args[0] + + drg_process = the_backend.start_event_listener(cpu_affinity=[], gpu_affinity=[]) + + # # confirm there is a process still running + logger.info(f"Dragon process started: {drg_process}") + assert drg_process is not None, "Backend was unable to start event listener" + assert drg_process.puid != 0, "Process unique ID is empty" + assert drg_process.returncode is None, "Listener terminated early" + + # wait for the event listener to come up + try: + config = backbone.wait_for( + [BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], timeout=30 + ) + # verify result was in the returned configuration map + assert config[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + except Exception: + raise KeyError( + f"Unable to locate {BackboneFeatureStore.MLI_REGISTRAR_CONSUMER}" + "in the backbone" + ) + + # wait_for ensures the normal retrieval will now work, error-free + descriptor = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + assert descriptor is not None + + # register a new listener channel + comm_channel = DragonCommChannel.from_descriptor(descriptor) + mock_descriptor = str(uuid.uuid4()) + event = OnCreateConsumer("test_dragonbackend_start_listener", mock_descriptor, []) + + event_bytes = bytes(event) + comm_channel.send(event_bytes) + + subscriber_list = [] + + # Give the channel time to write the message and the listener time to handle it + for i in range(20): + time.sleep(1) + # Retrieve the subscriber list from the backbone and verify it is updated + if subscriber_list := backbone.notification_channels: + logger.debug(f"The subscriber list was populated after {i} iterations") + break + + assert mock_descriptor in subscriber_list + + # now send a shutdown message to terminate the listener + return_code = drg_process.returncode + + # clean up if the OnShutdownRequested wasn't properly handled + if return_code is None and drg_process.is_alive: + drg_process.kill() + drg_process.join() + + +def test_dragonbackend_backend_consumer(the_backend: DragonBackend): + """Verify the listener background process updates the appropriate + value in the backbone.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + assert backbone._allow_reserved_writes + + # create listener with `as_service=False` to perform a single loop iteration + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False) + + logger.debug(f"backbone loaded? {listener._backbone}") + logger.debug(f"listener created? {listener}") + + try: + # call the service execute method directly to trigger + # the entire service lifecycle + listener.execute() + + consumer_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + logger.debug(f"MLI_REGISTRAR_CONSUMER: {consumer_desc}") + + assert consumer_desc + except Exception as ex: + logger.info("") + finally: + listener._on_shutdown() + + +def test_dragonbackend_event_handled(the_backend: DragonBackend): + """Verify the event listener process updates the appropriate + value in the backbone when an event is received and again on shutdown. + """ + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + # create the listener to be tested + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False) + + assert listener._backbone, "The listener is not attached to a backbone" + + try: + # set up the listener but don't let the service event loop start + listener._create_eventing() # listener.execute() + + # grab the channel descriptor so we can simulate registrations + channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + comm_channel = DragonCommChannel.from_descriptor(channel_desc) + + num_events = 5 + events = [] + for i in range(num_events): + # register some mock consumers using the backend channel + event = OnCreateConsumer( + "test_dragonbackend_event_handled", + f"mock-consumer-descriptor-{uuid.uuid4()}", + [], + ) + event_bytes = bytes(event) + comm_channel.send(event_bytes) + events.append(event) + + # run few iterations of the event loop in case it takes a few cycles to write + for _ in range(20): + listener._on_iteration() + # Grab the value that should be getting updated + notify_consumers = set(backbone.notification_channels) + if len(notify_consumers) == len(events): + logger.info(f"Retrieved all consumers after {i} listen cycles") + break + + # ... and confirm that all the mock consumer descriptors are registered + assert set([e.descriptor for e in events]) == set(notify_consumers) + logger.info(f"Number of registered consumers: {len(notify_consumers)}") + + except Exception as ex: + logger.exception(f"test_dragonbackend_event_handled - exception occurred: {ex}") + assert False + finally: + # shutdown should unregister a registration listener + listener._on_shutdown() + + for i in range(10): + if BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in backbone: + logger.debug(f"The listener was removed after {i} iterations") + channel_desc = None + break + + # we should see that there is no listener registered + assert not channel_desc, "Listener shutdown failed to clean up the backbone" + + +def test_dragonbackend_shutdown_event(the_backend: DragonBackend): + """Verify the background process shuts down when it receives a + shutdown request.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=True) + + # set up the listener but don't let the listener loop start + listener._create_eventing() # listener.execute() + + # grab the channel descriptor so we can publish to it + channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + comm_channel = DragonCommChannel.from_descriptor(channel_desc) + + assert listener._consumer.listening, "Listener isn't ready to listen" + + # send a shutdown request... + event = OnShutdownRequested("test_dragonbackend_shutdown_event") + event_bytes = bytes(event) + comm_channel.send(event_bytes, 0.1) + + # execute should encounter the shutdown and exit + listener.execute() + + # ...and confirm the listener is now cancelled + assert not listener._consumer.listening + + +@pytest.mark.parametrize("health_check_frequency", [10, 20]) +def test_dragonbackend_shutdown_on_health_check( + the_backend: DragonBackend, + health_check_frequency: float, +): + """Verify that the event listener automatically shuts down when + a new listener is registered in its place. + + :param health_check_frequency: The expected frequency of service health check + invocations""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + listener = ConsumerRegistrationListener( + backbone, + 1.0, + 1.0, + as_service=True, # allow service to run long enough to health check + health_check_frequency=health_check_frequency, + ) + + # set up the listener but don't let the listener loop start + listener._create_eventing() # listener.execute() + assert listener._consumer.listening, "Listener wasn't ready to listen" + + # Replace the consumer descriptor in the backbone to trigger + # an automatic shutdown + backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = str(uuid.uuid4()) + + # set the last health check manually to verify the duration + start_at = time.time() + listener._last_health_check = time.time() + + # run execute to let the service trigger health checks + listener.execute() + elapsed = time.time() - start_at + + # confirm the frequency of the health check was honored + assert elapsed >= health_check_frequency + + # ...and confirm the listener is now cancelled + assert ( + not listener._consumer.listening + ), "Listener was not automatically shutdown by the health check" diff --git a/tests/dragon/test_dragon_ddict_utils.py b/tests/dragon/test_dragon_ddict_utils.py new file mode 100644 index 0000000000..c8bf687ef1 --- /dev/null +++ b/tests/dragon/test_dragon_ddict_utils.py @@ -0,0 +1,117 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim._core.mli.infrastructure.storage import dragon_util +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.mark.parametrize( + "num_nodes, num_managers, mem_per_node", + [ + pytest.param(1, 1, 3 * 1024**2, id="3MB, Bare minimum allocation"), + pytest.param(2, 2, 128 * 1024**2, id="128 MB allocation, 2 nodes, 2 mgr"), + pytest.param(2, 1, 512 * 1024**2, id="512 MB allocation, 2 nodes, 1 mgr"), + ], +) +def test_dragon_storage_util_create_ddict( + num_nodes: int, + num_managers: int, + mem_per_node: int, +): + """Verify that a dragon dictionary is successfully created. + + :param num_nodes: Number of ddict nodes to attempt to create + :param num_managers: Number of managers per node to request + :param num_managers: Memory to allocate per node + """ + ddict = dragon_util.create_ddict(num_nodes, num_managers, mem_per_node) + + assert ddict is not None + + +@pytest.mark.parametrize( + "num_nodes, num_managers, mem_per_node", + [ + pytest.param(-1, 1, 3 * 1024**2, id="Negative Node Count"), + pytest.param(0, 1, 3 * 1024**2, id="Invalid Node Count"), + pytest.param(1, -1, 3 * 1024**2, id="Negative Mgr Count"), + pytest.param(1, 0, 3 * 1024**2, id="Invalid Mgr Count"), + pytest.param(1, 1, -3 * 1024**2, id="Negative Mem Per Node"), + pytest.param(1, 1, (3 * 1024**2) - 1, id="Invalid Mem Per Node"), + pytest.param(1, 1, 0 * 1024**2, id="No Mem Per Node"), + ], +) +def test_dragon_storage_util_create_ddict_validators( + num_nodes: int, + num_managers: int, + mem_per_node: int, +): + """Verify that a dragon dictionary is successfully created. + + :param num_nodes: Number of ddict nodes to attempt to create + :param num_managers: Number of managers per node to request + :param num_managers: Memory to allocate per node + """ + with pytest.raises(ValueError): + dragon_util.create_ddict(num_nodes, num_managers, mem_per_node) + + +def test_dragon_storage_util_get_ddict_descriptor(the_storage: dragon_ddict.DDict): + """Verify that a descriptor is created. + + :param the_storage: A pre-allocated ddict + """ + value = dragon_util.ddict_to_descriptor(the_storage) + + assert isinstance(value, str) + assert len(value) > 0 + + +def test_dragon_storage_util_get_ddict_from_descriptor(the_storage: dragon_ddict.DDict): + """Verify that a ddict is created from a descriptor. + + :param the_storage: A pre-allocated ddict + """ + descriptor = dragon_util.ddict_to_descriptor(the_storage) + + value = dragon_util.descriptor_to_ddict(descriptor) + + assert value is not None + assert isinstance(value, dragon_ddict.DDict) + assert dragon_util.ddict_to_descriptor(value) == descriptor diff --git a/tests/dragon/test_environment_loader.py b/tests/dragon/test_environment_loader.py new file mode 100644 index 0000000000..07b2a45c1c --- /dev/null +++ b/tests/dragon/test_environment_loader.py @@ -0,0 +1,147 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +import dragon.data.ddict.ddict as dragon_ddict +import dragon.utils as du +from dragon.fli import FLInterface + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + DragonFeatureStore, +) +from smartsim.error.errors import SmartSimError + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.mark.parametrize( + "content", + [ + pytest.param(b"a"), + pytest.param(b"new byte string"), + ], +) +def test_environment_loader_attach_fli(content: bytes, monkeypatch: pytest.MonkeyPatch): + """A descriptor can be stored, loaded, and reattached.""" + chan = create_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv( + EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, + du.B64.bytes_to_str(queue.serialize()), + ) + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + config_queue = config.get_queue() + + _ = config_queue.send(content) + + old_recv = queue.recvh() + result, _ = old_recv.recv_bytes() + assert result == content + + +def test_environment_loader_serialize_fli(monkeypatch: pytest.MonkeyPatch): + """The serialized descriptors of a loaded and unloaded + queue are the same.""" + chan = create_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv( + EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, + du.B64.bytes_to_str(queue.serialize()), + ) + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + config_queue = config.get_queue() + assert config_queue._fli.serialize() == queue.serialize() + + +def test_environment_loader_flifails(monkeypatch: pytest.MonkeyPatch): + """An incorrect serialized descriptor will fails to attach.""" + + monkeypatch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "randomstring") + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=None, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + with pytest.raises(SmartSimError): + config.get_queue() + + +def test_environment_loader_backbone_load_dfs( + monkeypatch: pytest.MonkeyPatch, the_storage: dragon_ddict.DDict +): + """Verify the dragon feature store is loaded correctly by the + EnvironmentConfigLoader to demonstrate featurestore_factory correctness.""" + feature_store = DragonFeatureStore(the_storage) + monkeypatch.setenv( + EnvironmentConfigLoader.BACKBONE_ENV_VAR, feature_store.descriptor + ) + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=None, + queue_factory=None, + ) + + print(f"calling config.get_backbone: `{feature_store.descriptor}`") + + backbone = config.get_backbone() + assert backbone is not None + + +def test_environment_variables_not_set(monkeypatch: pytest.MonkeyPatch): + """EnvironmentConfigLoader getters return None when environment + variables are not set.""" + with monkeypatch.context() as patch: + patch.setenv(EnvironmentConfigLoader.BACKBONE_ENV_VAR, "") + patch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "") + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonCommChannel.from_descriptor, + ) + assert config.get_backbone() is None + assert config.get_queue() is None diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py new file mode 100644 index 0000000000..aacd47b556 --- /dev/null +++ b/tests/dragon/test_error_handling.py @@ -0,0 +1,511 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +from unittest.mock import MagicMock + +import pytest + +dragon = pytest.importorskip("dragon") + +import multiprocessing as mp + +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.fli import FLInterface +from dragon.mpbridge.queues import DragonQueue + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.infrastructure.control.request_dispatcher import ( + RequestDispatcher, +) +from smartsim._core.mli.infrastructure.control.worker_manager import ( + WorkerManager, + exception_handler, +) +from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.feature_store import ( + FeatureStore, + ModelKey, + TensorKey, +) +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.mli.mli_schemas.response.response_capnp import ResponseBuilder + +from .utils.channel import FileSystemCommChannel +from .utils.worker import IntegratedTorchWorker + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.fixture(scope="module") +def app_feature_store(the_storage) -> FeatureStore: + # create a standalone feature store to mimic a user application putting + # data into an application-owned resource (app should not access backbone) + app_fs = DragonFeatureStore(the_storage) + return app_fs + + +@pytest.fixture +def setup_worker_manager_model_bytes( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, + as_service=False, + cooldown=3, + ) + + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + + inf_request = InferenceRequest( + model_key=None, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + + model_id = ModelKey(key="key", descriptor=app_feature_store.descriptor) + + request_batch = RequestBatch( + [inf_request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_id, + ) + + dispatcher_task_queue.put(request_batch) + return worker_manager, integrated_worker_type + + +@pytest.fixture +def setup_worker_manager_model_key( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, + as_service=False, + cooldown=3, + ) + + tensor_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) + output_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) + model_id = ModelKey(key="model key", descriptor=app_feature_store.descriptor) + + request = InferenceRequest( + model_key=model_id, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_id, + ) + + dispatcher_task_queue.put(request_batch) + return worker_manager, integrated_worker_type + + +@pytest.fixture +def setup_request_dispatcher_model_bytes( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=0, + batch_size=0, + config_loader=config_loader, + worker_type=integrated_worker_type, + ) + request_dispatcher._on_start() + + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") + request = MessageHandler.build_request( + test_dir, model, [tensor_key], [output_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + + request_dispatcher._incoming_channel.send(ser_request) + + return request_dispatcher, integrated_worker_type + + +@pytest.fixture +def setup_request_dispatcher_model_key( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=0, + batch_size=0, + config_loader=config_loader, + worker_type=integrated_worker_type, + ) + request_dispatcher._on_start() + + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + model_key = MessageHandler.build_model_key( + key="model key", descriptor=app_feature_store.descriptor + ) + request = MessageHandler.build_request( + test_dir, model_key, [tensor_key], [output_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + + request_dispatcher._incoming_channel.send(ser_request) + + return request_dispatcher, integrated_worker_type + + +def mock_pipeline_stage( + monkeypatch: pytest.MonkeyPatch, + integrated_worker: MachineLearningWorkerBase, + stage: str, +) -> t.Callable[[t.Any], ResponseBuilder]: + def mock_stage(*args: t.Any, **kwargs: t.Any) -> None: + raise ValueError(f"Simulated error in {stage}") + + monkeypatch.setattr(integrated_worker, stage, mock_stage) + mock_reply_fn = MagicMock() + mock_response = MagicMock() + mock_response.schema.node.displayName = "Response" + mock_reply_fn.return_value = mock_response + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply", + mock_reply_fn, + ) + + mock_reply_channel = MagicMock() + mock_reply_channel.send = MagicMock() + + def mock_exception_handler( + exc: Exception, reply_channel: CommChannelBase, failure_message: str + ) -> None: + exception_handler(exc, mock_reply_channel, failure_message) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.worker_manager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.request_dispatcher.exception_handler", + mock_exception_handler, + ) + + return mock_reply_fn + + +@pytest.mark.parametrize( + "setup_worker_manager", + [ + pytest.param("setup_worker_manager_model_bytes"), + pytest.param("setup_worker_manager_model_key"), + ], +) +@pytest.mark.parametrize( + "stage, error_message", + [ + pytest.param( + "fetch_model", + "Error loading model on device or getting device.", + id="fetch model", + ), + pytest.param( + "load_model", + "Error loading model on device or getting device.", + id="load model", + ), + pytest.param("execute", "Error while executing.", id="execute"), + pytest.param( + "transform_output", + "Error while transforming the output.", + id="transform output", + ), + pytest.param( + "place_output", "Error while placing the output.", id="place output" + ), + ], +) +def test_wm_pipeline_stage_errors_handled( + request: pytest.FixtureRequest, + setup_worker_manager: str, + monkeypatch: pytest.MonkeyPatch, + stage: str, + error_message: str, +) -> None: + """Ensures that the worker manager does not crash after a failure in various pipeline stages""" + worker_manager, integrated_worker_type = request.getfixturevalue( + setup_worker_manager + ) + integrated_worker = worker_manager._worker + + worker_manager._on_start() + device = worker_manager._device_manager._device + mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + + if stage not in ["fetch_model"]: + monkeypatch.setattr( + integrated_worker, + "fetch_model", + MagicMock(return_value=FetchModelResult(b"result_bytes")), + ) + if stage not in ["fetch_model", "load_model"]: + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + device, + "get_model", + MagicMock(return_value=b"result_bytes"), + ) + if stage not in [ + "fetch_model", + "execute", + ]: + monkeypatch.setattr( + integrated_worker, + "execute", + MagicMock(return_value=ExecuteResult(b"result_bytes", [slice(0, 1)])), + ) + if stage not in [ + "fetch_model", + "execute", + "transform_output", + ]: + monkeypatch.setattr( + integrated_worker, + "transform_output", + MagicMock( + return_value=[TransformOutputResult(b"result", [], "c", "float32")] + ), + ) + + worker_manager._on_iteration() + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", error_message) + + +@pytest.mark.parametrize( + "setup_request_dispatcher", + [ + pytest.param("setup_request_dispatcher_model_bytes"), + pytest.param("setup_request_dispatcher_model_key"), + ], +) +@pytest.mark.parametrize( + "stage, error_message", + [ + pytest.param( + "fetch_inputs", + "Error fetching input.", + id="fetch input", + ), + pytest.param( + "transform_input", + "Error transforming input.", + id="transform input", + ), + ], +) +def test_dispatcher_pipeline_stage_errors_handled( + request: pytest.FixtureRequest, + setup_request_dispatcher: str, + monkeypatch: pytest.MonkeyPatch, + stage: str, + error_message: str, +) -> None: + """Ensures that the request dispatcher does not crash after a failure in various pipeline stages""" + request_dispatcher, integrated_worker_type = request.getfixturevalue( + setup_request_dispatcher + ) + integrated_worker = request_dispatcher._worker + + mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + + if stage not in ["fetch_inputs"]: + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=[FetchInputResult(result=[b"result"], meta=None)]), + ) + + request_dispatcher._on_iteration() + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", error_message) + + +def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensures that the worker manager does not crash after a failure in the + execute pipeline stage""" + + mock_reply_channel = MagicMock() + mock_reply_channel.send = MagicMock() + + mock_reply_fn = MagicMock() + + mock_response = MagicMock() + mock_response.schema.node.displayName = "Response" + mock_reply_fn.return_value = mock_response + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply", + mock_reply_fn, + ) + + test_exception = ValueError("Test ValueError") + exception_handler( + test_exception, mock_reply_channel, "Failure while fetching the model." + ) + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", "Failure while fetching the model.") + + +def test_dragon_feature_store_invalid_storage(): + """Verify that attempting to create a DragonFeatureStore without storage fails.""" + storage = None + + with pytest.raises(ValueError) as ex: + DragonFeatureStore(storage) + + assert "storage" in ex.value.args[0].lower() + assert "required" in ex.value.args[0].lower() diff --git a/tests/dragon/test_event_consumer.py b/tests/dragon/test_event_consumer.py new file mode 100644 index 0000000000..8a241bab19 --- /dev/null +++ b/tests/dragon/test_event_consumer.py @@ -0,0 +1,386 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import time +import typing as t +from unittest import mock + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnShutdownRequested, + OnWriteFeatureStore, +) +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +def test_eventconsumer_eventpublisher_integration( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that the publisher and consumer integrate as expected when + multiple publishers and consumers are sending simultaneously. This + test closely tracks the test in tests/test_featurestore_base.py also named + test_eventconsumer_eventpublisher_integration but requires dragon entities. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + capp_channel = DragonCommChannel(create_local()) + back_channel = DragonCommChannel(create_local()) + + wmgr_consumer_descriptor = wmgr_channel.descriptor + capp_consumer_descriptor = capp_channel.descriptor + back_consumer_descriptor = back_channel.descriptor + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + capp_consumer = EventConsumer( + capp_channel, + the_backbone, + ) + back_consumer = EventConsumer( + back_channel, + the_backbone, + filters=[OnCreateConsumer.CONSUMER_CREATED], + ) + + # create some broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + mock_client_app = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + the_backbone.notification_channels = [ + wmgr_consumer_descriptor, + capp_consumer_descriptor, + back_consumer_descriptor, + ] + + # simulate worker manager sending a notification to backend that it's alive + event_1 = OnCreateConsumer( + "test_eventconsumer_eventpublisher_integration", + wmgr_consumer_descriptor, + filters=[], + ) + mock_worker_mgr.send(event_1) + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", + the_backbone.descriptor, + key, + ) + mock_client_app.send(event, timeout=0.1) + + # worker manager should only get updates about feature update + wmgr_messages = wmgr_consumer.recv() + assert len(wmgr_messages) == 3 + + # the backend should only receive messages about consumer creation + back_messages = back_consumer.recv() + assert len(back_messages) == 1 + + # hypothetical app has no filters and will get all events + app_messages = capp_consumer.recv() + assert len(app_messages) == 4 + + +@pytest.mark.parametrize( + " timeout, batch_timeout, exp_err_msg", + [(-1, 1, " timeout"), (1, -1, "batch_timeout")], +) +def test_eventconsumer_invalid_timeout( + timeout: float, + batch_timeout: float, + exp_err_msg: str, + test_dir: str, + the_backbone: BackboneFeatureStore, +) -> None: + """Verify that the event consumer raises an exception + when provided an invalid request timeout. + + :param timeout: The request timeout for the event consumer recv call + :param batch_timeout: The batch timeout for the event consumer recv call + :param exp_err_msg: A unique value from the error message that should be raised + :param the_storage: The dragon storage engine to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + + # the consumer should report an error for the invalid timeout value + with pytest.raises(ValueError) as ex: + wmgr_consumer.recv(timeout=timeout, batch_timeout=batch_timeout) + + assert exp_err_msg in ex.value.args[0] + + +def test_eventconsumer_no_event_handler_registered( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that a consumer discards messages when + on a channel if no handler is registered. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + + # create a consumer to receive messages + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone, event_handler=None) + + # create a broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # manually register the consumers since we don't have a backend running + the_backbone.notification_channels = [wmgr_channel.descriptor] + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_no_event_handler_registered", + the_backbone.descriptor, + key, + ) + mock_worker_mgr.send(event, timeout=0.1) + + # run the handler and let it discard messages + for _ in range(15): + wmgr_consumer.listen_once(0.2, 2.0) + + assert wmgr_consumer.listening + + +def test_eventconsumer_no_event_handler_registered_shutdown( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that a consumer without an event handler + registered still honors shutdown requests. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + capp_channel = DragonCommChannel(create_local()) + + # create a consumers to receive messages + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone) + + # create a broadcaster to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # manually register the consumers since we don't have a backend running + the_backbone.notification_channels = [ + wmgr_channel.descriptor, + capp_channel.descriptor, + ] + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_no_event_handler_registered_shutdown", + the_backbone.descriptor, + key, + ) + mock_worker_mgr.send(event, timeout=0.1) + + event = OnShutdownRequested( + "test_eventconsumer_no_event_handler_registered_shutdown" + ) + mock_worker_mgr.send(event, timeout=0.1) + + # wmgr will stop listening to messages when it is told to stop listening + wmgr_consumer.listen(timeout=0.1, batch_timeout=2.0) + + for _ in range(15): + wmgr_consumer.listen_once(timeout=0.1, batch_timeout=2.0) + + # confirm the messages were processed, discarded, and the shutdown was received + assert wmgr_consumer.listening == False + + +def test_eventconsumer_registration( + the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that a consumer is correctly registered in + the backbone after sending a registration request. Then, + Confirm the consumer is unregistered after sending the + un-register request. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + with monkeypatch.context() as patch: + registrar = ConsumerRegistrationListener( + the_backbone, 1.0, 2.0, as_service=False + ) + + # NOTE: service.execute(as_service=False) will complete the service life- + # cycle and remove the registrar from the backbone, so mock _on_shutdown + disabled_shutdown = mock.MagicMock() + patch.setattr(registrar, "_on_shutdown", disabled_shutdown) + + # initialze registrar resources + registrar.execute() + + # create a consumer that will be registered + wmgr_channel = DragonCommChannel(create_local()) + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone) + + registered_channels = the_backbone.notification_channels + + # trigger the consumer-to-registrar handshake + wmgr_consumer.register() + + current_registrations: t.List[str] = [] + + # have the registrar run a few times to pick up the msg + for i in range(15): + registrar.execute() + current_registrations = the_backbone.notification_channels + if len(current_registrations) != len(registered_channels): + logger.debug(f"The event was processed on iteration {i}") + break + + # confirm the consumer is registered + assert wmgr_channel.descriptor in current_registrations + + # copy old list so we can compare against it. + registered_channels = list(current_registrations) + + # trigger the consumer removal + wmgr_consumer.unregister() + + # have the registrar run a few times to pick up the msg + for i in range(15): + registrar.execute() + current_registrations = the_backbone.notification_channels + if len(current_registrations) != len(registered_channels): + logger.debug(f"The event was processed on iteration {i}") + break + + # confirm the consumer is no longer registered + assert wmgr_channel.descriptor not in current_registrations + + +def test_registrar_teardown( + the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that the consumer registrar removes itself from + the backbone when it shuts down. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + with monkeypatch.context() as patch: + registrar = ConsumerRegistrationListener( + the_backbone, 1.0, 2.0, as_service=False + ) + + # directly initialze registrar resources to avoid service life-cycle + registrar._create_eventing() + + # confirm the registrar is published to the backbone + cfg = the_backbone.wait_for([BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], 10) + assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in cfg + + # execute the entire service lifecycle 1x + registrar.execute() + + consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone + + for i in range(15): + time.sleep(0.1) + consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone + if not consumer_found: + logger.debug(f"Registrar removed from the backbone on iteration {i}") + break + + assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in the_backbone diff --git a/tests/dragon/test_featurestore.py b/tests/dragon/test_featurestore.py new file mode 100644 index 0000000000..019dcde7a0 --- /dev/null +++ b/tests/dragon/test_featurestore.py @@ -0,0 +1,327 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import multiprocessing as mp +import random +import time +import typing as t +import unittest.mock as mock +import uuid + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + time as bbtime, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +def test_backbone_wait_for_no_keys( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for a value succeeds + immediately and does not cause a wait to occur if the supplied key + list is empty. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([]) + assert len(values) == 0 + + # confirm that no wait occurred + bbtime.sleep.assert_not_called() + + +def test_backbone_wait_for_prepopulated( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for a value succeed + immediately and do not cause a wait to occur if the data exists. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([BackboneFeatureStore.MLI_WORKER_QUEUE], 0.1) + + # confirm that wait_for with one key returns one value + assert len(values) == 1 + + # confirm that the descriptor is non-null w/some non-trivial value + assert len(values[BackboneFeatureStore.MLI_WORKER_QUEUE]) > 5 + + # confirm that no wait occurred + bbtime.sleep.assert_not_called() + + +def test_backbone_wait_for_prepopulated_dupe( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for keys that are duplicated + results in a single value being returned for each key. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + key1, key2 = "key-1", "key-2" + value1, value2 = "i-am-value-1", "i-am-value-2" + the_backbone[key1] = value1 + the_backbone[key2] = value2 + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([key1, key2, key1]) # key1 is duplicated + + # confirm that wait_for with one key returns one value + assert len(values) == 2 + assert key1 in values + assert key2 in values + + assert values[key1] == value1 + assert values[key2] == value2 + + +def set_value_after_delay( + descriptor: str, key: str, value: str, delay: float = 5 +) -> None: + """Helper method to persist a random value into the backbone + + :param descriptor: the backbone feature store descriptor to attach to + :param key: the key to write to + :param value: a value to write to the key + :param delay: amount of delay to apply before writing the key + """ + time.sleep(delay) + + backbone = BackboneFeatureStore.from_descriptor(descriptor) + backbone[key] = value + logger.debug(f"set_value_after_delay wrote `{value} to backbone[`{key}`]") + + +@pytest.mark.parametrize( + "delay", + [ + pytest.param( + 0, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 1, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 2, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 4, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 8, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + ], +) +def test_backbone_wait_for_partial_prepopulated( + the_backbone: BackboneFeatureStore, delay: float +) -> None: + """Verify that when data is not all in the backbone, the `wait_for` operation + continues to poll until it finds everything it needs. + + :param the_backbone: the storage engine to use, prepopulated with + :param delay: the number of seconds the second process will wait before + setting the target value in the backbone featurestore + """ + # set a very low timeout to confirm that it does not wait + wait_timeout = 10 + + key, value = str(uuid.uuid4()), str(random.random() * 10) + + logger.debug(f"Starting process to write {key} after {delay}s") + p = mp.Process( + target=set_value_after_delay, args=(the_backbone.descriptor, key, value, delay) + ) + p.start() + + p2 = mp.Process( + target=the_backbone.wait_for, + args=([BackboneFeatureStore.MLI_WORKER_QUEUE, key],), + kwargs={"timeout": wait_timeout}, + ) + p2.start() + + p.join() + p2.join() + + # both values should be written at this time + ret_vals = the_backbone.wait_for( + [key, BackboneFeatureStore.MLI_WORKER_QUEUE, key], 0.1 + ) + # confirm that wait_for with two keys returns two values + assert len(ret_vals) == 2, "values should contain values for both awaited keys" + + # confirm the pre-populated value has the correct output + assert ( + ret_vals[BackboneFeatureStore.MLI_WORKER_QUEUE] == "12345" + ) # mock descriptor value from fixture + + # confirm the population process completed and the awaited value is correct + assert ret_vals[key] == value, "verify order of values " + + +@pytest.mark.parametrize( + "num_keys", + [ + pytest.param( + 0, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 1, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 3, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 7, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 11, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + ], +) +def test_backbone_wait_for_multikey( + the_backbone: BackboneFeatureStore, + num_keys: int, + test_dir: str, +) -> None: + """Verify that asking the backbone to wait for multiple keys results + in that number of values being returned. + + :param the_backbone: the storage engine to use, prepopulated with + :param num_keys: the number of extra keys to set & request in the backbone + """ + # maximum delay allowed for setter processes + max_delay = 5 + + extra_keys = [str(uuid.uuid4()) for _ in range(num_keys)] + extra_values = [str(uuid.uuid4()) for _ in range(num_keys)] + extras = dict(zip(extra_keys, extra_values)) + delays = [random.random() * max_delay for _ in range(num_keys)] + processes = [] + + for key, value, delay in zip(extra_keys, extra_values, delays): + assert delay < max_delay, "write delay exceeds test timeout" + logger.debug(f"Delaying {key} write by {delay} seconds") + p = mp.Process( + target=set_value_after_delay, + args=(the_backbone.descriptor, key, value, delay), + ) + p.start() + processes.append(p) + + p2 = mp.Process( + target=the_backbone.wait_for, + args=(extra_keys,), + kwargs={"timeout": max_delay * 2}, + ) + p2.start() + for p in processes: + p.join(timeout=max_delay * 2) + p2.join( + timeout=max_delay * 2 + ) # give it 10 seconds longer than p2 timeout for backoff + + # use without a wait to verify all values are written + num_keys = len(extra_keys) + actual_values = the_backbone.wait_for(extra_keys, timeout=0.01) + assert len(extra_keys) == num_keys + + # confirm that wait_for returns all the expected values + assert len(actual_values) == num_keys + + # confirm that the returned values match (e.g. are returned in the right order) + for k in extras: + assert extras[k] == actual_values[k] diff --git a/tests/dragon/test_featurestore_base.py b/tests/dragon/test_featurestore_base.py new file mode 100644 index 0000000000..6daceb9061 --- /dev/null +++ b/tests/dragon/test_featurestore_base.py @@ -0,0 +1,844 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pathlib +import time +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnWriteFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.feature_store import ReservedKeys +from smartsim.error import SmartSimError + +from .channel import FileSystemCommChannel +from .feature_store import MemoryFeatureStore + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +def boom(*args, **kwargs) -> None: + """Helper function that blows up when used to mock up + some other function.""" + raise Exception(f"you shall not pass! {args}, {kwargs}") + + +def test_event_uid() -> None: + """Verify that all events include a unique identifier.""" + uids: t.Set[str] = set() + num_iters = 1000 + + # generate a bunch of events and keep track all the IDs + for i in range(num_iters): + event_a = OnCreateConsumer("test_event_uid", str(i), filters=[]) + event_b = OnWriteFeatureStore("test_event_uid", "test_event_uid", str(i)) + + uids.add(event_a.uid) + uids.add(event_b.uid) + + # verify each event created a unique ID + assert len(uids) == 2 * num_iters + + +def test_mli_reserved_keys_conversion() -> None: + """Verify that conversion from a string to an enum member + works as expected.""" + + for reserved_key in ReservedKeys: + # iterate through all keys and verify `from_string` works + assert ReservedKeys.contains(reserved_key.value) + + # show that the value (actual key) not the enum member name + # will not be incorrectly identified as reserved + assert not ReservedKeys.contains(str(reserved_key).split(".")[1]) + + +def test_mli_reserved_keys_writes() -> None: + """Verify that attempts to write to reserved keys are blocked from a + standard DragonFeatureStore but enabled with the BackboneFeatureStore.""" + + mock_storage = {} + dfs = DragonFeatureStore(mock_storage) + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + other = MemoryFeatureStore(mock_storage) + + expected_value = "value" + + for reserved_key in ReservedKeys: + # we expect every reserved key to fail using DragonFeatureStore... + with pytest.raises(SmartSimError) as ex: + dfs[reserved_key] = expected_value + + assert "reserved key" in ex.value.args[0] + + # ... and expect other feature stores to respect reserved keys + with pytest.raises(SmartSimError) as ex: + other[reserved_key] = expected_value + + assert "reserved key" in ex.value.args[0] + + # ...and those same keys to succeed on the backbone + backbone[reserved_key] = expected_value + actual_value = backbone[reserved_key] + assert actual_value == expected_value + + +def test_mli_consumers_read_by_key() -> None: + """Verify that the value returned from the mli consumers method is written + to the correct key and reads are allowed via standard dragon feature store.""" + + mock_storage = {} + dfs = DragonFeatureStore(mock_storage) + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + other = MemoryFeatureStore(mock_storage) + + expected_value = "value" + + # write using backbone that has permission to write reserved keys + backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] = expected_value + + # confirm read-only access to reserved keys from any FeatureStore + for fs in [dfs, backbone, other]: + assert fs[ReservedKeys.MLI_NOTIFY_CONSUMERS] == expected_value + + +def test_mli_consumers_read_by_backbone() -> None: + """Verify that the backbone reads the correct location + when using the backbone feature store API instead of mapping API.""" + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + expected_value = "value" + + backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] = expected_value + + # confirm reading via convenience method returns expected value + assert backbone.notification_channels[0] == expected_value + + +def test_mli_consumers_write_by_backbone() -> None: + """Verify that the backbone writes the correct location + when using the backbone feature store API instead of mapping API.""" + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + expected_value = ["value"] + + backbone.notification_channels = expected_value + + # confirm write using convenience method targets expected key + assert backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] == ",".join(expected_value) + + +def test_eventpublisher_broadcast_no_factory(test_dir: str) -> None: + """Verify that a broadcast operation without any registered subscribers + succeeds without raising Exceptions. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + consumer_descriptor = storage_path / "test-consumer" + + # NOTE: we're not putting any consumers into the backbone here! + backbone = BackboneFeatureStore(mock_storage) + + event = OnCreateConsumer( + "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[] + ) + + publisher = EventBroadcaster(backbone) + num_receivers = 0 + + # publishing this event without any known consumers registered should succeed + # but report that it didn't have anybody to send the event to + consumer_descriptor = storage_path / f"test-consumer" + event = OnCreateConsumer( + "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[] + ) + + num_receivers += publisher.send(event) + + # confirm no changes to the backbone occur when fetching the empty consumer key + key_in_features_store = ReservedKeys.MLI_NOTIFY_CONSUMERS in backbone + assert not key_in_features_store + + # confirm that the broadcast reports no events published + assert num_receivers == 0 + # confirm that the broadcast buffered the event for a later send + assert publisher.num_buffered == 1 + + +def test_eventpublisher_broadcast_to_empty_consumer_list(test_dir: str) -> None: + """Verify that a broadcast operation without any registered subscribers + succeeds without raising Exceptions. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + # prep our backbone with a consumer list + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = [] + + event = OnCreateConsumer( + "test_eventpublisher_broadcast_to_empty_consumer_list", + consumer_descriptor, + filters=[], + ) + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + num_receivers = publisher.send(event) + + registered_consumers = backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] + + # confirm that no consumers exist in backbone to send to + assert not registered_consumers + # confirm that the broadcast reports no events published + assert num_receivers == 0 + # confirm that the broadcast buffered the event for a later send + assert publisher.num_buffered == 1 + + +def test_eventpublisher_broadcast_without_channel_factory(test_dir: str) -> None: + """Verify that a broadcast operation reports an error if no channel + factory was supplied for constructing the consumer channels. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + # prep our backbone with a consumer list + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = [consumer_descriptor] + + event = OnCreateConsumer( + "test_eventpublisher_broadcast_without_channel_factory", + consumer_descriptor, + filters=[], + ) + publisher = EventBroadcaster( + backbone, + # channel_factory=FileSystemCommChannel.from_descriptor # <--- not supplied + ) + + with pytest.raises(SmartSimError) as ex: + publisher.send(event) + + assert "factory" in ex.value.args[0] + + +def test_eventpublisher_broadcast_empties_buffer(test_dir: str) -> None: + """Verify that a successful broadcast clears messages from the event + buffer when a new message is sent and consumers are registered. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = (consumer_descriptor,) + + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + # mock building up some buffered events + num_buffered_events = 14 + for i in range(num_buffered_events): + event = OnCreateConsumer( + "test_eventpublisher_broadcast_empties_buffer", + storage_path / f"test-consumer-{str(i)}", + [], + ) + publisher._event_buffer.append(bytes(event)) + + event0 = OnCreateConsumer( + "test_eventpublisher_broadcast_empties_buffer", + storage_path / f"test-consumer-{str(num_buffered_events + 1)}", + [], + ) + + num_receivers = publisher.send(event0) + # 1 receiver x 15 total events == 15 events + assert num_receivers == num_buffered_events + 1 + + +@pytest.mark.parametrize( + "num_consumers, num_buffered, expected_num_sent", + [ + pytest.param(0, 7, 0, id="0 x (7+1) - no consumers, multi-buffer"), + pytest.param(1, 7, 8, id="1 x (7+1) - single consumer, multi-buffer"), + pytest.param(2, 7, 16, id="2 x (7+1) - multi-consumer, multi-buffer"), + pytest.param(4, 4, 20, id="4 x (4+1) - multi-consumer, multi-buffer (odd #)"), + pytest.param(9, 0, 9, id="13 x (0+1) - multi-consumer, empty buffer"), + ], +) +def test_eventpublisher_broadcast_returns_total_sent( + test_dir: str, num_consumers: int, num_buffered: int, expected_num_sent: int +) -> None: + """Verify that a successful broadcast returns the total number of events + sent, including buffered messages. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param num_consumers: the number of consumers to mock setting up prior to send + :param num_buffered: the number of pre-buffered events to mock up + :param expected_num_sent: the expected result from calling send + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumers = [] + for i in range(num_consumers): + consumers.append(storage_path / f"test-consumer-{i}") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = consumers + + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + # mock building up some buffered events + for i in range(num_buffered): + event = OnCreateConsumer( + "test_eventpublisher_broadcast_returns_total_sent", + storage_path / f"test-consumer-{str(i)}", + [], + ) + publisher._event_buffer.append(bytes(event)) + + assert publisher.num_buffered == num_buffered + + # this event will trigger clearing anything already in buffer + event0 = OnCreateConsumer( + "test_eventpublisher_broadcast_returns_total_sent", + storage_path / f"test-consumer-{num_buffered}", + [], + ) + + # num_receivers should contain a number that computes w/all consumers and all events + num_receivers = publisher.send(event0) + + assert num_receivers == expected_num_sent + + +def test_eventpublisher_prune_unused_consumer(test_dir: str) -> None: + """Verify that any unused consumers are pruned each time a new event is sent. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + event = OnCreateConsumer( + "test_eventpublisher_prune_unused_consumer", + consumer_descriptor, + filters=[], + ) + + # the only registered cnosumer is in the event, expect no pruning + backbone.notification_channels = (consumer_descriptor,) + + publisher.send(event) + assert str(consumer_descriptor) in publisher._channel_cache + assert len(publisher._channel_cache) == 1 + + # add a new descriptor for another event... + consumer_descriptor2 = storage_path / "test-consumer-2" + # ... and remove the old descriptor from the backbone when it's looked up + backbone.notification_channels = (consumer_descriptor2,) + + event = OnCreateConsumer( + "test_eventpublisher_prune_unused_consumer", consumer_descriptor2, filters=[] + ) + + publisher.send(event) + + assert str(consumer_descriptor2) in publisher._channel_cache + assert str(consumer_descriptor) not in publisher._channel_cache + assert len(publisher._channel_cache) == 1 + + # test multi-consumer pruning by caching some extra channels + prune0, prune1, prune2 = "abc", "def", "ghi" + publisher._channel_cache[prune0] = "doesnt-matter-if-it-is-pruned" + publisher._channel_cache[prune1] = "doesnt-matter-if-it-is-pruned" + publisher._channel_cache[prune2] = "doesnt-matter-if-it-is-pruned" + + # add in one of our old channels so we prune the above items, send to these + backbone.notification_channels = (consumer_descriptor, consumer_descriptor2) + + publisher.send(event) + + assert str(consumer_descriptor2) in publisher._channel_cache + + # NOTE: we should NOT prune something that isn't used by this message but + # does appear in `backbone.notification_channels` + assert str(consumer_descriptor) in publisher._channel_cache + + # confirm all of our items that were not in the notification channels are gone + for pruned in [prune0, prune1, prune2]: + assert pruned not in publisher._channel_cache + + # confirm we have only the two expected items in the channel cache + assert len(publisher._channel_cache) == 2 + + +def test_eventpublisher_serialize_failure( + test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that errors during message serialization are raised to the caller. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param monkeypatch: pytest fixture for modifying behavior of existing code + with mock implementations + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + with monkeypatch.context() as patch: + event = OnCreateConsumer( + "test_eventpublisher_serialize_failure", target_descriptor, filters=[] + ) + + # patch the __bytes__ implementation to cause pickling to fail during send + def bad_bytes(self) -> bytes: + return b"abc" + + # this patch causes an attribute error when event pickling is attempted + patch.setattr(event, "__bytes__", bad_bytes) + + backbone.notification_channels = (target_descriptor,) + + # send a message into the channel + with pytest.raises(AttributeError) as ex: + publisher.send(event) + + assert "serialize" in ex.value.args[0] + + +def test_eventpublisher_factory_failure( + test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that errors during channel construction are raised to the caller. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param monkeypatch: pytest fixture for modifying behavior of existing code + with mock implementations + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + def boom(descriptor: str) -> None: + raise Exception(f"you shall not pass! {descriptor}") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + publisher = EventBroadcaster(backbone, channel_factory=boom) + + with monkeypatch.context() as patch: + event = OnCreateConsumer( + "test_eventpublisher_factory_failure", target_descriptor, filters=[] + ) + + backbone.notification_channels = (target_descriptor,) + + # send a message into the channel + with pytest.raises(SmartSimError) as ex: + publisher.send(event) + + assert "construct" in ex.value.args[0] + + +def test_eventpublisher_failure(test_dir: str, monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that unexpected errors during message send are caught and wrapped in a + SmartSimError so they are not propagated directly to the caller. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param monkeypatch: pytest fixture for modifying behavior of existing code + with mock implementations + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + def boom(self) -> None: + raise Exception("That was unexpected...") + + with monkeypatch.context() as patch: + event = OnCreateConsumer( + "test_eventpublisher_failure", target_descriptor, filters=[] + ) + + # patch the _broadcast implementation to cause send to fail after + # after the event has been pickled + patch.setattr(publisher, "_broadcast", boom) + + backbone.notification_channels = (target_descriptor,) + + # Here, we see the exception raised by broadcast that isn't expected + # is not allowed directly out, and instead is wrapped in SmartSimError + with pytest.raises(SmartSimError) as ex: + publisher.send(event) + + assert "unexpected" in ex.value.args[0] + + +def test_eventconsumer_receive(test_dir: str) -> None: + """Verify that a consumer retrieves a message from the given channel. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage) + comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) + event = OnCreateConsumer( + "test_eventconsumer_receive", target_descriptor, filters=[] + ) + + # simulate a sent event by writing directly to the input comm channel + comm_channel.send(bytes(event)) + + consumer = EventConsumer(comm_channel, backbone) + + all_received: t.List[OnCreateConsumer] = consumer.recv() + assert len(all_received) == 1 + + # verify we received the same event that was raised + assert all_received[0].category == event.category + assert all_received[0].descriptor == event.descriptor + + +@pytest.mark.parametrize("num_sent", [0, 1, 2, 4, 8, 16]) +def test_eventconsumer_receive_multi(test_dir: str, num_sent: int) -> None: + """Verify that a consumer retrieves multiple message from the given channel. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param num_sent: parameterized value used to vary the number of events + that are enqueued and validations are checked at multiple queue sizes + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage) + comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) + + # simulate multiple sent events by writing directly to the input comm channel + for _ in range(num_sent): + event = OnCreateConsumer( + "test_eventconsumer_receive_multi", target_descriptor, filters=[] + ) + comm_channel.send(bytes(event)) + + consumer = EventConsumer(comm_channel, backbone) + + all_received: t.List[OnCreateConsumer] = consumer.recv() + assert len(all_received) == num_sent + + +def test_eventconsumer_receive_empty(test_dir: str) -> None: + """Verify that a consumer receiving an empty message ignores the + message and continues processing. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage) + comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) + + # simulate a sent event by writing directly to the input comm channel + comm_channel.send(bytes(b"")) + + consumer = EventConsumer(comm_channel, backbone) + + messages = consumer.recv() + + # the messages array should be empty + assert not messages + + +def test_eventconsumer_eventpublisher_integration(test_dir: str) -> None: + """Verify that the publisher and consumer integrate as expected when + multiple publishers and consumers are sending simultaneously. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + mock_fs_descriptor = str(storage_path / f"mock-feature-store") + + wmgr_channel = FileSystemCommChannel(storage_path / "test-wmgr") + capp_channel = FileSystemCommChannel(storage_path / "test-capp") + back_channel = FileSystemCommChannel(storage_path / "test-backend") + + wmgr_consumer_descriptor = wmgr_channel.descriptor + capp_consumer_descriptor = capp_channel.descriptor + back_consumer_descriptor = back_channel.descriptor + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + capp_consumer = EventConsumer( + capp_channel, + backbone, + ) + back_consumer = EventConsumer( + back_channel, + backbone, + filters=[OnCreateConsumer.CONSUMER_CREATED], + ) + + # create some broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + backbone, + channel_factory=FileSystemCommChannel.from_descriptor, + ) + mock_client_app = EventBroadcaster( + backbone, + channel_factory=FileSystemCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + backbone.notification_channels = [ + wmgr_consumer_descriptor, + capp_consumer_descriptor, + back_consumer_descriptor, + ] + + # simulate worker manager sending a notification to backend that it's alive + event_1 = OnCreateConsumer( + "test_eventconsumer_eventpublisher_integration", + wmgr_consumer_descriptor, + filters=[], + ) + mock_worker_mgr.send(event_1) + + # simulate the app updating a model a few times + event_2 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1" + ) + event_3 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-2" + ) + event_4 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1" + ) + + mock_client_app.send(event_2) + mock_client_app.send(event_3) + mock_client_app.send(event_4) + + # worker manager should only get updates about feature update + wmgr_messages = wmgr_consumer.recv() + assert len(wmgr_messages) == 3 + + # the backend should only receive messages about consumer creation + back_messages = back_consumer.recv() + assert len(back_messages) == 1 + + # hypothetical app has no filters and will get all events + app_messages = capp_consumer.recv() + assert len(app_messages) == 4 + + +@pytest.mark.parametrize("invalid_timeout", [-100.0, -1.0, 0.0]) +def test_eventconsumer_batch_timeout( + invalid_timeout: float, + test_dir: str, +) -> None: + """Verify that a consumer allows only positive, non-zero values for timeout + if it is supplied. + + :param invalid_timeout: any invalid timeout that should fail validation + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage) + + channel = FileSystemCommChannel(storage_path / "test-wmgr") + + with pytest.raises(ValueError) as ex: + # try to create a consumer w/a max recv size of 0 + consumer = EventConsumer( + channel, + backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + consumer.recv(batch_timeout=invalid_timeout) + + assert "positive" in ex.value.args[0] + + +@pytest.mark.parametrize( + "wait_timeout, exp_wait_max", + [ + # aggregate the 1+1+1 into 3 on remaining parameters + pytest.param(1, 1 + 1 + 1, id="1s wait, 3 cycle steps"), + pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"), + pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"), + pytest.param(9, 3 + 2 + 4 + 8, id="9s wait, 6 cycle steps"), + # aggregate an entire cycle into 16 + pytest.param(19.5, 16 + 3 + 2 + 4, id="20s wait, repeat cycle"), + ], +) +def test_backbone_wait_timeout(wait_timeout: float, exp_wait_max: float) -> None: + """Verify that attempts to attach to the worker queue from the protoclient + timeout in an appropriate amount of time. Note: due to the backoff, we verify + the elapsed time is less than the 15s of a cycle of waits. + + :param wait_timeout: Maximum amount of time (in seconds) to allow the backbone + to wait for the requested value to exist + :param exp_wait_max: Maximum amount of time (in seconds) to set as the upper + bound to allow the delays with backoff to occur + :param storage_for_dragon_fs: the dragon storage engine to use + """ + + # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8] + # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps) + start_time = time.time() + + storage = {} + backbone = BackboneFeatureStore(storage) + + with pytest.raises(SmartSimError) as ex: + backbone.wait_for(["does-not-exist"], wait_timeout) + + assert "timeout" in str(ex.value.args[0]).lower() + + end_time = time.time() + elapsed = end_time - start_time + + # confirm that we met our timeout + assert elapsed > wait_timeout, f"below configured timeout {wait_timeout}" + + # confirm that the total wait time is aligned with the sleep cycle + assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}" diff --git a/tests/dragon/test_featurestore_integration.py b/tests/dragon/test_featurestore_integration.py new file mode 100644 index 0000000000..23fdc55ab6 --- /dev/null +++ b/tests/dragon/test_featurestore_integration.py @@ -0,0 +1,213 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import ( + DEFAULT_CHANNEL_BUFFER_SIZE, + create_local, +) +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) + +# isort: off +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +@pytest.fixture(scope="module") +def the_worker_channel() -> DragonCommChannel: + """Fixture to create a valid descriptor for a worker channel + that can be attached to.""" + wmgr_channel_ = create_local() + wmgr_channel = DragonCommChannel(wmgr_channel_) + return wmgr_channel + + +@pytest.mark.parametrize( + "num_events, batch_timeout, max_batches_expected", + [ + pytest.param(1, 1.0, 2, id="under 1s timeout"), + pytest.param(20, 1.0, 3, id="test 1s timeout 20x"), + pytest.param(30, 0.2, 5, id="test 0.2s timeout 30x"), + pytest.param(60, 0.4, 4, id="small batches"), + pytest.param(100, 0.1, 10, id="many small batches"), + ], +) +def test_eventconsumer_max_dequeue( + num_events: int, + batch_timeout: float, + max_batches_expected: int, + the_worker_channel: DragonCommChannel, + the_backbone: BackboneFeatureStore, +) -> None: + """Verify that a consumer does not sit and collect messages indefinitely + by checking that a consumer returns after a maximum timeout is exceeded. + + :param num_events: Total number of events to raise in the test + :param batch_timeout: Maximum wait time (in seconds) for a message to be sent + :param max_batches_expected: Maximum number of receives that should occur + :param the_storage: Dragon storage engine to use + """ + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + the_worker_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + + # create a broadcaster to publish messages + mock_client_app = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + the_backbone.notification_channels = [the_worker_channel.descriptor] + + # simulate the app updating a model a lot of times + for key in (f"key-{i}" for i in range(num_events)): + event = OnWriteFeatureStore( + "test_eventconsumer_max_dequeue", the_backbone.descriptor, key + ) + mock_client_app.send(event, timeout=0.01) + + num_dequeued = 0 + num_batches = 0 + + while wmgr_messages := wmgr_consumer.recv( + timeout=0.1, + batch_timeout=batch_timeout, + ): + # worker manager should not get more than `max_num_msgs` events + num_dequeued += len(wmgr_messages) + num_batches += 1 + + # make sure we made all the expected dequeue calls and got everything + assert num_dequeued == num_events + assert num_batches > 0 + assert num_batches < max_batches_expected, "too many recv calls were made" + + +@pytest.mark.parametrize( + "buffer_size", + [ + pytest.param( + -1, + id="replace negative, default to 500", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 0, + id="replace zero, default to 500", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 1, + id="non-zero buffer size: 1", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + # pytest.param(500, id="maximum size edge case: 500"), + pytest.param( + 550, + id="larger than default: 550", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 800, + id="much larger then default: 800", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 1000, + id="very large buffer: 1000, unreliable in dragon-v0.10", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + ], +) +def test_channel_buffer_size( + buffer_size: int, + the_storage: t.Any, +) -> None: + """Verify that a channel used by an EventBroadcaster can buffer messages + until a configured maximum value is exceeded. + + :param buffer_size: Maximum number of messages allowed in a channel buffer + :param the_storage: The dragon storage engine to use + """ + + mock_storage = the_storage + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + + wmgr_channel_ = create_local(buffer_size) # <--- vary buffer size + wmgr_channel = DragonCommChannel(wmgr_channel_) + wmgr_consumer_descriptor = wmgr_channel.descriptor + + # create a broadcaster to publish messages. create no consumers to + # push the number of sent messages past the allotted buffer size + mock_client_app = EventBroadcaster( + backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + backbone.notification_channels = [wmgr_consumer_descriptor] + + if buffer_size < 1: + # NOTE: we set this after creating the channel above to ensure + # the default parameter value was used during instantiation + buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE + + # simulate the app updating a model a lot of times + for key in (f"key-{i}" for i in range(buffer_size)): + event = OnWriteFeatureStore( + "test_channel_buffer_size", backbone.descriptor, key + ) + mock_client_app.send(event, timeout=0.01) + + # adding 1 more over the configured buffer size should report the error + with pytest.raises(Exception) as ex: + mock_client_app.send(event, timeout=0.01) diff --git a/tests/dragon/test_inference_reply.py b/tests/dragon/test_inference_reply.py new file mode 100644 index 0000000000..bdc7be14bc --- /dev/null +++ b/tests/dragon/test_inference_reply.py @@ -0,0 +1,76 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey +from smartsim._core.mli.infrastructure.worker.worker import InferenceReply +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +handler = MessageHandler() + + +@pytest.fixture +def inference_reply() -> InferenceReply: + return InferenceReply() + + +@pytest.fixture +def fs_key() -> TensorKey: + return TensorKey("key", "descriptor") + + +@pytest.mark.parametrize( + "outputs, expected", + [ + ([b"output bytes"], True), + (None, False), + ([], False), + ], +) +def test_has_outputs(monkeypatch, inference_reply, outputs, expected): + """Test the has_outputs property with different values for outputs.""" + monkeypatch.setattr(inference_reply, "outputs", outputs) + assert inference_reply.has_outputs == expected + + +@pytest.mark.parametrize( + "output_keys, expected", + [ + ([fs_key], True), + (None, False), + ([], False), + ], +) +def test_has_output_keys(monkeypatch, inference_reply, output_keys, expected): + """Test the has_output_keys property with different values for output_keys.""" + monkeypatch.setattr(inference_reply, "output_keys", output_keys) + assert inference_reply.has_output_keys == expected diff --git a/tests/dragon/test_inference_request.py b/tests/dragon/test_inference_request.py new file mode 100644 index 0000000000..f5c8b9bdc7 --- /dev/null +++ b/tests/dragon/test_inference_request.py @@ -0,0 +1,118 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey +from smartsim._core.mli.infrastructure.worker.worker import InferenceRequest +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +handler = MessageHandler() + + +@pytest.fixture +def inference_request() -> InferenceRequest: + return InferenceRequest() + + +@pytest.fixture +def fs_key() -> TensorKey: + return TensorKey("key", "descriptor") + + +@pytest.mark.parametrize( + "raw_model, expected", + [ + (handler.build_model(b"bytes", "Model Name", "V1"), True), + (None, False), + ], +) +def test_has_raw_model(monkeypatch, inference_request, raw_model, expected): + """Test the has_raw_model property with different values for raw_model.""" + monkeypatch.setattr(inference_request, "raw_model", raw_model) + assert inference_request.has_raw_model == expected + + +@pytest.mark.parametrize( + "model_key, expected", + [ + (fs_key, True), + (None, False), + ], +) +def test_has_model_key(monkeypatch, inference_request, model_key, expected): + """Test the has_model_key property with different values for model_key.""" + monkeypatch.setattr(inference_request, "model_key", model_key) + assert inference_request.has_model_key == expected + + +@pytest.mark.parametrize( + "raw_inputs, expected", + [([b"raw input bytes"], True), (None, False), ([], False)], +) +def test_has_raw_inputs(monkeypatch, inference_request, raw_inputs, expected): + """Test the has_raw_inputs property with different values for raw_inputs.""" + monkeypatch.setattr(inference_request, "raw_inputs", raw_inputs) + assert inference_request.has_raw_inputs == expected + + +@pytest.mark.parametrize( + "input_keys, expected", + [([fs_key], True), (None, False), ([], False)], +) +def test_has_input_keys(monkeypatch, inference_request, input_keys, expected): + """Test the has_input_keys property with different values for input_keys.""" + monkeypatch.setattr(inference_request, "input_keys", input_keys) + assert inference_request.has_input_keys == expected + + +@pytest.mark.parametrize( + "output_keys, expected", + [([fs_key], True), (None, False), ([], False)], +) +def test_has_output_keys(monkeypatch, inference_request, output_keys, expected): + """Test the has_output_keys property with different values for output_keys.""" + monkeypatch.setattr(inference_request, "output_keys", output_keys) + assert inference_request.has_output_keys == expected + + +@pytest.mark.parametrize( + "input_meta, expected", + [ + ([handler.build_tensor_descriptor("c", "float32", [1, 2, 3])], True), + (None, False), + ([], False), + ], +) +def test_has_input_meta(monkeypatch, inference_request, input_meta, expected): + """Test the has_input_meta property with different values for input_meta.""" + monkeypatch.setattr(inference_request, "input_meta", input_meta) + assert inference_request.has_input_meta == expected diff --git a/tests/dragon/test_protoclient.py b/tests/dragon/test_protoclient.py new file mode 100644 index 0000000000..f84417107d --- /dev/null +++ b/tests/dragon/test_protoclient.py @@ -0,0 +1,313 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import pickle +import time +import typing as t +from unittest.mock import MagicMock + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +# isort: off +from dragon import fli +from dragon.data.ddict.ddict import DDict + +# from ..ex..high_throughput_inference.mock_app import ProtoClient +from smartsim._core.mli.client.protoclient import ProtoClient + + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +WORK_QUEUE_KEY = BackboneFeatureStore.MLI_WORKER_QUEUE +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_worker_queue(the_backbone: BackboneFeatureStore) -> DragonFLIChannel: + """Fixture that creates a dragon FLI channel as a stand-in for the + worker queue created by the worker. + + :param the_backbone: The backbone feature store to update + with the worker queue descriptor. + :returns: The attached `DragonFLIChannel` + """ + + # create the FLI + to_worker_channel = create_local() + fli_ = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + comm_channel = DragonFLIChannel(fli_) + + # store the descriptor in the backbone + the_backbone.worker_queue = comm_channel.descriptor + + try: + comm_channel.send(b"foo") + except Exception as ex: + logger.exception(f"Test send from worker channel failed", exc_info=True) + + return comm_channel + + +@pytest.mark.parametrize( + "backbone_timeout, exp_wait_max", + [ + # aggregate the 1+1+1 into 3 on remaining parameters + pytest.param(0.5, 1 + 1 + 1, id="0.5s wait, 3 cycle steps"), + pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"), + pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"), + ], +) +def test_protoclient_timeout( + backbone_timeout: float, + exp_wait_max: float, + the_backbone: BackboneFeatureStore, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that attempts to attach to the worker queue from the protoclient + timeout in an appropriate amount of time. Note: due to the backoff, we verify + the elapsed time is less than the 15s of a cycle of waits. + + :param backbone_timeout: a timeout for use when configuring a proto client + :param exp_wait_max: a ceiling for the expected time spent waiting for + the timeout + :param the_backbone: a pre-initialized backbone featurestore for setting up + the environment variable required by the client + """ + + # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8] + # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps) + + with monkeypatch.context() as ctx, pytest.raises(SmartSimError) as ex: + start_time = time.time() + # remove the worker queue value from the backbone if it exists + # to ensure the timeout occurs + the_backbone.pop(BackboneFeatureStore.MLI_WORKER_QUEUE) + + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + + ProtoClient(timing_on=False, backbone_timeout=backbone_timeout) + elapsed = time.time() - start_time + logger.info(f"ProtoClient timeout occurred in {elapsed} seconds") + + # confirm that we met our timeout + assert ( + elapsed >= backbone_timeout + ), f"below configured timeout {backbone_timeout}" + + # confirm that the total wait time is aligned with the sleep cycle + assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}" + + +def test_protoclient_initialization_no_backbone( + monkeypatch: pytest.MonkeyPatch, the_worker_queue: DragonFLIChannel +): + """Verify that attempting to start the client without required environment variables + results in an exception. + + :param the_worker_queue: Passing the worker queue fixture to ensure + the worker queue environment is correctly configured. + + NOTE: os.environ[BackboneFeatureStore.MLI_BACKBONE] is not set""" + + with monkeypatch.context() as patch, pytest.raises(SmartSimError) as ex: + patch.setenv(BackboneFeatureStore.MLI_BACKBONE, "") + + ProtoClient(timing_on=False) + + # confirm the missing value error has been raised + assert {"backbone", "configuration"}.issubset(set(ex.value.args[0].split(" "))) + + +def test_protoclient_initialization( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that attempting to start the client with required env vars results + in a fully initialized client. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: an FLI channel the client will retrieve + from the backbone""" + + with monkeypatch.context() as ctx: + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + fs_descriptor = the_backbone.descriptor + wq_descriptor = the_worker_queue.descriptor + + # confirm the backbone was attached correctly + assert client._backbone is not None + assert client._backbone.descriptor == fs_descriptor + + # we expect the backbone to add its descriptor to the local env + assert os.environ[BackboneFeatureStore.MLI_BACKBONE] == fs_descriptor + + # confirm the worker queue is created and attached correctly + assert client._to_worker_fli is not None + assert client._to_worker_fli.descriptor == wq_descriptor + + # we expect the worker queue descriptor to be placed into the backbone + # we do NOT expect _from_worker_ch to be placed anywhere. it's a specific callback + assert the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] == wq_descriptor + + # confirm the worker channels are created + assert client._from_worker_ch is not None + assert client._to_worker_ch is not None + + # wrap the channels just to easily verify they produces a descriptor + assert DragonCommChannel(client._from_worker_ch).descriptor + assert DragonCommChannel(client._to_worker_ch).descriptor + + # confirm a publisher is created + assert client._publisher is not None + + +def test_protoclient_write_model( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that writing a model using the client causes the model data to be + written to a feature store. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: Passing the worker queue fixture to ensure + the worker queue environment is correctly configured. + from the backbone + """ + + with monkeypatch.context() as ctx: + # we won't actually send here + client = ProtoClient(timing_on=False) + + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + model_key = "my-model" + model_bytes = b"12345" + + client.set_model(model_key, model_bytes) + + # confirm the client modified the underlying feature store + assert client._backbone[model_key] == model_bytes + + +@pytest.mark.parametrize( + "num_listeners, num_model_updates", + [(1, 1), (1, 4), (2, 4), (16, 4), (64, 8)], +) +def test_protoclient_write_model_notification_sent( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, + num_listeners: int, + num_model_updates: int, +): + """Verify that writing a model sends a key-written event. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: an FLI channel the client will retrieve + from the backbone + :param num_listeners: vary the number of registered listeners + to verify that the event is broadcast to everyone + :param num_listeners: vary the number of listeners to register + to verify the broadcast counts messages sent correctly + """ + + # we won't actually send here, but it won't try without registered listeners + listeners = [f"mock-ch-desc-{i}" for i in range(num_listeners)] + + the_backbone[BackboneFeatureStore.MLI_BACKBONE] = the_backbone.descriptor + the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_queue.descriptor + the_backbone[BackboneFeatureStore.MLI_NOTIFY_CONSUMERS] = ",".join(listeners) + the_backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = None + + with monkeypatch.context() as ctx: + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + publisher = t.cast(EventBroadcaster, client._publisher) + + # mock attaching to a channel given the mock-ch-desc in backbone + mock_send = MagicMock(return_value=None) + mock_comm_channel = MagicMock(**{"send": mock_send}, spec=DragonCommChannel) + mock_get_comm_channel = MagicMock(return_value=mock_comm_channel) + ctx.setattr(publisher, "_get_comm_channel", mock_get_comm_channel) + + model_key = "my-model" + model_bytes = b"12345" + + for i in range(num_model_updates): + client.set_model(model_key, model_bytes) + + # confirm that a listener channel was attached + # once for each registered listener in backbone + assert mock_get_comm_channel.call_count == num_listeners * num_model_updates + + # confirm the client raised the key-written event + assert ( + mock_send.call_count == num_listeners * num_model_updates + ), f"Expected {num_listeners} sends with {num_listeners} registrations" + + # with at least 1 consumer registered, we can verify the message is sent + for call_args in mock_send.call_args_list: + send_args = call_args.args + event_bytes, timeout = send_args[0], send_args[1] + + assert event_bytes, "Expected event bytes to be supplied to send" + assert ( + timeout == 0.001 + ), "Expected default timeout on call to `publisher.send`, " + + # confirm the correct event was raised + event = t.cast( + OnWriteFeatureStore, + pickle.loads(event_bytes), + ) + assert event.descriptor == the_backbone.descriptor + assert event.key == model_key diff --git a/tests/dragon/test_reply_building.py b/tests/dragon/test_reply_building.py new file mode 100644 index 0000000000..48493b3c4d --- /dev/null +++ b/tests/dragon/test_reply_building.py @@ -0,0 +1,64 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.control.worker_manager import build_failure_reply + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.mark.parametrize( + "status, message", + [ + pytest.param("timeout", "Worker timed out", id="timeout"), + pytest.param("fail", "Failed while executing", id="fail"), + ], +) +def test_build_failure_reply(status: "Status", message: str): + "Ensures failure replies can be built successfully" + response = build_failure_reply(status, message) + display_name = response.schema.node.displayName # type: ignore + class_name = display_name.split(":")[-1] + assert class_name == "Response" + assert response.status == status + assert response.message == message + + +def test_build_failure_reply_fails(): + "Ensures ValueError is raised if a Status Enum is not used" + with pytest.raises(ValueError) as ex: + response = build_failure_reply("not a status enum", "message") + + assert "Error assigning status to response" in ex.value.args[0] diff --git a/tests/dragon/test_request_dispatcher.py b/tests/dragon/test_request_dispatcher.py new file mode 100644 index 0000000000..70d73e243f --- /dev/null +++ b/tests/dragon/test_request_dispatcher.py @@ -0,0 +1,233 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import gc +import os +import subprocess as sp +import time +import typing as t +from queue import Empty + +import numpy as np +import pytest + +from . import conftest +from .utils import msg_pump + +pytest.importorskip("dragon") + + +# isort: off +import dragon +import multiprocessing as mp + +import torch + +# isort: on + +from dragon import fli +from dragon.data.ddict.ddict import DDict +from dragon.managed_memory import MemoryAlloc + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.control.request_dispatcher import ( + RequestBatch, + RequestDispatcher, +) +from smartsim._core.mli.infrastructure.control.worker_manager import ( + EnvironmentConfigLoader, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +try: + mp.set_start_method("dragon") +except Exception: + pass + + +@pytest.mark.parametrize("num_iterations", [4]) +def test_request_dispatcher( + num_iterations: int, + the_storage: DDict, + test_dir: str, +) -> None: + """Test the request dispatcher batching and queueing system + + This also includes setting a queue to disposable, checking that it is no + longer referenced by the dispatcher. + """ + + to_worker_channel = create_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli) + + backbone_fs = BackboneFeatureStore(the_storage, allow_reserved_writes=True) + + # NOTE: env vars should be set prior to instantiating EnvironmentConfigLoader + # or test environment may be unable to send messages w/queue + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone_fs.descriptor + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=1000, + batch_size=2, + config_loader=config_loader, + worker_type=TorchWorker, + mem_pool_size=2 * 1024**2, + ) + + worker_queue = config_loader.get_queue() + if worker_queue is None: + logger.warning( + "FLI input queue not loaded correctly from config_loader: " + f"{config_loader._queue_descriptor}" + ) + + request_dispatcher._on_start() + + # put some messages into the work queue for the dispatcher to pickup + channels = [] + processes = [] + for i in range(num_iterations): + batch: t.Optional[RequestBatch] = None + mem_allocs = [] + tensors = [] + + # NOTE: creating callbacks in test to avoid a local channel being torn + # down when mock_messages terms but before the final response message is sent + + callback_channel = DragonCommChannel.from_local() + channels.append(callback_channel) + + process = conftest.function_as_dragon_proc( + msg_pump.mock_messages, + [ + worker_queue.descriptor, + backbone_fs.descriptor, + i, + callback_channel.descriptor, + ], + [], + [], + ) + processes.append(process) + process.start() + assert process.returncode is None, "The message pump failed to start" + + # give dragon some time to populate the message queues + for i in range(15): + try: + request_dispatcher._on_iteration() + batch = request_dispatcher.task_queue.get(timeout=1.0) + break + except Empty: + time.sleep(2) + logger.warning(f"Task queue is empty on iteration {i}") + continue + except Exception as exc: + logger.error(f"Task queue exception on iteration {i}") + raise exc + + assert batch is not None + assert batch.has_valid_requests + + model_key = batch.model_id.key + + try: + transform_result = batch.inputs + for transformed, dims, dtype in zip( + transform_result.transformed, + transform_result.dims, + transform_result.dtypes, + ): + mem_alloc = MemoryAlloc.attach(transformed) + mem_allocs.append(mem_alloc) + itemsize = np.empty((1), dtype=dtype).itemsize + tensors.append( + torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[0 : np.prod(dims) * itemsize], + dtype=dtype, + ).reshape(dims) + ) + ) + + assert len(batch.requests) == 2 + assert batch.model_id.key == model_key + assert model_key in request_dispatcher._queues + assert model_key in request_dispatcher._active_queues + assert len(request_dispatcher._queues[model_key]) == 1 + assert request_dispatcher._queues[model_key][0].empty() + assert request_dispatcher._queues[model_key][0].model_id.key == model_key + assert len(tensors) == 1 + assert tensors[0].shape == torch.Size([2, 2]) + + for tensor in tensors: + for sample_idx in range(tensor.shape[0]): + tensor_in = tensor[sample_idx] + tensor_out = (sample_idx + 1) * torch.ones( + (2,), dtype=torch.float32 + ) + assert torch.equal(tensor_in, tensor_out) + + except Exception as exc: + raise exc + finally: + for mem_alloc in mem_allocs: + mem_alloc.free() + + request_dispatcher._active_queues[model_key].make_disposable() + assert request_dispatcher._active_queues[model_key].can_be_removed + + request_dispatcher._on_iteration() + + assert model_key not in request_dispatcher._active_queues + assert model_key not in request_dispatcher._queues + + # Try to remove the dispatcher and free the memory + del request_dispatcher + gc.collect() diff --git a/tests/dragon/test_torch_worker.py b/tests/dragon/test_torch_worker.py new file mode 100644 index 0000000000..2a9e7d01bd --- /dev/null +++ b/tests/dragon/test_torch_worker.py @@ -0,0 +1,221 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import numpy as np +import pytest +import torch + +dragon = pytest.importorskip("dragon") +import dragon.globalservices.pool as dragon_gs_pool +from dragon.managed_memory import MemoryAlloc, MemoryPool +from torch import nn +from torch.nn import functional as F + +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + RequestBatch, + TransformInputResult, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__) +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +# simple MNIST in PyTorch +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x, y): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +torch_device = {"cpu": "cpu", "gpu": "cuda"} + + +def get_batch() -> torch.Tensor: + return torch.rand(20, 1, 28, 28) + + +def create_torch_model(): + n = Net() + example_forward_input = get_batch() + module = torch.jit.trace(n, [example_forward_input, example_forward_input]) + model_buffer = io.BytesIO() + torch.jit.save(module, model_buffer) + return model_buffer.getvalue() + + +def get_request() -> InferenceRequest: + + tensors = [get_batch() for _ in range(2)] + tensor_numpy = [tensor.numpy() for tensor in tensors] + serialized_tensors_descriptors = [ + MessageHandler.build_tensor_descriptor("c", "float32", list(tensor.shape)) + for tensor in tensors + ] + + return InferenceRequest( + model_key=ModelKey(key="model", descriptor="xyz"), + callback=None, + raw_inputs=tensor_numpy, + input_keys=None, + input_meta=serialized_tensors_descriptors, + output_keys=None, + raw_model=create_torch_model(), + batch_size=0, + ) + + +def get_request_batch_from_request( + request: InferenceRequest, inputs: t.Optional[TransformInputResult] = None +) -> RequestBatch: + + return RequestBatch([request], inputs, request.model_key) + + +sample_request: InferenceRequest = get_request() +sample_request_batch: RequestBatch = get_request_batch_from_request(sample_request) +worker = TorchWorker() + + +def test_load_model(mlutils) -> None: + fetch_model_result = FetchModelResult(sample_request.raw_model) + load_model_result = worker.load_model( + sample_request_batch, fetch_model_result, mlutils.get_test_device().lower() + ) + + assert load_model_result.model( + get_batch().to(torch_device[mlutils.get_test_device().lower()]), + get_batch().to(torch_device[mlutils.get_test_device().lower()]), + ).shape == torch.Size((20, 10)) + + +def test_transform_input(mlutils) -> None: + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_input_result = worker.transform_input( + sample_request_batch, [fetch_input_result], mem_pool + ) + + batch = get_batch().numpy() + assert transform_input_result.slices[0] == slice(0, batch.shape[0]) + + for tensor_index in range(2): + assert torch.Size(transform_input_result.dims[tensor_index]) == batch.shape + assert transform_input_result.dtypes[tensor_index] == str(batch.dtype) + mem_alloc = MemoryAlloc.attach(transform_input_result.transformed[tensor_index]) + itemsize = batch.itemsize + tensor = torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[ + 0 : np.prod(transform_input_result.dims[tensor_index]) * itemsize + ], + dtype=transform_input_result.dtypes[tensor_index], + ).reshape(transform_input_result.dims[tensor_index]) + ) + + assert torch.equal( + tensor, torch.from_numpy(sample_request.raw_inputs[tensor_index]) + ) + + mem_pool.destroy() + + +def test_execute(mlutils) -> None: + load_model_result = LoadModelResult( + Net().to(torch_device[mlutils.get_test_device().lower()]) + ) + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + request_batch = get_request_batch_from_request(sample_request, fetch_input_result) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_result = worker.transform_input( + request_batch, [fetch_input_result], mem_pool + ) + + execute_result = worker.execute( + request_batch, + load_model_result, + transform_result, + mlutils.get_test_device().lower(), + ) + + assert all( + result.shape == torch.Size((20, 10)) for result in execute_result.predictions + ) + + mem_pool.destroy() + + +def test_transform_output(mlutils): + tensors = [torch.rand((20, 10)) for _ in range(2)] + execute_result = ExecuteResult(tensors, [slice(0, 20)]) + + transformed_output = worker.transform_output(sample_request_batch, execute_result) + + assert transformed_output[0].outputs == [item.numpy().tobytes() for item in tensors] + assert transformed_output[0].shape == None + assert transformed_output[0].order == "c" + assert transformed_output[0].dtype == "float32" diff --git a/tests/dragon/test_worker_manager.py b/tests/dragon/test_worker_manager.py new file mode 100644 index 0000000000..4047a731fc --- /dev/null +++ b/tests/dragon/test_worker_manager.py @@ -0,0 +1,314 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import logging +import pathlib +import time + +import pytest + +torch = pytest.importorskip("torch") +dragon = pytest.importorskip("dragon") + +import multiprocessing as mp + +try: + mp.set_start_method("dragon") +except Exception: + pass + +import os + +import torch.nn as nn +from dragon import fli + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.control.worker_manager import ( + EnvironmentConfigLoader, + WorkerManager, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_util import create_ddict +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +from .utils.channel import FileSystemCommChannel + +logger = get_logger(__name__) +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +class MiniModel(nn.Module): + """A torch model that can be executed by the default torch worker""" + + def __init__(self): + """Initialize the model.""" + super().__init__() + + self._name = "mini-model" + self._net = torch.nn.Linear(2, 1) + + def forward(self, input): + """Execute a forward pass.""" + return self._net(input) + + @property + def bytes(self) -> bytes: + """Retrieve the serialized model + + :returns: The byte stream of the model file + """ + buffer = io.BytesIO() + scripted = torch.jit.trace(self._net, self.get_batch()) + torch.jit.save(scripted, buffer) + return buffer.getvalue() + + @classmethod + def get_batch(cls) -> "torch.Tensor": + """Generate a single batch of data with the correct + shape for inference. + + :returns: The batch as a torch tensor + """ + return torch.randn((100, 2), dtype=torch.float32) + + +def create_model(model_path: pathlib.Path) -> pathlib.Path: + """Create a simple torch model and persist to disk for + testing purposes. + + :param model_path: The path to the torch model file + """ + if not model_path.parent.exists(): + model_path.parent.mkdir(parents=True, exist_ok=True) + + model_path.unlink(missing_ok=True) + + mini_model = MiniModel() + torch.save(mini_model, model_path) + + return model_path + + +def load_model() -> bytes: + """Create a simple torch model in memory for testing.""" + mini_model = MiniModel() + return mini_model.bytes + + +def mock_messages( + feature_store_root_dir: pathlib.Path, + comm_channel_root_dir: pathlib.Path, + kill_queue: mp.Queue, +) -> None: + """Mock event producer for triggering the inference pipeline. + + :param feature_store_root_dir: Path to a directory where a + FileSystemFeatureStore can read & write results + :param comm_channel_root_dir: Path to a directory where a + FileSystemCommChannel can read & write messages + :param kill_queue: Queue used by unit test to stop mock_message process + """ + feature_store_root_dir.mkdir(parents=True, exist_ok=True) + comm_channel_root_dir.mkdir(parents=True, exist_ok=True) + + iteration_number = 0 + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + backbone = config_loader.get_backbone() + + worker_queue = config_loader.get_queue() + if worker_queue is None: + queue_desc = config_loader._queue_descriptor + logger.warn( + f"FLI input queue not loaded correctly from config_loader: {queue_desc}" + ) + + model_key = "mini-model" + model_bytes = load_model() + backbone[model_key] = model_bytes + + while True: + if not kill_queue.empty(): + return + iteration_number += 1 + time.sleep(1) + + channel_key = comm_channel_root_dir / f"{iteration_number}/channel.txt" + callback_channel = FileSystemCommChannel(pathlib.Path(channel_key)) + + batch = MiniModel.get_batch() + shape = batch.shape + batch_bytes = batch.numpy().tobytes() + + logger.debug(f"Model content: {backbone[model_key][:20]}") + + input_descriptor = MessageHandler.build_tensor_descriptor( + "f", "float32", list(shape) + ) + + # The first request is always the metadata... + request = MessageHandler.build_request( + reply_channel=callback_channel.descriptor, + model=MessageHandler.build_model(model_bytes, "mini-model", "1.0"), + inputs=[input_descriptor], + outputs=[], + output_descriptors=[], + custom_attributes=None, + ) + request_bytes = MessageHandler.serialize_request(request) + fli: DragonFLIChannel = worker_queue + + with fli._fli.sendh(timeout=None, stream_channel=fli._channel) as sendh: + sendh.send_bytes(request_bytes) + sendh.send_bytes(batch_bytes) + + logger.info("published message") + + if iteration_number > 5: + return + + +def mock_mli_infrastructure_mgr() -> None: + """Create resources normally instanatiated by the infrastructure + management portion of the DragonBackend. + """ + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + integrated_worker = TorchWorker + + worker_manager = WorkerManager( + config_loader, + integrated_worker, + as_service=True, + cooldown=10, + device="cpu", + dispatcher_queue=mp.Queue(maxsize=0), + ) + worker_manager.execute() + + +@pytest.fixture +def prepare_environment(test_dir: str) -> pathlib.Path: + """Cleanup prior outputs to run demo repeatedly. + + :param test_dir: the directory to prepare + :returns: The path to the log file + """ + path = pathlib.Path(f"{test_dir}/workermanager.log") + logging.basicConfig(filename=path.absolute(), level=logging.DEBUG) + return path + + +def test_worker_manager(prepare_environment: pathlib.Path) -> None: + """Test the worker manager. + + :param prepare_environment: Pass this fixture to configure + global resources before the worker manager executes + """ + + test_path = prepare_environment + fs_path = test_path / "feature_store" + comm_path = test_path / "comm_store" + + mgr_per_node = 1 + num_nodes = 2 + mem_per_node = 128 * 1024**2 + + storage = create_ddict(num_nodes, mgr_per_node, mem_per_node) + backbone = BackboneFeatureStore(storage, allow_reserved_writes=True) + + to_worker_channel = create_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + + to_worker_fli_comm_channel = DragonFLIChannel(to_worker_fli) + + # NOTE: env vars must be set prior to instantiating EnvironmentConfigLoader + # or test environment may be unable to send messages w/queue + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = ( + to_worker_fli_comm_channel.descriptor + ) + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + integrated_worker_type = TorchWorker + + worker_manager = WorkerManager( + config_loader, + integrated_worker_type, + as_service=True, + cooldown=5, + device="cpu", + dispatcher_queue=mp.Queue(maxsize=0), + ) + + worker_queue = config_loader.get_queue() + if worker_queue is None: + logger.warn( + f"FLI input queue not loaded correctly from config_loader: {config_loader._queue_descriptor}" + ) + backbone.worker_queue = to_worker_fli_comm_channel.descriptor + + # create a mock client application to populate the request queue + kill_queue = mp.Queue() + msg_pump = mp.Process( + target=mock_messages, + args=(fs_path, comm_path, kill_queue), + ) + msg_pump.start() + + # create a process to execute commands + process = mp.Process(target=mock_mli_infrastructure_mgr) + + # let it send some messages before starting the worker manager + msg_pump.join(timeout=5) + process.start() + msg_pump.join(timeout=5) + kill_queue.put_nowait("kill!") + process.join(timeout=5) + msg_pump.kill() + process.kill() diff --git a/tests/dragon/utils/__init__.py b/tests/dragon/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py new file mode 100644 index 0000000000..4c46359c2d --- /dev/null +++ b/tests/dragon/utils/channel.py @@ -0,0 +1,125 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import threading +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. + + :param key: a path to the root directory of the feature store + """ + self._lock = threading.RLock() + + super().__init__(key.as_posix()) + self._file_path = key + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes, timeout: float = 0) -> None: + """Send a message throuh the underlying communication channel. + + :param value: The value to send + :param timeout: maximum time to wait (in seconds) for messages to send + """ + with self._lock: + # write as text so we can add newlines as delimiters + with open(self._file_path, "a") as fp: + encoded_value = base64.b64encode(value).decode("utf-8") + fp.write(f"{encoded_value}\n") + logger.debug(f"FileSystemCommChannel {self._file_path} sent message") + + def recv(self, timeout: float = 0) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: maximum time to wait (in seconds) for messages to arrive + :returns: the received message + :raises SmartSimError: if the descriptor points to a missing file + """ + with self._lock: + messages: t.List[bytes] = [] + if not self._file_path.exists(): + raise SmartSimError("Empty channel") + + # read as text so we can split on newlines + with open(self._file_path, "r") as fp: + lines = fp.readlines() + + if lines: + line = lines.pop(0) + event_bytes = base64.b64decode(line.encode("utf-8")) + messages.append(event_bytes) + + self.clear() + + # remove the first message only, write remainder back... + if len(lines) > 0: + with open(self._file_path, "w") as fp: + fp.writelines(lines) + + logger.debug( + f"FileSystemCommChannel {self._file_path} received message" + ) + + return messages + + def clear(self) -> None: + """Create an empty file for events.""" + if self._file_path.exists(): + self._file_path.unlink() + self._file_path.touch() + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemCommChannel + """ + try: + path = pathlib.Path(descriptor) + return FileSystemCommChannel(path) + except: + logger.warning(f"failed to create fs comm channel: {descriptor}") + raise diff --git a/tests/dragon/utils/msg_pump.py b/tests/dragon/utils/msg_pump.py new file mode 100644 index 0000000000..8d69e57c63 --- /dev/null +++ b/tests/dragon/utils/msg_pump.py @@ -0,0 +1,225 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import logging +import pathlib +import sys +import time +import typing as t + +import pytest + +pytest.importorskip("torch") +pytest.importorskip("dragon") + + +# isort: off +import dragon +import multiprocessing as mp +import torch +import torch.nn as nn + +# isort: on + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__, log_level=logging.DEBUG) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +try: + mp.set_start_method("dragon") +except Exception: + pass + + +class MiniModel(nn.Module): + def __init__(self): + super().__init__() + + self._name = "mini-model" + self._net = torch.nn.Linear(2, 1) + + def forward(self, input): + return self._net(input) + + @property + def bytes(self) -> bytes: + """Returns the model serialized to a byte stream""" + buffer = io.BytesIO() + scripted = torch.jit.trace(self._net, self.get_batch()) + torch.jit.save(scripted, buffer) + return buffer.getvalue() + + @classmethod + def get_batch(cls) -> "torch.Tensor": + return torch.randn((100, 2), dtype=torch.float32) + + +def load_model() -> bytes: + """Create a simple torch model in memory for testing""" + mini_model = MiniModel() + return mini_model.bytes + + +def persist_model_file(model_path: pathlib.Path) -> pathlib.Path: + """Create a simple torch model and persist to disk for + testing purposes. + + :returns: Path to the model file + """ + # test_path = pathlib.Path(work_dir) + if not model_path.parent.exists(): + model_path.parent.mkdir(parents=True, exist_ok=True) + + model_path.unlink(missing_ok=True) + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + + return model_path + + +def _mock_messages( + dispatch_fli_descriptor: str, + fs_descriptor: str, + parent_iteration: int, + callback_descriptor: str, +) -> None: + """Mock event producer for triggering the inference pipeline.""" + model_key = "mini-model" + # mock_message sends 2 messages, so we offset by 2 * (# of iterations in caller) + offset = 2 * parent_iteration + + feature_store = BackboneFeatureStore.from_descriptor(fs_descriptor) + request_dispatcher_queue = DragonFLIChannel.from_descriptor(dispatch_fli_descriptor) + + feature_store[model_key] = load_model() + + for iteration_number in range(2): + logged_iteration = offset + iteration_number + logger.debug(f"Sending mock message {logged_iteration}") + + output_key = f"output-{iteration_number}" + + tensor = ( + (iteration_number + 1) * torch.ones((1, 2), dtype=torch.float32) + ).numpy() + fsd = feature_store.descriptor + + tensor_desc = MessageHandler.build_tensor_descriptor( + "c", "float32", list(tensor.shape) + ) + + message_tensor_output_key = MessageHandler.build_tensor_key(output_key, fsd) + message_model_key = MessageHandler.build_model_key(model_key, fsd) + + request = MessageHandler.build_request( + reply_channel=callback_descriptor, + model=message_model_key, + inputs=[tensor_desc], + outputs=[message_tensor_output_key], + output_descriptors=[], + custom_attributes=None, + ) + + logger.info(f"Sending request {iteration_number} to request_dispatcher_queue") + request_bytes = MessageHandler.serialize_request(request) + + logger.info("Sending msg_envelope") + + # cuid = request_dispatcher_queue._channel.cuid + # logger.info(f"\tInternal cuid: {cuid}") + + # send the header & body together so they arrive together + try: + request_dispatcher_queue.send_multiple([request_bytes, tensor.tobytes()]) + logger.info(f"\tenvelope 0: {request_bytes[:5]}...") + logger.info(f"\tenvelope 1: {tensor.tobytes()[:5]}...") + except Exception as ex: + logger.exception("Unable to send request envelope") + + logger.info("All messages sent") + + # keep the process alive for an extra 15 seconds to let the processor + # have access to the channels before they're destroyed + for _ in range(15): + time.sleep(1) + + +def mock_messages( + dispatch_fli_descriptor: str, + fs_descriptor: str, + parent_iteration: int, + callback_descriptor: str, +) -> int: + """Mock event producer for triggering the inference pipeline. Used + when starting using multiprocessing.""" + logger.info(f"{dispatch_fli_descriptor=}") + logger.info(f"{fs_descriptor=}") + logger.info(f"{parent_iteration=}") + logger.info(f"{callback_descriptor=}") + + try: + return _mock_messages( + dispatch_fli_descriptor, + fs_descriptor, + parent_iteration, + callback_descriptor, + ) + except Exception as ex: + logger.exception() + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + args = argparse.ArgumentParser() + + args.add_argument("--dispatch-fli-descriptor", type=str) + args.add_argument("--fs-descriptor", type=str) + args.add_argument("--parent-iteration", type=int) + args.add_argument("--callback-descriptor", type=str) + + args = args.parse_args() + + return_code = mock_messages( + args.dispatch_fli_descriptor, + args.fs_descriptor, + args.parent_iteration, + args.callback_descriptor, + ) + sys.exit(return_code) diff --git a/tests/dragon/utils/worker.py b/tests/dragon/utils/worker.py new file mode 100644 index 0000000000..0582cae566 --- /dev/null +++ b/tests/dragon/utils/worker.py @@ -0,0 +1,104 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import torch + +import smartsim._core.mli.infrastructure.worker.worker as mliw +import smartsim.error as sse +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class IntegratedTorchWorker(mliw.MachineLearningWorkerBase): + """A minimum implementation of a worker that executes a PyTorch model""" + + # @staticmethod + # def deserialize(request: InferenceRequest) -> t.List[t.Any]: + # # request.input_meta + # # request.raw_inputs + # return request + + @staticmethod + def load_model( + request: mliw.InferenceRequest, fetch_result: mliw.FetchModelResult, device: str + ) -> mliw.LoadModelResult: + model_bytes = fetch_result.model_bytes or request.raw_model + if not model_bytes: + raise ValueError("Unable to load model without reference object") + + model: torch.nn.Module = torch.load(io.BytesIO(model_bytes)) + result = mliw.LoadModelResult(model) + return result + + @staticmethod + def transform_input( + request: mliw.InferenceRequest, + fetch_result: mliw.FetchInputResult, + device: str, + ) -> mliw.TransformInputResult: + # extra metadata for assembly can be found in request.input_meta + raw_inputs = request.raw_inputs or fetch_result.inputs + + result: t.List[torch.Tensor] = [] + # should this happen here? + # consider - fortran to c data layout + # is there an intermediate representation before really doing torch.load? + if raw_inputs: + result = [torch.load(io.BytesIO(item)) for item in raw_inputs] + + return mliw.TransformInputResult(result) + + @staticmethod + def execute( + request: mliw.InferenceRequest, + load_result: mliw.LoadModelResult, + transform_result: mliw.TransformInputResult, + ) -> mliw.ExecuteResult: + if not load_result.model: + raise sse.SmartSimError("Model must be loaded to execute") + + model = load_result.model + results = [model(tensor) for tensor in transform_result.transformed] + + execute_result = mliw.ExecuteResult(results) + return execute_result + + @staticmethod + def transform_output( + request: mliw.InferenceRequest, + execute_result: mliw.ExecuteResult, + result_device: str, + ) -> mliw.TransformOutputResult: + # send the original tensors... + execute_result.predictions = [t.detach() for t in execute_result.predictions] + # todo: solve sending all tensor metadata that coincisdes with each prediction + return mliw.TransformOutputResult( + execute_result.predictions, [1], "c", "float32" + ) diff --git a/tests/mli/__init__.py b/tests/mli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/mli/channel.py b/tests/mli/channel.py new file mode 100644 index 0000000000..4c46359c2d --- /dev/null +++ b/tests/mli/channel.py @@ -0,0 +1,125 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import threading +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. + + :param key: a path to the root directory of the feature store + """ + self._lock = threading.RLock() + + super().__init__(key.as_posix()) + self._file_path = key + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes, timeout: float = 0) -> None: + """Send a message throuh the underlying communication channel. + + :param value: The value to send + :param timeout: maximum time to wait (in seconds) for messages to send + """ + with self._lock: + # write as text so we can add newlines as delimiters + with open(self._file_path, "a") as fp: + encoded_value = base64.b64encode(value).decode("utf-8") + fp.write(f"{encoded_value}\n") + logger.debug(f"FileSystemCommChannel {self._file_path} sent message") + + def recv(self, timeout: float = 0) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: maximum time to wait (in seconds) for messages to arrive + :returns: the received message + :raises SmartSimError: if the descriptor points to a missing file + """ + with self._lock: + messages: t.List[bytes] = [] + if not self._file_path.exists(): + raise SmartSimError("Empty channel") + + # read as text so we can split on newlines + with open(self._file_path, "r") as fp: + lines = fp.readlines() + + if lines: + line = lines.pop(0) + event_bytes = base64.b64decode(line.encode("utf-8")) + messages.append(event_bytes) + + self.clear() + + # remove the first message only, write remainder back... + if len(lines) > 0: + with open(self._file_path, "w") as fp: + fp.writelines(lines) + + logger.debug( + f"FileSystemCommChannel {self._file_path} received message" + ) + + return messages + + def clear(self) -> None: + """Create an empty file for events.""" + if self._file_path.exists(): + self._file_path.unlink() + self._file_path.touch() + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemCommChannel + """ + try: + path = pathlib.Path(descriptor) + return FileSystemCommChannel(path) + except: + logger.warning(f"failed to create fs comm channel: {descriptor}") + raise diff --git a/tests/mli/feature_store.py b/tests/mli/feature_store.py new file mode 100644 index 0000000000..7bc18253c8 --- /dev/null +++ b/tests/mli/feature_store.py @@ -0,0 +1,144 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import typing as t + +import smartsim.error as sse +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class MemoryFeatureStore(FeatureStore): + """A feature store with values persisted only in local memory""" + + def __init__(self, storage: t.Optional[t.Dict[str, bytes]] = None) -> None: + """Initialize the MemoryFeatureStore instance""" + super().__init__("in-memory-fs") + if storage is None: + storage = {"_": "abc"} + self._storage: t.Dict[str, bytes] = storage + + def _get(self, key: str) -> bytes: + """Retrieve an item using key + + :param key: Unique key of an item to retrieve from the feature store""" + if key not in self._storage: + raise sse.SmartSimError(f"{key} not found in feature store") + return self._storage[key] + + def _set(self, key: str, value: bytes) -> None: + """Membership operator to test for a key existing within the feature store. + + :param key: Unique key of an item to retrieve from the feature store + :returns: `True` if the key is found, `False` otherwise""" + self._check_reserved(key) + self._storage[key] = value + + def _contains(self, key: str) -> bool: + """Membership operator to test for a key existing within the feature store. + Return `True` if the key is found, `False` otherwise + :param key: Unique key of an item to retrieve from the feature store""" + return key in self._storage + + +class FileSystemFeatureStore(FeatureStore): + """Alternative feature store implementation for testing. Stores all + data on the file system""" + + def __init__(self, storage_dir: t.Union[pathlib.Path, str] = None) -> None: + """Initialize the FileSystemFeatureStore instance + + :param storage_dir: (optional) root directory to store all data relative to""" + if isinstance(storage_dir, str): + storage_dir = pathlib.Path(storage_dir) + self._storage_dir = storage_dir + super().__init__(storage_dir.as_posix()) + + def _get(self, key: str) -> bytes: + """Retrieve an item using key + + :param key: Unique key of an item to retrieve from the feature store""" + path = self._key_path(key) + if not path.exists(): + raise sse.SmartSimError(f"{path} not found in feature store") + return path.read_bytes() + + def _set(self, key: str, value: bytes) -> None: + """Assign a value using key + + :param key: Unique key of an item to set in the feature store + :param value: Value to persist in the feature store""" + path = self._key_path(key, create=True) + if isinstance(value, str): + value = value.encode("utf-8") + path.write_bytes(value) + + def _contains(self, key: str) -> bool: + """Membership operator to test for a key existing within the feature store. + + :param key: Unique key of an item to retrieve from the feature store + :returns: `True` if the key is found, `False` otherwise""" + path = self._key_path(key) + return path.exists() + + def _key_path(self, key: str, create: bool = False) -> pathlib.Path: + """Given a key, return a path that is optionally combined with a base + directory used by the FileSystemFeatureStore. + + :param key: Unique key of an item to retrieve from the feature store""" + value = pathlib.Path(key) + + if self._storage_dir: + value = self._storage_dir / key + + if create: + value.parent.mkdir(parents=True, exist_ok=True) + + return value + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemFeatureStore": + """A factory method that creates an instance from a descriptor string + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemFeatureStore""" + try: + path = pathlib.Path(descriptor) + path.mkdir(parents=True, exist_ok=True) + if not path.is_dir(): + raise ValueError("FileSystemFeatureStore requires a directory path") + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + return FileSystemFeatureStore(path) + except: + logger.error(f"Error while creating FileSystemFeatureStore: {descriptor}") + raise diff --git a/tests/mli/test_integrated_torch_worker.py b/tests/mli/test_integrated_torch_worker.py new file mode 100644 index 0000000000..60f1f0c6b9 --- /dev/null +++ b/tests/mli/test_integrated_torch_worker.py @@ -0,0 +1,275 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import typing as t + +import pytest +import torch + +# import smartsim.error as sse +# from smartsim._core.mli.infrastructure.control import workermanager as mli +# from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.utils import installed_redisai_backends + +# The tests in this file belong to the group_b group +pytestmark = pytest.mark.group_b + +# retrieved from pytest fixtures +is_dragon = pytest.test_launcher == "dragon" +torch_available = "torch" in installed_redisai_backends() + + +@pytest.fixture +def persist_torch_model(test_dir: str) -> pathlib.Path: + test_path = pathlib.Path(test_dir) + model_path = test_path / "basic.pt" + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + + return model_path + + +# todo: move deserialization tests into suite for worker manager where serialization occurs + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_direct_request(persist_torch_model: pathlib.Path) -> None: +# """Verify that a direct requestis deserialized properly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# model_bytes = persist_torch_model.read_bytes() +# input_tensor = torch.randn(2) + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# message_tensor_input = MessageHandler.build_tensor( +# input_tensor, "c", "float32", [2] +# ) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=model_bytes, +# inputs=[message_tensor_input], +# outputs=[], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_indirect_request(persist_torch_model: pathlib.Path) -> None: +# """Verify that an indirect request is deserialized correctly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# model_key = "persisted-model" +# # model_bytes = persist_torch_model.read_bytes() +# # feature_store[model_key] = model_bytes + +# input_key = f"demo-input" +# # input_tensor = torch.randn(2) +# # feature_store[input_key] = input_tensor + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# output_key = f"demo-output" + +# message_tensor_output_key = MessageHandler.build_tensor_key(output_key) +# message_tensor_input_key = MessageHandler.build_tensor_key(input_key) +# message_model_key = MessageHandler.build_model_key(model_key) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=message_model_key, +# inputs=[message_tensor_input_key], +# outputs=[message_tensor_output_key], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_mixed_mode_indirect_inputs( +# persist_torch_model: pathlib.Path, +# ) -> None: +# """Verify that a mixed mode (combining direct and indirect inputs, models, outputs) +# with indirect inputs is deserialized correctly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# # model_key = "persisted-model" +# model_bytes = persist_torch_model.read_bytes() +# # feature_store[model_key] = model_bytes + +# input_key = f"demo-input" +# # input_tensor = torch.randn(2) +# # feature_store[input_key] = input_tensor + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# output_key = f"demo-output" + +# message_tensor_output_key = MessageHandler.build_tensor_key(output_key) +# message_tensor_input_key = MessageHandler.build_tensor_key(input_key) +# # message_model_key = MessageHandler.build_model_key(model_key) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=model_bytes, +# inputs=[message_tensor_input_key], +# # outputs=[message_tensor_output_key], +# outputs=[], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_mixed_mode_indirect_outputs( +# persist_torch_model: pathlib.Path, +# ) -> None: +# """Verify that a mixed mode (combining direct and indirect inputs, models, outputs) +# with indirect outputs is deserialized correctly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# # model_key = "persisted-model" +# model_bytes = persist_torch_model.read_bytes() +# # feature_store[model_key] = model_bytes + +# input_key = f"demo-input" +# input_tensor = torch.randn(2) +# # feature_store[input_key] = input_tensor + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# output_key = f"demo-output" + +# message_tensor_output_key = MessageHandler.build_tensor_key(output_key) +# # message_tensor_input_key = MessageHandler.build_tensor_key(input_key) +# # message_model_key = MessageHandler.build_model_key(model_key) +# message_tensor_input = MessageHandler.build_tensor( +# input_tensor, "c", "float32", [2] +# ) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=model_bytes, +# inputs=[message_tensor_input], +# # outputs=[message_tensor_output_key], +# outputs=[message_tensor_output_key], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_mixed_mode_indirect_model( +# persist_torch_model: pathlib.Path, +# ) -> None: +# """Verify that a mixed mode (combining direct and indirect inputs, models, outputs) +# with indirect outputs is deserialized correctly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# model_key = "persisted-model" +# # model_bytes = persist_torch_model.read_bytes() +# # feature_store[model_key] = model_bytes + +# # input_key = f"demo-input" +# input_tensor = torch.randn(2) +# # feature_store[input_key] = input_tensor + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# output_key = f"demo-output" + +# # message_tensor_output_key = MessageHandler.build_tensor_key(output_key) +# # message_tensor_input_key = MessageHandler.build_tensor_key(input_key) +# message_model_key = MessageHandler.build_model_key(model_key) +# message_tensor_input = MessageHandler.build_tensor( +# input_tensor, "c", "float32", [2] +# ) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=message_model_key, +# inputs=[message_tensor_input], +# # outputs=[message_tensor_output_key], +# outputs=[], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_serialize(test_dir: str, persist_torch_model: pathlib.Path) -> None: +# """Verify that the worker correctly executes reply serialization""" +# worker = mli.IntegratedTorchWorker + +# reply = mli.InferenceReply() +# reply.output_keys = ["foo", "bar"] + +# # use the worker implementation of reply serialization to get bytes for +# # use on the callback channel +# reply_bytes = worker.serialize_reply(reply) +# assert reply_bytes is not None + +# # deserialize to verity the mapping in the worker.serialize_reply was correct +# actual_reply = MessageHandler.deserialize_response(reply_bytes) + +# actual_tensor_keys = [tk.key for tk in actual_reply.result.keys] +# assert set(actual_tensor_keys) == set(reply.output_keys) +# assert actual_reply.status == 200 +# assert actual_reply.statusMessage == "success" diff --git a/tests/mli/test_service.py b/tests/mli/test_service.py new file mode 100644 index 0000000000..3635f6ff78 --- /dev/null +++ b/tests/mli/test_service.py @@ -0,0 +1,290 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import datetime +import multiprocessing as mp +import pathlib +import time +import typing as t +from asyncore import loop + +import pytest +import torch + +import smartsim.error as sse +from smartsim._core.entrypoints.service import Service + +# The tests in this file belong to the group_b group +pytestmark = pytest.mark.group_a + + +class SimpleService(Service): + """Mock implementation of a service that counts method invocations + using the base class event hooks.""" + + def __init__( + self, + log: t.List[str], + quit_after: int = -1, + as_service: bool = False, + cooldown: float = 0, + loop_delay: float = 0, + hc_freq: float = -1, + run_for: float = 0, + ) -> None: + super().__init__(as_service, cooldown, loop_delay, hc_freq) + self._log = log + self._quit_after = quit_after + self.num_starts = 0 + self.num_shutdowns = 0 + self.num_health_checks = 0 + self.num_cooldowns = 0 + self.num_delays = 0 + self.num_iterations = 0 + self.num_can_shutdown = 0 + self.run_for = run_for + self.start_time = time.time() + + @property + def runtime(self) -> float: + return time.time() - self.start_time + + def _can_shutdown(self) -> bool: + self.num_can_shutdown += 1 + + if self._quit_after > -1 and self.num_iterations >= self._quit_after: + return True + if self.run_for > 0: + return self.runtime >= self.run_for + + def _on_start(self) -> None: + self.num_starts += 1 + + def _on_shutdown(self) -> None: + self.num_shutdowns += 1 + + def _on_health_check(self) -> None: + self.num_health_checks += 1 + + def _on_cooldown_elapsed(self) -> None: + self.num_cooldowns += 1 + + def _on_delay(self) -> None: + self.num_delays += 1 + + def _on_iteration(self) -> None: + self.num_iterations += 1 + + return self.num_iterations >= self._quit_after + + +def test_service_init() -> None: + """Verify expected default values after Service initialization""" + activity_log: t.List[str] = [] + service = SimpleService(activity_log) + + assert service._as_service is False + assert service._cooldown == 0 + assert service._loop_delay == 0 + + +def test_service_run_once() -> None: + """Verify the service completes after a single call to _on_iteration""" + activity_log: t.List[str] = [] + service = SimpleService(activity_log) + + service.execute() + + assert service.num_iterations == 1 + assert service.num_starts == 1 + assert service.num_cooldowns == 0 # it never exceeds a cooldown period + assert service.num_can_shutdown == 0 # it automatically exits in run once + assert service.num_shutdowns == 1 + + +@pytest.mark.parametrize( + "num_iterations", + [ + pytest.param(0, id="Immediate Shutdown"), + pytest.param(1, id="1x"), + pytest.param(2, id="2x"), + pytest.param(4, id="4x"), + pytest.param(8, id="8x"), + pytest.param(16, id="16x"), + pytest.param(32, id="32x"), + ], +) +def test_service_run_until_can_shutdown(num_iterations: int) -> None: + """Verify the service completes after a dynamic number of iterations + based on the return value of `_can_shutdown`""" + activity_log: t.List[str] = [] + + service = SimpleService(activity_log, quit_after=num_iterations, as_service=True) + + service.execute() + + if num_iterations == 0: + # no matter what, it should always execute the _on_iteration method + assert service.num_iterations == 1 + else: + # the shutdown check follows on_iteration. there will be one last call + assert service.num_iterations == num_iterations + + assert service.num_starts == 1 + assert service.num_shutdowns == 1 + + +@pytest.mark.parametrize( + "cooldown", + [ + pytest.param(1, id="1s"), + pytest.param(3, id="3s"), + pytest.param(5, id="5s"), + ], +) +def test_service_cooldown(cooldown: int) -> None: + """Verify that the cooldown period is respected""" + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=1, + as_service=True, + cooldown=cooldown, + loop_delay=0, + ) + + ts0 = datetime.datetime.now() + service.execute() + ts1 = datetime.datetime.now() + + fudge_factor = 1.1 # allow a little bit of wiggle room for the loop + duration_in_seconds = (ts1 - ts0).total_seconds() + + assert duration_in_seconds <= cooldown * fudge_factor + assert service.num_cooldowns == 1 + assert service.num_shutdowns == 1 + + +@pytest.mark.parametrize( + "delay, num_iterations", + [ + pytest.param(1, 3, id="1s delay, 3x"), + pytest.param(3, 2, id="2s delay, 2x"), + pytest.param(5, 1, id="5s delay, 1x"), + ], +) +def test_service_delay(delay: int, num_iterations: int) -> None: + """Verify that a delay is correctly added between iterations""" + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=num_iterations, + as_service=True, + cooldown=0, + loop_delay=delay, + ) + + ts0 = datetime.datetime.now() + service.execute() + ts1 = datetime.datetime.now() + + # the expected duration is the sum of the delay between each iteration + expected_duration = (num_iterations + 1) * delay + duration_in_seconds = (ts1 - ts0).total_seconds() + + assert duration_in_seconds <= expected_duration + assert service.num_cooldowns == 0 + assert service.num_shutdowns == 1 + + +@pytest.mark.parametrize( + "health_check_freq, run_for", + [ + pytest.param(1, 5.5, id="1s freq, 10x"), + pytest.param(5, 10.5, id="5s freq, 2x"), + pytest.param(0.1, 5.1, id="0.1s freq, 50x"), + ], +) +def test_service_health_check_freq(health_check_freq: float, run_for: float) -> None: + """Verify that a the health check frequency is honored + + :param health_check_freq: The desired frequency of the health check + :pram run_for: A fixed duration to allow the service to run + """ + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=-1, + as_service=True, + cooldown=0, + hc_freq=health_check_freq, + run_for=run_for, + ) + + ts0 = datetime.datetime.now() + service.execute() + ts1 = datetime.datetime.now() + + # the expected duration is the sum of the delay between each iteration + expected_hc_count = run_for // health_check_freq + + # allow some wiggle room for frequency comparison + assert expected_hc_count - 1 <= service.num_health_checks <= expected_hc_count + 1 + + assert service.num_cooldowns == 0 + assert service.num_shutdowns == 1 + + +def test_service_health_check_freq_unbound() -> None: + """Verify that a health check frequency of zero is treated as + "always on" and is called each loop iteration + + :param health_check_freq: The desired frequency of the health check + :pram run_for: A fixed duration to allow the service to run + """ + health_check_freq: float = 0.0 + run_for: float = 5 + + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=-1, + as_service=True, + cooldown=0, + hc_freq=health_check_freq, + run_for=run_for, + ) + + service.execute() + + # allow some wiggle room for frequency comparison + assert service.num_health_checks == service.num_iterations + assert service.num_cooldowns == 0 + assert service.num_shutdowns == 1 diff --git a/tests/mli/worker.py b/tests/mli/worker.py new file mode 100644 index 0000000000..0582cae566 --- /dev/null +++ b/tests/mli/worker.py @@ -0,0 +1,104 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import torch + +import smartsim._core.mli.infrastructure.worker.worker as mliw +import smartsim.error as sse +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class IntegratedTorchWorker(mliw.MachineLearningWorkerBase): + """A minimum implementation of a worker that executes a PyTorch model""" + + # @staticmethod + # def deserialize(request: InferenceRequest) -> t.List[t.Any]: + # # request.input_meta + # # request.raw_inputs + # return request + + @staticmethod + def load_model( + request: mliw.InferenceRequest, fetch_result: mliw.FetchModelResult, device: str + ) -> mliw.LoadModelResult: + model_bytes = fetch_result.model_bytes or request.raw_model + if not model_bytes: + raise ValueError("Unable to load model without reference object") + + model: torch.nn.Module = torch.load(io.BytesIO(model_bytes)) + result = mliw.LoadModelResult(model) + return result + + @staticmethod + def transform_input( + request: mliw.InferenceRequest, + fetch_result: mliw.FetchInputResult, + device: str, + ) -> mliw.TransformInputResult: + # extra metadata for assembly can be found in request.input_meta + raw_inputs = request.raw_inputs or fetch_result.inputs + + result: t.List[torch.Tensor] = [] + # should this happen here? + # consider - fortran to c data layout + # is there an intermediate representation before really doing torch.load? + if raw_inputs: + result = [torch.load(io.BytesIO(item)) for item in raw_inputs] + + return mliw.TransformInputResult(result) + + @staticmethod + def execute( + request: mliw.InferenceRequest, + load_result: mliw.LoadModelResult, + transform_result: mliw.TransformInputResult, + ) -> mliw.ExecuteResult: + if not load_result.model: + raise sse.SmartSimError("Model must be loaded to execute") + + model = load_result.model + results = [model(tensor) for tensor in transform_result.transformed] + + execute_result = mliw.ExecuteResult(results) + return execute_result + + @staticmethod + def transform_output( + request: mliw.InferenceRequest, + execute_result: mliw.ExecuteResult, + result_device: str, + ) -> mliw.TransformOutputResult: + # send the original tensors... + execute_result.predictions = [t.detach() for t in execute_result.predictions] + # todo: solve sending all tensor metadata that coincisdes with each prediction + return mliw.TransformOutputResult( + execute_result.predictions, [1], "c", "float32" + ) diff --git a/tests/on_wlm/test_dragon.py b/tests/on_wlm/test_dragon.py index a05d381415..1bef3cac8d 100644 --- a/tests/on_wlm/test_dragon.py +++ b/tests/on_wlm/test_dragon.py @@ -56,7 +56,7 @@ def test_dragon_global_path(global_dragon_teardown, wlmutils, test_dir, monkeypa def test_dragon_exp_path(global_dragon_teardown, wlmutils, test_dir, monkeypatch): monkeypatch.delenv("SMARTSIM_DRAGON_SERVER_PATH", raising=False) - monkeypatch.delenv("SMARTSIM_DRAGON_SERVER_PATH_EXP", raising=False) + monkeypatch.delenv("_SMARTSIM_DRAGON_SERVER_PATH_EXP", raising=False) exp: Experiment = Experiment( "test_dragon_connection", exp_path=test_dir, diff --git a/tests/test_config.py b/tests/test_config.py index 00a1fcdd36..5a84103ffd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -66,9 +66,9 @@ def get_redisai_env( """ env = os.environ.copy() if rai_path is not None: - env["RAI_PATH"] = rai_path + env["SMARTSIM_RAI_LIB"] = rai_path else: - env.pop("RAI_PATH", None) + env.pop("SMARTSIM_RAI_LIB", None) if lib_path is not None: env["SMARTSIM_DEP_INSTALL_PATH"] = lib_path @@ -85,7 +85,7 @@ def make_file(filepath: str) -> None: def test_redisai_invalid_rai_path(test_dir, monkeypatch): - """An invalid RAI_PATH and valid SMARTSIM_DEP_INSTALL_PATH should fail""" + """An invalid SMARTSIM_RAI_LIB and valid SMARTSIM_DEP_INSTALL_PATH should fail""" rai_file_path = os.path.join(test_dir, "lib", "mock-redisai.so") make_file(os.path.join(test_dir, "lib", "redisai.so")) @@ -94,7 +94,7 @@ def test_redisai_invalid_rai_path(test_dir, monkeypatch): config = Config() - # Fail when no file exists @ RAI_PATH + # Fail when no file exists @ SMARTSIM_RAI_LIB with pytest.raises(SSConfigError) as ex: _ = config.redisai @@ -102,7 +102,7 @@ def test_redisai_invalid_rai_path(test_dir, monkeypatch): def test_redisai_valid_rai_path(test_dir, monkeypatch): - """A valid RAI_PATH should override valid SMARTSIM_DEP_INSTALL_PATH and succeed""" + """A valid SMARTSIM_RAI_LIB should override valid SMARTSIM_DEP_INSTALL_PATH and succeed""" rai_file_path = os.path.join(test_dir, "lib", "mock-redisai.so") make_file(rai_file_path) @@ -117,7 +117,7 @@ def test_redisai_valid_rai_path(test_dir, monkeypatch): def test_redisai_invalid_lib_path(test_dir, monkeypatch): - """Invalid RAI_PATH and invalid SMARTSIM_DEP_INSTALL_PATH should fail""" + """Invalid SMARTSIM_RAI_LIB and invalid SMARTSIM_DEP_INSTALL_PATH should fail""" rai_file_path = f"{test_dir}/railib/redisai.so" @@ -133,7 +133,7 @@ def test_redisai_invalid_lib_path(test_dir, monkeypatch): def test_redisai_valid_lib_path(test_dir, monkeypatch): - """Valid RAI_PATH and invalid SMARTSIM_DEP_INSTALL_PATH should succeed""" + """Valid SMARTSIM_RAI_LIB and invalid SMARTSIM_DEP_INSTALL_PATH should succeed""" rai_file_path = os.path.join(test_dir, "lib", "mock-redisai.so") make_file(rai_file_path) @@ -147,7 +147,7 @@ def test_redisai_valid_lib_path(test_dir, monkeypatch): def test_redisai_valid_lib_path_null_rai(test_dir, monkeypatch): - """Missing RAI_PATH and valid SMARTSIM_DEP_INSTALL_PATH should succeed""" + """Missing SMARTSIM_RAI_LIB and valid SMARTSIM_DEP_INSTALL_PATH should succeed""" rai_file_path: t.Optional[str] = None lib_file_path = os.path.join(test_dir, "lib", "redisai.so") @@ -166,11 +166,11 @@ def test_redis_conf(): assert Path(config.database_conf).is_file() assert isinstance(config.database_conf, str) - os.environ["REDIS_CONF"] = "not/a/path" + os.environ["SMARTSIM_REDIS_CONF"] = "not/a/path" config = Config() with pytest.raises(SSConfigError): config.database_conf - os.environ.pop("REDIS_CONF") + os.environ.pop("SMARTSIM_REDIS_CONF") def test_redis_exe(): @@ -178,11 +178,11 @@ def test_redis_exe(): assert Path(config.database_exe).is_file() assert isinstance(config.database_exe, str) - os.environ["REDIS_PATH"] = "not/a/path" + os.environ["SMARTSIM_REDIS_SERVER_EXE"] = "not/a/path" config = Config() with pytest.raises(SSConfigError): config.database_exe - os.environ.pop("REDIS_PATH") + os.environ.pop("SMARTSIM_REDIS_SERVER_EXE") def test_redis_cli(): @@ -190,11 +190,11 @@ def test_redis_cli(): assert Path(config.redisai).is_file() assert isinstance(config.redisai, str) - os.environ["REDIS_CLI_PATH"] = "not/a/path" + os.environ["SMARTSIM_REDIS_CLI_EXE"] = "not/a/path" config = Config() with pytest.raises(SSConfigError): config.database_cli - os.environ.pop("REDIS_CLI_PATH") + os.environ.pop("SMARTSIM_REDIS_CLI_EXE") @pytest.mark.parametrize( diff --git a/tests/test_dragon_comm_utils.py b/tests/test_dragon_comm_utils.py new file mode 100644 index 0000000000..a6f9c206a4 --- /dev/null +++ b/tests/test_dragon_comm_utils.py @@ -0,0 +1,257 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import uuid + +import pytest + +from smartsim.error.errors import SmartSimError + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.channels as dch +import dragon.infrastructure.parameters as dp +import dragon.managed_memory as dm +import dragon.fli as fli + +# isort: on + +from smartsim._core.mli.comm.channel import dragon_util +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.fixture(scope="function") +def the_pool() -> dm.MemoryPool: + """Creates a memory pool.""" + raw_pool_descriptor = dp.this_process.default_pd + descriptor_ = base64.b64decode(raw_pool_descriptor) + + pool = dm.MemoryPool.attach(descriptor_) + return pool + + +@pytest.fixture(scope="function") +def the_channel() -> dch.Channel: + """Creates a Channel attached to the local memory pool.""" + channel = dch.Channel.make_process_local() + return channel + + +@pytest.fixture(scope="function") +def the_fli(the_channel) -> fli.FLInterface: + """Creates an FLI attached to the local memory pool.""" + fli_ = fli.FLInterface(main_ch=the_channel, manager_ch=None) + return fli_ + + +def test_descriptor_to_channel_empty() -> None: + """Verify that `descriptor_to_channel` raises an exception when + provided with an empty descriptor.""" + descriptor = "" + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "empty" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()], +) +def test_descriptor_to_channel_b64fail(descriptor: str) -> None: + """Verify that `descriptor_to_channel` raises an exception when + provided with an incorrectly encoded descriptor. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "base64" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + [str(uuid.uuid4())], +) +def test_descriptor_to_channel_channel_fail(descriptor: str) -> None: + """Verify that `descriptor_to_channel` raises an exception when a correctly + formatted descriptor that does not describe a real channel is passed. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + assert "channel" in ex.value.args[0] + + +def test_descriptor_to_channel_channel_not_available(the_channel: dch.Channel) -> None: + """Verify that `descriptor_to_channel` raises an exception when a channel + is no longer available. + + :param the_channel: A dragon channel + """ + + # get a good descriptor & wipe out the channel so it can't be attached + descriptor = dragon_util.channel_to_descriptor(the_channel) + the_channel.destroy() + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "address" in ex.value.args[0] + + +def test_descriptor_to_channel_happy_path(the_channel: dch.Channel) -> None: + """Verify that `descriptor_to_channel` works as expected when provided + a valid descriptor + + :param the_channel: A dragon channel + """ + + # get a good descriptor + descriptor = dragon_util.channel_to_descriptor(the_channel) + + reattached = dragon_util.descriptor_to_channel(descriptor) + assert reattached + + # and just make sure creation of the descriptor is transitive + assert dragon_util.channel_to_descriptor(reattached) == descriptor + + +def test_descriptor_to_fli_empty() -> None: + """Verify that `descriptor_to_fli` raises an exception when + provided with an empty descriptor.""" + descriptor = "" + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + assert "empty" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()], +) +def test_descriptor_to_fli_b64fail(descriptor: str) -> None: + """Verify that `descriptor_to_fli` raises an exception when + provided with an incorrectly encoded descriptor. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + assert "base64" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + [str(uuid.uuid4())], +) +def test_descriptor_to_fli_fli_fail(descriptor: str) -> None: + """Verify that `descriptor_to_fli` raises an exception when a correctly + formatted descriptor that does not describe a real FLI is passed. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + assert "fli" in ex.value.args[0].lower() + + +def test_descriptor_to_fli_fli_not_available( + the_fli: fli.FLInterface, the_channel: dch.Channel +) -> None: + """Verify that `descriptor_to_fli` raises an exception when a channel + is no longer available. + + :param the_fli: A dragon FLInterface + :param the_channel: A dragon channel + """ + + # get a good descriptor & wipe out the FLI so it can't be attached + descriptor = dragon_util.channel_to_descriptor(the_fli) + the_fli.destroy() + the_channel.destroy() + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + + +def test_descriptor_to_fli_happy_path(the_fli: dch.Channel) -> None: + """Verify that `descriptor_to_fli` works as expected when provided + a valid descriptor + + :param the_fli: A dragon FLInterface + """ + + # get a good descriptor + descriptor = dragon_util.channel_to_descriptor(the_fli) + + reattached = dragon_util.descriptor_to_fli(descriptor) + assert reattached + + # and just make sure creation of the descriptor is transitive + assert dragon_util.channel_to_descriptor(reattached) == descriptor + + +def test_pool_to_descriptor_empty() -> None: + """Verify that `pool_to_descriptor` raises an exception when + provided with a null pool.""" + + with pytest.raises(ValueError) as ex: + dragon_util.pool_to_descriptor(None) + + +def test_pool_to_happy_path(the_pool) -> None: + """Verify that `pool_to_descriptor` creates a descriptor + when supplied with a valid memory pool.""" + + descriptor = dragon_util.pool_to_descriptor(the_pool) + assert descriptor diff --git a/tests/test_dragon_installer.py b/tests/test_dragon_installer.py index b23a1a7ef0..a58d711721 100644 --- a/tests/test_dragon_installer.py +++ b/tests/test_dragon_installer.py @@ -31,12 +31,17 @@ from collections import namedtuple import pytest +from github.GitRelease import GitRelease from github.GitReleaseAsset import GitReleaseAsset from github.Requester import Requester import smartsim +import smartsim._core._install.utils import smartsim._core.utils.helpers as helpers from smartsim._core._cli.scripts.dragon_install import ( + DEFAULT_DRAGON_REPO, + DEFAULT_DRAGON_VERSION, + DragonInstallRequest, cleanup, create_dotenv, install_dragon, @@ -58,14 +63,25 @@ def test_archive(test_dir: str, archive_path: pathlib.Path) -> pathlib.Path: """Fixture for returning a simple tarfile to test on""" num_files = 10 + + archive_name = archive_path.name + archive_name = archive_name.replace(".tar.gz", "") + with tarfile.TarFile.open(archive_path, mode="w:gz") as tar: - mock_whl = pathlib.Path(test_dir) / "mock.whl" + mock_whl = pathlib.Path(test_dir) / archive_name / f"{archive_name}.whl" + mock_whl.parent.mkdir(parents=True, exist_ok=True) mock_whl.touch() + tar.add(mock_whl) + for i in range(num_files): - content = pathlib.Path(test_dir) / f"{i:04}.txt" + content = pathlib.Path(test_dir) / archive_name / f"{i:04}.txt" content.write_text(f"i am file {i}\n") tar.add(content) + content.unlink() + + mock_whl.unlink() + return archive_path @@ -118,11 +134,41 @@ def test_assets(monkeypatch: pytest.MonkeyPatch) -> t.Dict[str, GitReleaseAsset] _git_attr(value=f"http://foo/{archive_name}"), ) monkeypatch.setattr(asset, "_name", _git_attr(value=archive_name)) + monkeypatch.setattr(asset, "_id", _git_attr(value=123)) assets.append(asset) return assets +@pytest.fixture +def test_releases(monkeypatch: pytest.MonkeyPatch) -> t.Dict[str, GitRelease]: + requester = Requester( + auth=None, + base_url="https://github.com", + user_agent="mozilla", + per_page=10, + verify=False, + timeout=1, + retry=1, + pool_size=1, + ) + headers = {"mock-header": "mock-value"} + attributes = {"title": "mock-title"} + completed = True + + releases: t.List[GitRelease] = [] + + for python_version in ["py3.9", "py3.10", "py3.11"]: + for dragon_version in ["dragon-0.8", "dragon-0.9", "dragon-0.10"]: + attributes = { + "title": f"{python_version}-{dragon_version}-release", + "tag_name": f"v{dragon_version}-weekly", + } + releases.append(GitRelease(requester, headers, attributes, completed)) + + return releases + + def test_cleanup_no_op(archive_path: pathlib.Path) -> None: """Ensure that the cleanup method doesn't bomb when called with missing archive path; simulate a failed download""" @@ -143,17 +189,25 @@ def test_cleanup_archive_exists(test_archive: pathlib.Path) -> None: assert not test_archive.exists() -def test_retrieve_cached( - test_dir: str, - # archive_path: pathlib.Path, +@pytest.mark.skip("Deprecated due to builder.py changes") +def test_retrieve_updated( test_archive: pathlib.Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - """Verify that a previously retrieved asset archive is re-used""" - with tarfile.TarFile.open(test_archive) as tar: - tar.extractall(test_dir) + """Verify that a previously retrieved asset archive is not re-used if a new + version is found""" + + old_asset_id = 100 + asset_id = 123 - ts1 = test_archive.parent.stat().st_ctime + def mock__retrieve_archive(source_, destination_) -> None: + mock_extraction_dir = pathlib.Path(destination_) + with tarfile.TarFile.open(test_archive) as tar: + tar.extractall(mock_extraction_dir) + + # we'll use the mock extract to create the files that would normally be downloaded + expected_output_dir = test_archive.parent / str(asset_id) + old_output_dir = test_archive.parent / str(old_asset_id) requester = Requester( auth=None, @@ -174,14 +228,22 @@ def test_retrieve_cached( # ensure mocked asset has values that we use... monkeypatch.setattr(asset, "_browser_download_url", _git_attr(value="http://foo")) monkeypatch.setattr(asset, "_name", _git_attr(value=mock_archive_name)) + monkeypatch.setattr(asset, "_id", _git_attr(value=asset_id)) + monkeypatch.setattr( + smartsim._core._install.utils, + "retrieve", + lambda s_, d_: mock__retrieve_archive(s_, expected_output_dir), + ) # mock the retrieval of the updated archive + + # tell it to retrieve. it should return the path to the new download, not the old one + request = DragonInstallRequest(test_archive.parent) + asset_path = retrieve_asset(request, asset) - asset_path = retrieve_asset(test_archive.parent, asset) - ts2 = asset_path.stat().st_ctime + # sanity check we don't have the same paths + assert old_output_dir != expected_output_dir - assert ( - asset_path == test_archive.parent - ) # show that the expected path matches the output path - assert ts1 == ts2 # show that the file wasn't changed... + # verify the "cached" copy wasn't used + assert asset_path == expected_output_dir @pytest.mark.parametrize( @@ -214,11 +276,13 @@ def test_retrieve_cached( ) def test_retrieve_asset_info( test_assets: t.Collection[GitReleaseAsset], + test_releases: t.Collection[GitRelease], monkeypatch: pytest.MonkeyPatch, dragon_pin: str, pyv: str, is_found: bool, is_crayex: bool, + test_dir: str, ) -> None: """Verify that an information is retrieved correctly based on the python version, platform (e.g. CrayEX, !CrayEx), and target dragon pin""" @@ -234,20 +298,23 @@ def test_retrieve_asset_info( "is_crayex_platform", lambda: is_crayex, ) + # avoid hitting github API ctx.setattr( smartsim._core._cli.scripts.dragon_install, - "dragon_pin", - lambda: dragon_pin, + "_get_all_releases", + lambda x: test_releases, ) # avoid hitting github API ctx.setattr( smartsim._core._cli.scripts.dragon_install, "_get_release_assets", - lambda: test_assets, + lambda x: test_assets, ) + request = DragonInstallRequest(test_dir, version=dragon_pin) + if is_found: - chosen_asset = retrieve_asset_info() + chosen_asset = retrieve_asset_info(request) assert chosen_asset assert pyv in chosen_asset.name @@ -259,7 +326,7 @@ def test_retrieve_asset_info( assert "crayex" not in chosen_asset.name.lower() else: with pytest.raises(SmartSimCLIActionCancelled): - retrieve_asset_info() + retrieve_asset_info(request) def test_check_for_utility_missing(test_dir: str) -> None: @@ -357,11 +424,12 @@ def mock_util_check(util: str) -> bool: assert is_cray == platform_result -def test_install_package_no_wheel(extraction_dir: pathlib.Path): +def test_install_package_no_wheel(test_dir: str, extraction_dir: pathlib.Path): """Verify that a missing wheel does not blow up and has a failure retcode""" exp_path = extraction_dir + request = DragonInstallRequest(test_dir) - result = install_package(exp_path) + result = install_package(request, exp_path) assert result != 0 @@ -370,7 +438,9 @@ def test_install_macos(monkeypatch: pytest.MonkeyPatch, extraction_dir: pathlib. with monkeypatch.context() as ctx: ctx.setattr(sys, "platform", "darwin") - result = install_dragon(extraction_dir) + request = DragonInstallRequest(extraction_dir) + + result = install_dragon(request) assert result == 1 @@ -387,7 +457,7 @@ def test_create_dotenv(monkeypatch: pytest.MonkeyPatch, test_dir: str): # ensure no .env exists before trying to create it. assert not exp_env_path.exists() - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # ensure the .env is created as side-effect of create_dotenv assert exp_env_path.exists() @@ -409,7 +479,7 @@ def test_create_dotenv_existing_dir(monkeypatch: pytest.MonkeyPatch, test_dir: s # ensure no .env exists before trying to create it. assert not exp_env_path.exists() - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # ensure the .env is created as side-effect of create_dotenv assert exp_env_path.exists() @@ -434,17 +504,25 @@ def test_create_dotenv_existing_dotenv(monkeypatch: pytest.MonkeyPatch, test_dir # ensure .env exists so we can update it assert exp_env_path.exists() - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # ensure the .env is created as side-effect of create_dotenv assert exp_env_path.exists() # ensure file was overwritten and env vars are not duplicated dotenv_content = exp_env_path.read_text(encoding="utf-8") - split_content = dotenv_content.split(var_name) - - # split to confirm env var only appars once - assert len(split_content) == 2 + lines = [ + line for line in dotenv_content.split("\n") if line and not "#" in line + ] + for line in lines: + if line.startswith(var_name): + # make sure the var isn't defined recursively + # DRAGON_BASE_DIR=$DRAGON_BASE_DIR + assert var_name not in line[len(var_name) + 1 :] + else: + # make sure any values reference the original base dir var + if var_name in line: + assert f"${var_name}" in line def test_create_dotenv_format(monkeypatch: pytest.MonkeyPatch, test_dir: str): @@ -456,13 +534,13 @@ def test_create_dotenv_format(monkeypatch: pytest.MonkeyPatch, test_dir: str): with monkeypatch.context() as ctx: ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path) - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # ensure the .env is created as side-effect of create_dotenv content = exp_env_path.read_text(encoding="utf-8") # ensure we have values written, but ignore empty lines - lines = [line for line in content.split("\n") if line] + lines = [line for line in content.split("\n") if line and not "#" in line] assert lines # ensure each line is formatted as key=value diff --git a/tests/test_dragon_launcher.py b/tests/test_dragon_launcher.py index 4bd07e920c..a894757918 100644 --- a/tests/test_dragon_launcher.py +++ b/tests/test_dragon_launcher.py @@ -37,7 +37,10 @@ import zmq import smartsim._core.config -from smartsim._core._cli.scripts.dragon_install import create_dotenv +from smartsim._core._cli.scripts.dragon_install import ( + DEFAULT_DRAGON_VERSION, + create_dotenv, +) from smartsim._core.config.config import get_config from smartsim._core.launcher.dragon.dragonLauncher import ( DragonConnector, @@ -494,7 +497,7 @@ def test_load_env_env_file_created(monkeypatch: pytest.MonkeyPatch, test_dir: st with monkeypatch.context() as ctx: ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path) - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) dragon_conf = smartsim._core.config.CONFIG.dragon_dotenv # verify config does exist @@ -507,7 +510,26 @@ def test_load_env_env_file_created(monkeypatch: pytest.MonkeyPatch, test_dir: st assert loaded_env # confirm .env was parsed as expected by inspecting a key + assert "DRAGON_BASE_DIR" in loaded_env + base_dir = loaded_env["DRAGON_BASE_DIR"] + assert "DRAGON_ROOT_DIR" in loaded_env + assert loaded_env["DRAGON_ROOT_DIR"] == base_dir + + assert "DRAGON_INCLUDE_DIR" in loaded_env + assert loaded_env["DRAGON_INCLUDE_DIR"] == f"{base_dir}/include" + + assert "DRAGON_LIB_DIR" in loaded_env + assert loaded_env["DRAGON_LIB_DIR"] == f"{base_dir}/lib" + + assert "DRAGON_VERSION" in loaded_env + assert loaded_env["DRAGON_VERSION"] == DEFAULT_DRAGON_VERSION + + assert "PATH" in loaded_env + assert loaded_env["PATH"] == f"{base_dir}/bin" + + assert "LD_LIBRARY_PATH" in loaded_env + assert loaded_env["LD_LIBRARY_PATH"] == f"{base_dir}/lib" def test_load_env_cached_env(monkeypatch: pytest.MonkeyPatch, test_dir: str): @@ -517,7 +539,7 @@ def test_load_env_cached_env(monkeypatch: pytest.MonkeyPatch, test_dir: str): with monkeypatch.context() as ctx: ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path) - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # load config w/launcher connector = DragonConnector() @@ -541,7 +563,7 @@ def test_merge_env(monkeypatch: pytest.MonkeyPatch, test_dir: str): with monkeypatch.context() as ctx: ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path) - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # load config w/launcher connector = DragonConnector() diff --git a/tests/test_dragon_run_policy.py b/tests/test_dragon_run_policy.py index 1d8d069fab..5e8642c052 100644 --- a/tests/test_dragon_run_policy.py +++ b/tests/test_dragon_run_policy.py @@ -114,9 +114,6 @@ def test_create_run_policy_non_run_request(dragon_request: DragonRequest) -> Non policy = DragonBackend.create_run_policy(dragon_request, "localhost") assert policy is not None, "Default policy was not returned" - assert ( - policy.device == Policy.Device.DEFAULT - ), "Default device was not Device.DEFAULT" assert policy.cpu_affinity == [], "Default cpu affinity was not empty" assert policy.gpu_affinity == [], "Default gpu affinity was not empty" @@ -140,10 +137,8 @@ def test_create_run_policy_run_request_no_run_policy() -> None: policy = DragonBackend.create_run_policy(run_req, "localhost") - assert policy.device == Policy.Device.DEFAULT assert set(policy.cpu_affinity) == set() assert policy.gpu_affinity == [] - assert policy.affinity == Policy.Affinity.DEFAULT @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @@ -167,7 +162,6 @@ def test_create_run_policy_run_request_default_run_policy() -> None: assert set(policy.cpu_affinity) == set() assert set(policy.gpu_affinity) == set() - assert policy.affinity == Policy.Affinity.DEFAULT @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @@ -192,7 +186,6 @@ def test_create_run_policy_run_request_cpu_affinity_no_device() -> None: assert set(policy.cpu_affinity) == affinity assert policy.gpu_affinity == [] - assert policy.affinity == Policy.Affinity.SPECIFIC @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @@ -216,7 +209,6 @@ def test_create_run_policy_run_request_cpu_affinity() -> None: assert set(policy.cpu_affinity) == affinity assert policy.gpu_affinity == [] - assert policy.affinity == Policy.Affinity.SPECIFIC @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @@ -240,7 +232,6 @@ def test_create_run_policy_run_request_gpu_affinity() -> None: assert policy.cpu_affinity == [] assert set(policy.gpu_affinity) == set(affinity) - assert policy.affinity == Policy.Affinity.SPECIFIC @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") diff --git a/tests/test_dragon_run_request.py b/tests/test_dragon_run_request.py index 7514deab19..62ac572eb2 100644 --- a/tests/test_dragon_run_request.py +++ b/tests/test_dragon_run_request.py @@ -30,18 +30,14 @@ import time from unittest.mock import MagicMock +import pydantic.error_wrappers import pytest -from pydantic import ValidationError + +from smartsim._core.launcher.dragon.pqueue import NodePrioritizer # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b - -try: - import dragon - - dragon_loaded = True -except: - dragon_loaded = False +dragon = pytest.importorskip("dragon") from smartsim._core.config import CONFIG from smartsim._core.schemas.dragonRequests import * @@ -56,38 +52,6 @@ ) -class NodeMock(MagicMock): - def __init__( - self, name: t.Optional[str] = None, num_gpus: int = 2, num_cpus: int = 8 - ) -> None: - super().__init__() - self._mock_id = name - NodeMock._num_gpus = num_gpus - NodeMock._num_cpus = num_cpus - - @property - def hostname(self) -> str: - if self._mock_id: - return self._mock_id - return create_short_id_str() - - @property - def num_cpus(self) -> str: - return NodeMock._num_cpus - - @property - def num_gpus(self) -> str: - return NodeMock._num_gpus - - def _set_id(self, value: str) -> None: - self._mock_id = value - - def gpus(self, parent: t.Any = None) -> t.List[str]: - if self._num_gpus: - return [f"{self.hostname}-gpu{i}" for i in range(NodeMock._num_gpus)] - return [] - - class GroupStateMock(MagicMock): def Running(self) -> MagicMock: running = MagicMock(**{"__str__.return_value": "Running"}) @@ -102,59 +66,59 @@ class ProcessGroupMock(MagicMock): puids = [121, 122] -def node_mock() -> NodeMock: - return NodeMock() - - def get_mock_backend( - monkeypatch: pytest.MonkeyPatch, num_gpus: int = 2 + monkeypatch: pytest.MonkeyPatch, num_cpus: int, num_gpus: int ) -> "DragonBackend": - + # create all the necessary namespaces as raw magic mocks + monkeypatch.setitem(sys.modules, "dragon.data.ddict.ddict", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.native.machine", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.native.group_state", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.native.process_group", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.native.process", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.infrastructure.connection", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.infrastructure.policy", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.infrastructure.process_desc", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.data.ddict.ddict", MagicMock()) + + node_list = ["node1", "node2", "node3"] + system_mock = MagicMock(return_value=MagicMock(nodes=node_list)) + node_mock = lambda x: MagicMock(hostname=x, num_cpus=num_cpus, num_gpus=num_gpus) + process_group_mock = MagicMock(return_value=ProcessGroupMock()) process_mock = MagicMock(returncode=0) - process_group_mock = MagicMock(**{"Process.return_value": ProcessGroupMock()}) - process_module_mock = MagicMock() - process_module_mock.Process = process_mock - node_mock = NodeMock(num_gpus=num_gpus) - system_mock = MagicMock(nodes=["node1", "node2", "node3"]) + policy_mock = MagicMock(return_value=MagicMock()) + group_state_mock = GroupStateMock() + + # customize members that must perform specific actions within the namespaces monkeypatch.setitem( sys.modules, "dragon", MagicMock( **{ - "native.machine.Node.return_value": node_mock, - "native.machine.System.return_value": system_mock, - "native.group_state": GroupStateMock(), - "native.process_group.ProcessGroup.return_value": ProcessGroupMock(), + "native.machine.Node": node_mock, + "native.machine.System": system_mock, + "native.group_state": group_state_mock, + "native.process_group.ProcessGroup": process_group_mock, + "native.process_group.Process": process_mock, + "native.process.Process": process_mock, + "infrastructure.policy.Policy": policy_mock, } ), ) - monkeypatch.setitem( - sys.modules, - "dragon.infrastructure.connection", - MagicMock(), - ) - monkeypatch.setitem( - sys.modules, - "dragon.infrastructure.policy", - MagicMock(**{"Policy.return_value": MagicMock()}), - ) - monkeypatch.setitem(sys.modules, "dragon.native.process", process_module_mock) - monkeypatch.setitem(sys.modules, "dragon.native.process_group", process_group_mock) - monkeypatch.setitem(sys.modules, "dragon.native.group_state", GroupStateMock()) - monkeypatch.setitem( - sys.modules, - "dragon.native.machine", - MagicMock( - **{"System.return_value": system_mock, "Node.return_value": node_mock} - ), - ) from smartsim._core.launcher.dragon.dragonBackend import DragonBackend dragon_backend = DragonBackend(pid=99999) - monkeypatch.setattr( - dragon_backend, "_free_hosts", collections.deque(dragon_backend._hosts) + + # NOTE: we're manually updating these values due to issue w/mocking namespaces + dragon_backend._prioritizer = NodePrioritizer( + [ + MagicMock(num_cpus=num_cpus, num_gpus=num_gpus, hostname=node) + for node in node_list + ], + dragon_backend._queue_lock, ) + dragon_backend._cpus = [num_cpus] * len(node_list) + dragon_backend._gpus = [num_gpus] * len(node_list) return dragon_backend @@ -212,16 +176,14 @@ def set_mock_group_infos( } monkeypatch.setattr(dragon_backend, "_group_infos", group_infos) - monkeypatch.setattr(dragon_backend, "_free_hosts", collections.deque(hosts[1:3])) - monkeypatch.setattr(dragon_backend, "_allocated_hosts", {hosts[0]: "abc123-1"}) + monkeypatch.setattr(dragon_backend, "_allocated_hosts", {hosts[0]: {"abc123-1"}}) monkeypatch.setattr(dragon_backend, "_running_steps", ["abc123-1"]) return group_infos -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_handshake_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) handshake_req = DragonHandshakeRequest() handshake_resp = dragon_backend.process_request(handshake_req) @@ -230,9 +192,8 @@ def test_handshake_request(monkeypatch: pytest.MonkeyPatch) -> None: assert handshake_resp.dragon_pid == 99999 -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -259,9 +220,9 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None: assert dragon_backend._running_steps == [step_id] assert len(dragon_backend._queued_steps) == 0 - assert len(dragon_backend._free_hosts) == 1 - assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id - assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id + assert len(dragon_backend.free_hosts) == 1 + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] monkeypatch.setattr( dragon_backend._group_infos[step_id].process_group, "status", "Running" @@ -271,9 +232,9 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None: assert dragon_backend._running_steps == [step_id] assert len(dragon_backend._queued_steps) == 0 - assert len(dragon_backend._free_hosts) == 1 - assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id - assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id + assert len(dragon_backend.free_hosts) == 1 + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] dragon_backend._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED @@ -281,9 +242,8 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None: assert not dragon_backend._running_steps -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_deny_run_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) dragon_backend._shutdown_requested = True @@ -309,7 +269,7 @@ def test_deny_run_request(monkeypatch: pytest.MonkeyPatch) -> None: def test_run_request_with_empty_policy(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that a policy is applied to a run request""" - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -325,10 +285,9 @@ def test_run_request_with_empty_policy(monkeypatch: pytest.MonkeyPatch) -> None: assert run_req.policy is None -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that a policy is applied to a run request""" - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -356,9 +315,9 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None: assert dragon_backend._running_steps == [step_id] assert len(dragon_backend._queued_steps) == 0 - assert len(dragon_backend._free_hosts) == 1 - assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id - assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id + assert len(dragon_backend._prioritizer.unassigned()) == 1 + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] monkeypatch.setattr( dragon_backend._group_infos[step_id].process_group, "status", "Running" @@ -368,9 +327,9 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None: assert dragon_backend._running_steps == [step_id] assert len(dragon_backend._queued_steps) == 0 - assert len(dragon_backend._free_hosts) == 1 - assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id - assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id + assert len(dragon_backend._prioritizer.unassigned()) == 1 + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] dragon_backend._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED @@ -378,9 +337,8 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None: assert not dragon_backend._running_steps -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_udpate_status_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) group_infos = set_mock_group_infos(monkeypatch, dragon_backend) @@ -395,9 +353,8 @@ def test_udpate_status_request(monkeypatch: pytest.MonkeyPatch) -> None: } -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_stop_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) group_infos = set_mock_group_infos(monkeypatch, dragon_backend) running_steps = [ @@ -424,10 +381,9 @@ def test_stop_request(monkeypatch: pytest.MonkeyPatch) -> None: ) assert len(dragon_backend._allocated_hosts) == 0 - assert len(dragon_backend._free_hosts) == 3 + assert len(dragon_backend._prioritizer.unassigned()) == 3 -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize( "immediate, kill_jobs, frontend_shutdown", [ @@ -446,7 +402,7 @@ def test_shutdown_request( frontend_shutdown: bool, ) -> None: monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", "0") - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) monkeypatch.setattr(dragon_backend, "_cooldown_period", 1) set_mock_group_infos(monkeypatch, dragon_backend) @@ -486,11 +442,10 @@ def test_shutdown_request( assert dragon_backend._has_cooled_down == kill_jobs -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("telemetry_flag", ["0", "1"]) def test_cooldown_is_set(monkeypatch: pytest.MonkeyPatch, telemetry_flag: str) -> None: monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", telemetry_flag) - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) expected_cooldown = ( 2 * CONFIG.telemetry_frequency + 5 if int(telemetry_flag) > 0 else 5 @@ -502,19 +457,17 @@ def test_cooldown_is_set(monkeypatch: pytest.MonkeyPatch, telemetry_flag: str) - assert dragon_backend.cooldown_period == expected_cooldown -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_heartbeat_and_time(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) first_heartbeat = dragon_backend.last_heartbeat assert dragon_backend.current_time > first_heartbeat dragon_backend._heartbeat() assert dragon_backend.last_heartbeat > first_heartbeat -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("num_nodes", [1, 3, 100]) def test_can_honor(monkeypatch: pytest.MonkeyPatch, num_nodes: int) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -527,18 +480,42 @@ def test_can_honor(monkeypatch: pytest.MonkeyPatch, num_nodes: int) -> None: pmi_enabled=False, ) - assert dragon_backend._can_honor(run_req)[0] == ( - num_nodes <= len(dragon_backend._hosts) - ) + can_honor, error_msg = dragon_backend._can_honor(run_req) + + nodes_in_range = num_nodes <= len(dragon_backend._hosts) + assert can_honor == nodes_in_range + assert error_msg is None if nodes_in_range else error_msg is not None + + +@pytest.mark.parametrize("num_nodes", [-10, -1, 0]) +def test_can_honor_invalid_num_nodes( + monkeypatch: pytest.MonkeyPatch, num_nodes: int +) -> None: + """Verify that requests for invalid numbers of nodes (negative, zero) are rejected""" + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) + + with pytest.raises(pydantic.error_wrappers.ValidationError) as ex: + DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=num_nodes, + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + ) -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("affinity", [[0], [0, 1], list(range(8))]) def test_can_honor_cpu_affinity( monkeypatch: pytest.MonkeyPatch, affinity: t.List[int] ) -> None: """Verify that valid CPU affinities are accepted""" - dragon_backend = get_mock_backend(monkeypatch) + num_cpus, num_gpus = 8, 0 + dragon_backend = get_mock_backend(monkeypatch, num_cpus=num_cpus, num_gpus=num_gpus) + run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -555,11 +532,10 @@ def test_can_honor_cpu_affinity( assert dragon_backend._can_honor(run_req)[0] -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_can_honor_cpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that invalid CPU affinities are NOT accepted NOTE: negative values are captured by the Pydantic schema""" - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -576,13 +552,15 @@ def test_can_honor_cpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> assert not dragon_backend._can_honor(run_req)[0] -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("affinity", [[0], [0, 1]]) def test_can_honor_gpu_affinity( monkeypatch: pytest.MonkeyPatch, affinity: t.List[int] ) -> None: """Verify that valid GPU affinities are accepted""" - dragon_backend = get_mock_backend(monkeypatch) + + num_cpus, num_gpus = 8, 2 + dragon_backend = get_mock_backend(monkeypatch, num_cpus=num_cpus, num_gpus=num_gpus) + run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -599,11 +577,10 @@ def test_can_honor_gpu_affinity( assert dragon_backend._can_honor(run_req)[0] -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_can_honor_gpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that invalid GPU affinities are NOT accepted NOTE: negative values are captured by the Pydantic schema""" - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -620,46 +597,45 @@ def test_can_honor_gpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> assert not dragon_backend._can_honor(run_req)[0] -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_can_honor_gpu_device_not_available(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that a request for a GPU if none exists is not accepted""" # create a mock node class that always reports no GPUs available - dragon_backend = get_mock_backend(monkeypatch, num_gpus=0) - - run_req = DragonRunRequest( - exe="sleep", - exe_args=["5"], - path="/a/fake/path", - nodes=2, - tasks=1, - tasks_per_node=1, - env={}, - current_env={}, - pmi_enabled=False, - # specify GPU device w/no affinity - policy=DragonRunPolicy(gpu_affinity=[0]), - ) - - assert not dragon_backend._can_honor(run_req)[0] + with monkeypatch.context() as ctx: + dragon_backend = get_mock_backend(ctx, num_cpus=8, num_gpus=0) + + run_req = DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=2, + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + # specify GPU device w/no affinity + policy=DragonRunPolicy(gpu_affinity=[0]), + ) + can_honor, _ = dragon_backend._can_honor(run_req) + assert not can_honor -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_get_id(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) step_id = next(dragon_backend._step_ids) assert step_id.endswith("0") assert step_id != next(dragon_backend._step_ids) -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_view(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) set_mock_group_infos(monkeypatch, dragon_backend) hosts = dragon_backend.hosts + dragon_backend._prioritizer.increment(hosts[0]) - expected_message = textwrap.dedent(f"""\ + expected_msg = textwrap.dedent(f"""\ Dragon server backend update | Host | Status | |--------|----------| @@ -667,7 +643,7 @@ def test_view(monkeypatch: pytest.MonkeyPatch) -> None: | {hosts[1]} | Free | | {hosts[2]} | Free | | Step | Status | Hosts | Return codes | Num procs | - |----------|--------------|-------------|----------------|-------------| + |----------|--------------|-----------------|----------------|-------------| | abc123-1 | Running | {hosts[0]} | | 1 | | del999-2 | Cancelled | {hosts[1]} | -9 | 1 | | c101vz-3 | Completed | {hosts[1]},{hosts[2]} | 0 | 2 | @@ -676,6 +652,110 @@ def test_view(monkeypatch: pytest.MonkeyPatch) -> None: # get rid of white space to make the comparison easier actual_msg = dragon_backend.status_message.replace(" ", "") - expected_message = expected_message.replace(" ", "") + expected_msg = expected_msg.replace(" ", "") + + # ignore dashes in separators (hostname changes may cause column expansion) + while actual_msg.find("--") > -1: + actual_msg = actual_msg.replace("--", "-") + while expected_msg.find("--") > -1: + expected_msg = expected_msg.replace("--", "-") + + assert actual_msg == expected_msg + + +def test_can_honor_hosts_unavailable_hosts(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that requesting nodes with invalid names causes number of available + nodes check to fail due to valid # of named nodes being under num_nodes""" + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) + + # let's supply 2 invalid and 1 valid hostname + actual_hosts = list(dragon_backend._hosts) + actual_hosts[0] = f"x{actual_hosts[0]}" + actual_hosts[1] = f"x{actual_hosts[1]}" + + host_list = ",".join(actual_hosts) + + run_req = DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=2, # <----- requesting 2 of 3 available nodes + hostlist=host_list, # <--- only one valid name available + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + policy=DragonRunPolicy(), + ) + + can_honor, error_msg = dragon_backend._can_honor(run_req) + + # confirm the failure is indicated + assert not can_honor + # confirm failure message indicates number of nodes requested as cause + assert "named hosts" in error_msg + + +def test_can_honor_hosts_unavailable_hosts_ok(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that requesting nodes with invalid names causes number of available + nodes check to be reduced but still passes if enough valid named nodes are passed""" + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) + + # let's supply 2 valid and 1 invalid hostname + actual_hosts = list(dragon_backend._hosts) + actual_hosts[0] = f"x{actual_hosts[0]}" + + host_list = ",".join(actual_hosts) + + run_req = DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=2, # <----- requesting 2 of 3 available nodes + hostlist=host_list, # <--- two valid names are available + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + policy=DragonRunPolicy(), + ) + + can_honor, error_msg = dragon_backend._can_honor(run_req) + + # confirm the failure is indicated + assert can_honor, error_msg + # confirm failure message indicates number of nodes requested as cause + assert error_msg is None, error_msg + + +def test_can_honor_hosts_1_hosts_requested(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that requesting nodes with invalid names causes number of available + nodes check to be reduced but still passes if enough valid named nodes are passed""" + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) + + # let's supply 2 valid and 1 invalid hostname + actual_hosts = list(dragon_backend._hosts) + actual_hosts[0] = f"x{actual_hosts[0]}" + + host_list = ",".join(actual_hosts) + + run_req = DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=1, # <----- requesting 0 nodes - should be ignored + hostlist=host_list, # <--- two valid names are available + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + policy=DragonRunPolicy(), + ) + + can_honor, error_msg = dragon_backend._can_honor(run_req) - assert actual_msg == expected_message + # confirm the failure is indicated + assert can_honor, error_msg diff --git a/tests/test_dragon_runsettings.py b/tests/test_dragon_runsettings.py index 34e8510e82..8c7600c74c 100644 --- a/tests/test_dragon_runsettings.py +++ b/tests/test_dragon_runsettings.py @@ -96,3 +96,122 @@ def test_dragon_runsettings_gpu_affinity(): # ensure the value is not changed when we extend the list rs.run_args["gpu-affinity"] = "7,8,9" assert rs.run_args["gpu-affinity"] != ",".join(str(val) for val in exp_value) + + +def test_dragon_runsettings_hostlist_null(): + """Verify that passing a null hostlist is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + with pytest.raises(ValueError) as ex: + rs.set_hostlist(None) + + assert "empty hostlist" in ex.value.args[0] + + +def test_dragon_runsettings_hostlist_empty(): + """Verify that passing an empty hostlist is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + with pytest.raises(ValueError) as ex: + rs.set_hostlist([]) + + assert "empty hostlist" in ex.value.args[0] + + +@pytest.mark.parametrize("hostlist_csv", [" ", " , , , ", ",", ",,,"]) +def test_dragon_runsettings_hostlist_whitespace_handling(hostlist_csv: str): + """Verify that passing a hostlist with emptystring host names is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + # empty string as hostname in list + with pytest.raises(ValueError) as ex: + rs.set_hostlist(hostlist_csv) + + assert "invalid names" in ex.value.args[0] + + +@pytest.mark.parametrize( + "hostlist_csv", [[" "], [" ", "", " ", " "], ["", " "], ["", "", "", ""]] +) +def test_dragon_runsettings_hostlist_whitespace_handling_list(hostlist_csv: str): + """Verify that passing a hostlist with emptystring host names contained in a list + is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + # empty string as hostname in list + with pytest.raises(ValueError) as ex: + rs.set_hostlist(hostlist_csv) + + assert "invalid names" in ex.value.args[0] + + +def test_dragon_runsettings_hostlist_as_csv(): + """Verify that a hostlist is stored properly when passing in a CSV string""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + hostnames = ["host0", "host1", "host2", "host3", "host4"] + + # set the host list with ideal comma separated values + input0 = ",".join(hostnames) + + # set the host list with a string of comma separated values + # including extra whitespace + input1 = ", ".join(hostnames) + + for hosts_input in [input0, input1]: + rs.set_hostlist(hosts_input) + + stored_list = rs.run_args.get("host-list", None) + assert stored_list + + # confirm that all values from the original list are retrieved + split_stored_list = stored_list.split(",") + assert set(hostnames) == set(split_stored_list) + + +def test_dragon_runsettings_hostlist_as_csv(): + """Verify that a hostlist is stored properly when passing in a CSV string""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + hostnames = ["host0", "host1", "host2", "host3", "host4"] + + # set the host list with ideal comma separated values + input0 = ",".join(hostnames) + + # set the host list with a string of comma separated values + # including extra whitespace + input1 = ", ".join(hostnames) + + for hosts_input in [input0, input1]: + rs.set_hostlist(hosts_input) + + stored_list = rs.run_args.get("host-list", None) + assert stored_list + + # confirm that all values from the original list are retrieved + split_stored_list = stored_list.split(",") + assert set(hostnames) == set(split_stored_list) diff --git a/tests/test_dragon_step.py b/tests/test_dragon_step.py index 19f408e0bd..f933fb7bc2 100644 --- a/tests/test_dragon_step.py +++ b/tests/test_dragon_step.py @@ -73,12 +73,18 @@ def dragon_batch_step(test_dir: str) -> DragonBatchStep: cpu_affinities = [[], [0, 1, 2], [], [3, 4, 5, 6]] gpu_affinities = [[], [], [0, 1, 2], [3, 4, 5, 6]] + # specify 3 hostnames to select from but require only 2 nodes + num_nodes = 2 + hostnames = ["host1", "host2", "host3"] + # assign some unique affinities to each run setting instance for index, rs in enumerate(settings): if gpu_affinities[index]: rs.set_node_feature("gpu") rs.set_cpu_affinity(cpu_affinities[index]) rs.set_gpu_affinity(gpu_affinities[index]) + rs.set_hostlist(hostnames) + rs.set_nodes(num_nodes) steps = list( DragonStep(name_, test_dir, rs_) for name_, rs_ in zip(names, settings) @@ -374,6 +380,11 @@ def test_dragon_batch_step_write_request_file( cpu_affinities = [[], [0, 1, 2], [], [3, 4, 5, 6]] gpu_affinities = [[], [], [0, 1, 2], [3, 4, 5, 6]] + hostnames = ["host1", "host2", "host3"] + num_nodes = 2 + + # parse requests file path from the launch command + # e.g. dragon python launch_cmd = dragon_batch_step.get_launch_cmd() requests_file = get_request_path_from_batch_script(launch_cmd) @@ -392,3 +403,5 @@ def test_dragon_batch_step_write_request_file( assert run_request assert run_request.policy.cpu_affinity == cpu_affinities[index] assert run_request.policy.gpu_affinity == gpu_affinities[index] + assert run_request.nodes == num_nodes + assert run_request.hostlist == ",".join(hostnames) diff --git a/tests/test_message_handler/__init__.py b/tests/test_message_handler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_message_handler/test_build_model.py b/tests/test_message_handler/test_build_model.py new file mode 100644 index 0000000000..56c1c8764c --- /dev/null +++ b/tests/test_message_handler/test_build_model.py @@ -0,0 +1,72 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +def test_build_model_successful(): + expected_data = b"model data" + expected_name = "model name" + expected_version = "v0.0.1" + model = handler.build_model(expected_data, expected_name, expected_version) + assert model.data == expected_data + assert model.name == expected_name + assert model.version == expected_version + + +@pytest.mark.parametrize( + "data, name, version", + [ + pytest.param( + 100, + "model name", + "v0.0.1", + id="bad data type", + ), + pytest.param( + b"model data", + 1, + "v0.0.1", + id="bad name type", + ), + pytest.param( + b"model data", + "model name", + 0.1, + id="bad version type", + ), + ], +) +def test_build_model_unsuccessful(data, name, version): + with pytest.raises(ValueError): + model = handler.build_model(data, name, version) diff --git a/tests/test_message_handler/test_build_model_key.py b/tests/test_message_handler/test_build_model_key.py new file mode 100644 index 0000000000..6c9b3dc951 --- /dev/null +++ b/tests/test_message_handler/test_build_model_key.py @@ -0,0 +1,47 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +def test_build_model_key_successful(): + fsd = "mock-feature-store-descriptor" + model_key = handler.build_model_key("tensor_key", fsd) + assert model_key.key == "tensor_key" + assert model_key.descriptor == fsd + + +def test_build_model_key_unsuccessful(): + with pytest.raises(ValueError): + fsd = "mock-feature-store-descriptor" + model_key = handler.build_model_key(100, fsd) diff --git a/tests/test_message_handler/test_build_request_attributes.py b/tests/test_message_handler/test_build_request_attributes.py new file mode 100644 index 0000000000..5b1e09b0aa --- /dev/null +++ b/tests/test_message_handler/test_build_request_attributes.py @@ -0,0 +1,55 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +def test_build_torch_request_attributes_successful(): + attribute = handler.build_torch_request_attributes("sparse") + assert attribute.tensorType == "sparse" + + +def test_build_torch_request_attributes_unsuccessful(): + with pytest.raises(ValueError): + attribute = handler.build_torch_request_attributes("invalid!") + + +def test_build_tf_request_attributes_successful(): + attribute = handler.build_tf_request_attributes(name="tfcnn", tensor_type="sparse") + assert attribute.tensorType == "sparse" + assert attribute.name == "tfcnn" + + +def test_build_tf_request_attributes_unsuccessful(): + with pytest.raises(ValueError): + attribute = handler.build_tf_request_attributes("tf_fail", "invalid!") diff --git a/tests/test_message_handler/test_build_tensor_desc.py b/tests/test_message_handler/test_build_tensor_desc.py new file mode 100644 index 0000000000..45126fb16c --- /dev/null +++ b/tests/test_message_handler/test_build_tensor_desc.py @@ -0,0 +1,90 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +@pytest.mark.parametrize( + "dtype, order, dimension", + [ + pytest.param( + "int8", + "c", + [3, 2, 5], + id="small torch tensor", + ), + pytest.param( + "int64", + "c", + [1040, 1040, 3], + id="medium torch tensor", + ), + ], +) +def test_build_tensor_descriptor_successful(dtype, order, dimension): + built_tensor_descriptor = handler.build_tensor_descriptor(order, dtype, dimension) + assert built_tensor_descriptor is not None + assert built_tensor_descriptor.order == order + assert built_tensor_descriptor.dataType == dtype + for i, j in zip(built_tensor_descriptor.dimensions, dimension): + assert i == j + + +@pytest.mark.parametrize( + "dtype, order, dimension", + [ + pytest.param( + "bad_order", + "int8", + [3, 2, 5], + id="bad order type", + ), + pytest.param( + "f", + "bad_num_type", + [3, 2, 5], + id="bad numerical type", + ), + pytest.param( + "f", + "int8", + "bad shape type", + id="bad shape type", + ), + ], +) +def test_build_tensor_descriptor_unsuccessful(dtype, order, dimension): + with pytest.raises(ValueError): + built_tensor_descriptor = handler.build_tensor_descriptor( + order, dtype, dimension + ) diff --git a/tests/test_message_handler/test_build_tensor_key.py b/tests/test_message_handler/test_build_tensor_key.py new file mode 100644 index 0000000000..6a28b80c4f --- /dev/null +++ b/tests/test_message_handler/test_build_tensor_key.py @@ -0,0 +1,46 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +def test_build_tensor_key_successful(): + fsd = "mock-feature-store-descriptor" + tensor_key = handler.build_tensor_key("tensor_key", fsd) + assert tensor_key.key == "tensor_key" + + +def test_build_tensor_key_unsuccessful(): + with pytest.raises(ValueError): + fsd = "mock-feature-store-descriptor" + tensor_key = handler.build_tensor_key(100, fsd) diff --git a/tests/test_message_handler/test_output_descriptor.py b/tests/test_message_handler/test_output_descriptor.py new file mode 100644 index 0000000000..beb9a47657 --- /dev/null +++ b/tests/test_message_handler/test_output_descriptor.py @@ -0,0 +1,78 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + +fsd = "mock-feature-store-descriptor" +tensor_key = handler.build_tensor_key("key", fsd) + + +@pytest.mark.parametrize( + "order, keys, dtype, dimension", + [ + pytest.param("c", [tensor_key], "int8", [1, 2, 3, 4], id="all specified"), + pytest.param( + "c", [tensor_key, tensor_key], "none", [1, 2, 3, 4], id="none dtype" + ), + pytest.param("c", [tensor_key], "int8", [], id="empty dimensions"), + pytest.param("c", [], "int8", [1, 2, 3, 4], id="empty keys"), + ], +) +def test_build_output_tensor_descriptor_successful(dtype, keys, order, dimension): + built_descriptor = handler.build_output_tensor_descriptor( + order, keys, dtype, dimension + ) + assert built_descriptor is not None + assert built_descriptor.order == order + assert len(built_descriptor.optionalKeys) == len(keys) + assert built_descriptor.optionalDatatype == dtype + for i, j in zip(built_descriptor.optionalDimension, dimension): + assert i == j + + +@pytest.mark.parametrize( + "order, keys, dtype, dimension", + [ + pytest.param("bad_order", [], "int8", [3, 2, 5], id="bad order type"), + pytest.param( + "f", [tensor_key], "bad_num_type", [3, 2, 5], id="bad numerical type" + ), + pytest.param("f", [tensor_key], "int8", "bad shape type", id="bad shape type"), + pytest.param("f", ["tensor_key"], "int8", [3, 2, 5], id="bad key type"), + ], +) +def test_build_output_tensor_descriptor_unsuccessful(order, keys, dtype, dimension): + with pytest.raises(ValueError): + built_tensor = handler.build_output_tensor_descriptor( + order, keys, dtype, dimension + ) diff --git a/tests/test_message_handler/test_request.py b/tests/test_message_handler/test_request.py new file mode 100644 index 0000000000..a60818f7dd --- /dev/null +++ b/tests/test_message_handler/test_request.py @@ -0,0 +1,449 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +fsd = "mock-feature-store-descriptor" + +model_key = MessageHandler.build_model_key("model_key", fsd) +model = MessageHandler.build_model(b"model data", "model_name", "v0.0.1") + +input_key1 = MessageHandler.build_tensor_key("input_key1", fsd) +input_key2 = MessageHandler.build_tensor_key("input_key2", fsd) + +output_key1 = MessageHandler.build_tensor_key("output_key1", fsd) +output_key2 = MessageHandler.build_tensor_key("output_key2", fsd) + +output_descriptor1 = MessageHandler.build_output_tensor_descriptor( + "c", [output_key1, output_key2], "int64", [] +) +output_descriptor2 = MessageHandler.build_output_tensor_descriptor("f", [], "auto", []) +output_descriptor3 = MessageHandler.build_output_tensor_descriptor( + "c", [output_key1], "none", [1, 2, 3] +) +torch_attributes = MessageHandler.build_torch_request_attributes("sparse") +tf_attributes = MessageHandler.build_tf_request_attributes( + name="tf", tensor_type="sparse" +) + +tensor_1 = MessageHandler.build_tensor_descriptor("c", "int8", [1]) +tensor_2 = MessageHandler.build_tensor_descriptor("c", "int64", [3, 2]) +tensor_3 = MessageHandler.build_tensor_descriptor("f", "int8", [1]) +tensor_4 = MessageHandler.build_tensor_descriptor("f", "int64", [3, 2]) + + +tf_indirect_request = MessageHandler.build_request( + b"reply", + model, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1, output_descriptor2, output_descriptor3], + tf_attributes, +) + +tf_direct_request = MessageHandler.build_request( + b"reply", + model, + [tensor_3, tensor_4], + [], + [output_descriptor1, output_descriptor2], + tf_attributes, +) + +torch_indirect_request = MessageHandler.build_request( + b"reply", + model, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1, output_descriptor2, output_descriptor3], + torch_attributes, +) + +torch_direct_request = MessageHandler.build_request( + b"reply", + model, + [tensor_1, tensor_2], + [], + [output_descriptor1, output_descriptor2], + torch_attributes, +) + + +@pytest.mark.parametrize( + "reply_channel, model, input, output, output_descriptors, custom_attributes", + [ + pytest.param( + "reply channel", + model_key, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1], + torch_attributes, + ), + pytest.param( + "another reply channel", + model, + [input_key1], + [output_key2], + [output_descriptor1], + tf_attributes, + ), + pytest.param( + "another reply channel", + model, + [input_key1], + [output_key2], + [output_descriptor1], + torch_attributes, + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [output_key1], + [output_descriptor1], + None, + ), + ], +) +def test_build_request_indirect_successful( + reply_channel, model, input, output, output_descriptors, custom_attributes +): + built_request = MessageHandler.build_request( + reply_channel, + model, + input, + output, + output_descriptors, + custom_attributes, + ) + assert built_request is not None + assert built_request.replyChannel.descriptor == reply_channel + if built_request.model.which() == "key": + assert built_request.model.key.key == model.key + else: + assert built_request.model.data.data == model.data + assert built_request.model.data.name == model.name + assert built_request.model.data.version == model.version + assert built_request.input.which() == "keys" + assert built_request.input.keys[0].key == input[0].key + assert len(built_request.input.keys) == len(input) + assert len(built_request.output) == len(output) + for i, j in zip(built_request.outputDescriptors, output_descriptors): + assert i.order == j.order + if built_request.customAttributes.which() == "tf": + assert ( + built_request.customAttributes.tf.tensorType == custom_attributes.tensorType + ) + elif built_request.customAttributes.which() == "torch": + assert ( + built_request.customAttributes.torch.tensorType + == custom_attributes.tensorType + ) + else: + assert built_request.customAttributes.none == custom_attributes + + +@pytest.mark.parametrize( + "reply_channel, model, input, output, output_descriptors, custom_attributes", + [ + pytest.param( + [], + model_key, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1], + tf_attributes, + id="bad channel", + ), + pytest.param( + "reply channel", + "bad model", + [input_key1], + [output_key2], + [output_descriptor1], + torch_attributes, + id="bad model", + ), + pytest.param( + "reply channel", + model_key, + ["input_key1", "input_key2"], + [output_key1, output_key2], + [output_descriptor1], + tf_attributes, + id="bad inputs", + ), + pytest.param( + "reply channel", + model_key, + [torch_attributes], + [output_key1, output_key2], + [output_descriptor1], + torch_attributes, + id="bad input schema type", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + ["output_key1", "output_key2"], + [output_descriptor1], + tf_attributes, + id="bad outputs", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [torch_attributes], + [output_descriptor1], + tf_attributes, + id="bad output schema type", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [output_key1, output_key2], + [output_descriptor1], + "bad attributes", + id="bad custom attributes", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [output_key1, output_key2], + [output_descriptor1], + model_key, + id="bad custom attributes schema type", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [output_key1, output_key2], + "bad descriptors", + torch_attributes, + id="bad output descriptors", + ), + ], +) +def test_build_request_indirect_unsuccessful( + reply_channel, model, input, output, output_descriptors, custom_attributes +): + with pytest.raises(ValueError): + built_request = MessageHandler.build_request( + reply_channel, + model, + input, + output, + output_descriptors, + custom_attributes, + ) + + +@pytest.mark.parametrize( + "reply_channel, model, input, output, output_descriptors, custom_attributes", + [ + pytest.param( + "reply channel", + model_key, + [tensor_1, tensor_2], + [], + [output_descriptor2], + torch_attributes, + ), + pytest.param( + "another reply channel", + model, + [tensor_1], + [], + [output_descriptor3], + tf_attributes, + ), + pytest.param( + "another reply channel", + model, + [tensor_2], + [], + [output_descriptor1], + tf_attributes, + ), + pytest.param( + "another reply channel", + model, + [tensor_1], + [], + [output_descriptor1], + None, + ), + ], +) +def test_build_request_direct_successful( + reply_channel, model, input, output, output_descriptors, custom_attributes +): + built_request = MessageHandler.build_request( + reply_channel, + model, + input, + output, + output_descriptors, + custom_attributes, + ) + assert built_request is not None + assert built_request.replyChannel.descriptor == reply_channel + if built_request.model.which() == "key": + assert built_request.model.key.key == model.key + else: + assert built_request.model.data.data == model.data + assert built_request.model.data.name == model.name + assert built_request.model.data.version == model.version + assert built_request.input.which() == "descriptors" + assert len(built_request.input.descriptors) == len(input) + assert len(built_request.output) == len(output) + for i, j in zip(built_request.outputDescriptors, output_descriptors): + assert i.order == j.order + if built_request.customAttributes.which() == "tf": + assert ( + built_request.customAttributes.tf.tensorType == custom_attributes.tensorType + ) + elif built_request.customAttributes.which() == "torch": + assert ( + built_request.customAttributes.torch.tensorType + == custom_attributes.tensorType + ) + else: + assert built_request.customAttributes.none == custom_attributes + + +@pytest.mark.parametrize( + "reply_channel, model, input, output, output_descriptors, custom_attributes", + [ + pytest.param( + [], + model_key, + [tensor_3, tensor_4], + [], + [output_descriptor2], + tf_attributes, + id="bad channel", + ), + pytest.param( + b"reply channel", + "bad model", + [tensor_4], + [], + [output_descriptor2], + tf_attributes, + id="bad model", + ), + pytest.param( + b"reply channel", + model_key, + ["input_key1", "input_key2"], + [], + [output_descriptor2], + torch_attributes, + id="bad inputs", + ), + pytest.param( + b"reply channel", + model_key, + [], + ["output_key1", "output_key2"], + [output_descriptor2], + tf_attributes, + id="bad outputs", + ), + pytest.param( + b"reply channel", + model_key, + [tensor_4], + [], + [output_descriptor2], + "bad attributes", + id="bad custom attributes", + ), + pytest.param( + b"reply_channel", + model_key, + [tensor_3, tensor_4], + [], + ["output_descriptor2"], + torch_attributes, + id="bad output descriptors", + ), + ], +) +def test_build_request_direct_unsuccessful( + reply_channel, model, input, output, output_descriptors, custom_attributes +): + with pytest.raises(ValueError): + built_request = MessageHandler.build_request( + reply_channel, + model, + input, + output, + output_descriptors, + custom_attributes, + ) + + +@pytest.mark.parametrize( + "req", + [ + pytest.param(tf_indirect_request, id="tf indirect"), + pytest.param(tf_direct_request, id="tf direct"), + pytest.param(torch_indirect_request, id="indirect"), + pytest.param(torch_direct_request, id="direct"), + ], +) +def test_serialize_request_successful(req): + serialized = MessageHandler.serialize_request(req) + assert type(serialized) == bytes + + deserialized = MessageHandler.deserialize_request(serialized) + assert deserialized.to_dict() == req.to_dict() + + +def test_serialization_fails(): + with pytest.raises(ValueError): + bad_request = MessageHandler.serialize_request(tensor_1) + + +def test_deserialization_fails(): + with pytest.raises(ValueError): + new_req = torch_direct_request.copy() + req_bytes = MessageHandler.serialize_request(new_req) + req_bytes = req_bytes + b"extra bytes" + deser = MessageHandler.deserialize_request(req_bytes) diff --git a/tests/test_message_handler/test_response.py b/tests/test_message_handler/test_response.py new file mode 100644 index 0000000000..86774132ec --- /dev/null +++ b/tests/test_message_handler/test_response.py @@ -0,0 +1,191 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +fsd = "mock-feature-store-descriptor" + +result_key1 = MessageHandler.build_tensor_key("result_key1", fsd) +result_key2 = MessageHandler.build_tensor_key("result_key2", fsd) + +torch_attributes = MessageHandler.build_torch_response_attributes() +tf_attributes = MessageHandler.build_tf_response_attributes() + +tensor1 = MessageHandler.build_tensor_descriptor("c", "int8", [1]) +tensor2 = MessageHandler.build_tensor_descriptor("c", "int64", [3, 2]) + + +tf_indirect_response = MessageHandler.build_response( + "complete", + "Success!", + [result_key1, result_key2], + tf_attributes, +) + +tf_direct_response = MessageHandler.build_response( + "complete", + "Success again!", + [tensor2, tensor1], + tf_attributes, +) + +torch_indirect_response = MessageHandler.build_response( + "complete", + "Success!", + [result_key1, result_key2], + torch_attributes, +) + +torch_direct_response = MessageHandler.build_response( + "complete", + "Success again!", + [tensor1, tensor2], + torch_attributes, +) + + +@pytest.mark.parametrize( + "status, status_message, result, custom_attribute", + [ + pytest.param( + 200, + "Yay, it worked!", + [tensor1, tensor2], + None, + id="tensor descriptor list", + ), + pytest.param( + 200, + "Yay, it worked!", + [result_key1, result_key2], + tf_attributes, + id="tensor key list", + ), + ], +) +def test_build_response_successful(status, status_message, result, custom_attribute): + response = MessageHandler.build_response( + status=status, + message=status_message, + result=result, + custom_attributes=custom_attribute, + ) + assert response is not None + assert response.status == status + assert response.message == status_message + if response.result.which() == "keys": + assert response.result.keys[0].to_dict() == result[0].to_dict() + else: + assert response.result.descriptors[0].to_dict() == result[0].to_dict() + + +@pytest.mark.parametrize( + "status, status_message, result, custom_attribute", + [ + pytest.param( + "bad status", + "Yay, it worked!", + [tensor1, tensor2], + None, + id="bad status", + ), + pytest.param( + "complete", + 200, + [tensor2], + torch_attributes, + id="bad status message", + ), + pytest.param( + "complete", + "Yay, it worked!", + ["result_key1", "result_key2"], + tf_attributes, + id="bad result", + ), + pytest.param( + "complete", + "Yay, it worked!", + [tf_attributes], + tf_attributes, + id="bad result type", + ), + pytest.param( + "complete", + "Yay, it worked!", + [tensor2, tensor1], + "custom attributes", + id="bad custom attributes", + ), + pytest.param( + "complete", + "Yay, it worked!", + [tensor2, tensor1], + result_key1, + id="bad custom attributes type", + ), + ], +) +def test_build_response_unsuccessful(status, status_message, result, custom_attribute): + with pytest.raises(ValueError): + response = MessageHandler.build_response( + status, status_message, result, custom_attribute + ) + + +@pytest.mark.parametrize( + "response", + [ + pytest.param(torch_indirect_response, id="indirect"), + pytest.param(torch_direct_response, id="direct"), + pytest.param(tf_indirect_response, id="tf indirect"), + pytest.param(tf_direct_response, id="tf direct"), + ], +) +def test_serialize_response(response): + serialized = MessageHandler.serialize_response(response) + assert type(serialized) == bytes + + deserialized = MessageHandler.deserialize_response(serialized) + assert deserialized.to_dict() == response.to_dict() + + +def test_serialization_fails(): + with pytest.raises(ValueError): + bad_response = MessageHandler.serialize_response(result_key1) + + +def test_deserialization_fails(): + with pytest.raises(ValueError): + new_resp = torch_direct_response.copy() + resp_bytes = MessageHandler.serialize_response(new_resp) + resp_bytes = resp_bytes + b"extra bytes" + deser = MessageHandler.deserialize_response(resp_bytes) diff --git a/tests/test_node_prioritizer.py b/tests/test_node_prioritizer.py new file mode 100644 index 0000000000..905c0ecc90 --- /dev/null +++ b/tests/test_node_prioritizer.py @@ -0,0 +1,553 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import random +import threading +import typing as t + +import pytest + +from smartsim._core.launcher.dragon.pqueue import NodePrioritizer, PrioritizerFilter +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +# The tests in this file belong to the group_b group +pytestmark = pytest.mark.group_b + + +logger = get_logger(__name__) + + +class MockNode: + def __init__(self, hostname: str, num_cpus: int, num_gpus: int) -> None: + self.hostname = hostname + self.num_cpus = num_cpus + self.num_gpus = num_gpus + + +def mock_node_hosts( + num_cpu_nodes: int, num_gpu_nodes: int +) -> t.Tuple[t.List[MockNode], t.List[MockNode]]: + cpu_hosts = [f"cpu-node-{i}" for i in range(num_cpu_nodes)] + gpu_hosts = [f"gpu-node-{i}" for i in range(num_gpu_nodes)] + + return cpu_hosts, gpu_hosts + + +def mock_node_builder(num_cpu_nodes: int, num_gpu_nodes: int) -> t.List[MockNode]: + nodes = [] + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + + nodes.extend(MockNode(hostname, 4, 0) for hostname in cpu_hosts) + nodes.extend(MockNode(hostname, 4, 4) for hostname in gpu_hosts) + + return nodes + + +def test_node_prioritizer_init_null() -> None: + """Verify that the priorizer reports failures to send a valid node set + if a null value is passed""" + lock = threading.RLock() + with pytest.raises(SmartSimError) as ex: + NodePrioritizer(None, lock) + + assert "Missing" in ex.value.args[0] + + +def test_node_prioritizer_init_empty() -> None: + """Verify that the priorizer reports failures to send a valid node set + if an empty list is passed""" + lock = threading.RLock() + with pytest.raises(SmartSimError) as ex: + NodePrioritizer([], lock) + + assert "Missing" in ex.value.args[0] + + +@pytest.mark.parametrize( + "num_cpu_nodes,num_gpu_nodes", [(1, 1), (2, 1), (1, 2), (8, 4), (1000, 200)] +) +def test_node_prioritizer_init_ok(num_cpu_nodes: int, num_gpu_nodes: int) -> None: + """Verify that initialization with a valid node list results in the + appropriate cpu & gpu ref counts, and complete ref map""" + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + # perform prioritizer initialization + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # get a copy of all the expected host names + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + all_hosts = cpu_hosts + gpu_hosts + assert len(all_hosts) == num_cpu_nodes + num_gpu_nodes + + # verify tracking data is initialized correctly for all nodes + for hostname in all_hosts: + # show that the ref map is tracking the node + assert hostname in p._nodes + + tracking_info = p.get_tracking_info(hostname) + + # show that the node is created w/zero ref counts + assert tracking_info.num_refs == 0 + + # show that the node is created and marked as not dirty (unchanged) + # assert tracking_info.is_dirty == False + + # iterate through known cpu node keys and verify prioritizer initialization + for hostname in cpu_hosts: + # show that the device ref counters are appropriately assigned + cpu_ref = next((n for n in p._cpu_refs if n.hostname == hostname), None) + assert cpu_ref, "CPU-only node not found in cpu ref set" + + gpu_ref = next((n for n in p._gpu_refs if n.hostname == hostname), None) + assert not gpu_ref, "CPU-only node should not be found in gpu ref set" + + # iterate through known GPU node keys and verify prioritizer initialization + for hostname in gpu_hosts: + # show that the device ref counters are appropriately assigned + gpu_ref = next((n for n in p._gpu_refs if n.hostname == hostname), None) + assert gpu_ref, "GPU-only node not found in gpu ref set" + + cpu_ref = next((n for n in p._cpu_refs if n.hostname == hostname), None) + assert not cpu_ref, "GPU-only node should not be found in cpu ref set" + + # verify we have all hosts in the ref map + assert set(p._nodes.keys()) == set(all_hosts) + + # verify we have no extra hosts in ref map + assert len(p._nodes.keys()) == len(set(all_hosts)) + + +def test_node_prioritizer_direct_increment() -> None: + """Verify that performing the increment operation causes the expected + side effect on the intended records""" + + num_cpu_nodes, num_gpu_nodes = 32, 8 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + exclude_index = 2 + exclude_host0 = cpu_hosts[exclude_index] + exclude_host1 = gpu_hosts[exclude_index] + exclusions = [exclude_host0, exclude_host1] + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # let's increment each element in a predictable way and verify + for node in nodes: + if node.hostname in exclusions: + # expect 1 cpu and 1 gpu node at zero and not incremented + continue + + if node.num_gpus == 0: + num_increments = random.randint(0, num_cpu_nodes - 1) + else: + num_increments = random.randint(0, num_gpu_nodes - 1) + + # increment this node some random number of times + for _ in range(num_increments): + p.increment(node.hostname) + + # ... and verify the correct incrementing is applied + tracking_info = p.get_tracking_info(node.hostname) + assert tracking_info.num_refs == num_increments + + # verify the excluded cpu node was never changed + tracking_info0 = p.get_tracking_info(exclude_host0) + assert tracking_info0.num_refs == 0 + + # verify the excluded gpu node was never changed + tracking_info1 = p.get_tracking_info(exclude_host1) + assert tracking_info1.num_refs == 0 + + +def test_node_prioritizer_indirect_increment() -> None: + """Verify that performing the increment operation indirectly affects + each available node until we run out of nodes to return""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # verify starting state + for node in p._nodes.values(): + tracking_info = p.get_tracking_info(node.hostname) + + assert node.num_refs == 0 # <--- ref count starts at zero + assert tracking_info.num_refs == 0 # <--- ref count starts at zero + + # perform indirect + for node in p._nodes.values(): + tracking_info = p.get_tracking_info(node.hostname) + + # apply `next` operation and verify tracking info reflects new ref + node = p.next(PrioritizerFilter.CPU) + tracking_info = p.get_tracking_info(node.hostname) + + # verify side-effects + assert tracking_info.num_refs > 0 # <--- ref count should now be > 0 + + # we expect it to give back only "clean" nodes from next* + assert tracking_info.is_dirty == False # NOTE: this is "hidden" by protocol + + # every node should be incremented now. prioritizer shouldn't have anything to give + tracking_info = p.next(PrioritizerFilter.CPU) + assert tracking_info is None # <--- get_next shouldn't have any nodes to give + + +def test_node_prioritizer_indirect_decrement_availability() -> None: + """Verify that a node who is decremented (dirty) is made assignable + on a subsequent request""" + + num_cpu_nodes, num_gpu_nodes = 1, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # increment our only node... + p.increment(cpu_hosts[0]) + + tracking_info = p.next() + assert tracking_info is None, "No nodes should be assignable" + + # perform a decrement... + p.decrement(cpu_hosts[0]) + + # ... and confirm that the node is available again + tracking_info = p.next() + assert tracking_info is not None, "A node should be assignable" + + +def test_node_prioritizer_multi_increment() -> None: + """Verify that retrieving multiple nodes via `next_n` API correctly + increments reference counts and returns appropriate results""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + assert p.get_tracking_info(cpu_hosts[0]).num_refs > 0 + + p.increment(cpu_hosts[2]) + assert p.get_tracking_info(cpu_hosts[2]).num_refs > 0 + + p.increment(cpu_hosts[4]) + assert p.get_tracking_info(cpu_hosts[4]).num_refs > 0 + + # use next_n w/the minimum allowed value + all_tracking_info = p.next_n(1, PrioritizerFilter.CPU) # <---- next_n(1) + + # confirm the number requested is honored + assert len(all_tracking_info) == 1 + # ensure no unavailable node is returned + assert all_tracking_info[0].hostname not in [ + cpu_hosts[0], + cpu_hosts[2], + cpu_hosts[4], + ] + + # use next_n w/value that exceeds available number of open nodes + # 3 direct increments in setup, 1 out of next_n(1), 4 left + all_tracking_info = p.next_n(5, PrioritizerFilter.CPU) + + # confirm that no nodes are returned, even though 4 out of 5 requested are available + assert len(all_tracking_info) == 0 + + +def test_node_prioritizer_multi_increment_validate_n() -> None: + """Verify that retrieving multiple nodes via `next_n` API correctly + reports failures when the request size is above pool size""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # we have 8 total cpu nodes available... request too many nodes + all_tracking_info = p.next_n(9, PrioritizerFilter.CPU) + assert len(all_tracking_info) == 0 + + all_tracking_info = p.next_n(num_cpu_nodes * 1000, PrioritizerFilter.CPU) + assert len(all_tracking_info) == 0 + + +def test_node_prioritizer_indirect_direct_interleaved_increments() -> None: + """Verify that interleaving indirect and direct increments results in + expected ref counts""" + + num_cpu_nodes, num_gpu_nodes = 8, 4 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # perform some set of non-popped increments + p.increment(gpu_hosts[1]) + p.increment(gpu_hosts[3]) + p.increment(gpu_hosts[3]) + + # increment 0th item 1x + p.increment(cpu_hosts[0]) + + # increment 3th item 2x + p.increment(cpu_hosts[3]) + p.increment(cpu_hosts[3]) + + # increment last item 3x + p.increment(cpu_hosts[7]) + p.increment(cpu_hosts[7]) + p.increment(cpu_hosts[7]) + + tracking_info = p.get_tracking_info(gpu_hosts[1]) + assert tracking_info.num_refs == 1 + + tracking_info = p.get_tracking_info(gpu_hosts[3]) + assert tracking_info.num_refs == 2 + + nodes = [n for n in p._nodes.values() if n.num_refs == 0 and n.num_gpus == 0] + + # we should skip the 0-th item in the heap due to direct increment + tracking_info = p.next(PrioritizerFilter.CPU) + assert tracking_info.num_refs == 1 + # confirm we get a cpu node + assert "cpu-node" in tracking_info.hostname + + # this should pull the next item right out + tracking_info = p.next(PrioritizerFilter.CPU) + assert tracking_info.num_refs == 1 + assert "cpu-node" in tracking_info.hostname + + # ensure we pull from gpu nodes and the 0th item is returned + tracking_info = p.next(PrioritizerFilter.GPU) + assert tracking_info.num_refs == 1 + assert "gpu-node" in tracking_info.hostname + + # we should step over the 3-th node on this iteration + tracking_info = p.next(PrioritizerFilter.CPU) + assert tracking_info.num_refs == 1 + assert "cpu-node" in tracking_info.hostname + + # and ensure that heap also steps over a direct increment + tracking_info = p.next(PrioritizerFilter.GPU) + assert tracking_info.num_refs == 1 + assert "gpu-node" in tracking_info.hostname + + # and another GPU request should return nothing + tracking_info = p.next(PrioritizerFilter.GPU) + assert tracking_info is None + + +def test_node_prioritizer_decrement_floor() -> None: + """Verify that repeatedly decrementing ref counts does not + allow negative ref counts""" + + num_cpu_nodes, num_gpu_nodes = 8, 4 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # try a ton of decrements on all the items in the prioritizer + for _ in range(len(nodes) * 100): + index = random.randint(0, num_cpu_nodes - 1) + p.decrement(cpu_hosts[index]) + + index = random.randint(0, num_gpu_nodes - 1) + p.decrement(gpu_hosts[index]) + + for node in nodes: + tracking_info = p.get_tracking_info(node.hostname) + assert tracking_info.num_refs == 0 + + +@pytest.mark.parametrize("num_requested", [1, 2, 3]) +def test_node_prioritizer_multi_increment_subheap(num_requested: int) -> None: + """Verify that retrieving multiple nodes via `next_n` API correctly + increments reference counts and returns appropriate results + when requesting an in-bounds number of nodes""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + p.increment(cpu_hosts[4]) + + hostnames = [cpu_hosts[0], cpu_hosts[1], cpu_hosts[2], cpu_hosts[3], cpu_hosts[5]] + + # request n == {num_requested} nodes from set of 3 available + all_tracking_info = p.next_n( + num_requested, + hosts=hostnames, + ) # <---- w/0,2,4 assigned, only 1,3,5 from hostnames can work + + # all parameterizations should result in a matching output size + assert len(all_tracking_info) == num_requested + + +def test_node_prioritizer_multi_increment_subheap_assigned() -> None: + """Verify that retrieving multiple nodes via `next_n` API does + not return anything when the number requested cannot be satisfied + by the given subheap due to prior assignment""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + hostnames = [ + cpu_hosts[0], + "x" + cpu_hosts[2], + ] # <--- we can't get 2 from 1 valid node name + + # request n == {num_requested} nodes from set of 3 available + num_requested = 2 + all_tracking_info = p.next_n(num_requested, hosts=hostnames) + + # w/0,2 assigned, nothing can be returned + assert len(all_tracking_info) == 0 + + +def test_node_prioritizer_empty_subheap_next_w_no_hosts() -> None: + """Verify that retrieving multiple nodes via `next_n` API does + with an empty host list uses the entire available host list""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + hostnames = [] + + # request n == {num_requested} nodes from set of 3 available + num_requested = 1 + node = p.next(hosts=hostnames) + assert node + + # assert "No hostnames provided" == ex.value.args[0] + + +def test_node_prioritizer_empty_subheap_next_n_w_hosts() -> None: + """Verify that retrieving multiple nodes via `next_n` API does + not blow up with an empty host list""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + hostnames = [] + + # request n == {num_requested} nodes from set of 3 available + num_requested = 1 + node = p.next_n(num_requested, hosts=hostnames) + assert node is not None + + +@pytest.mark.parametrize("num_requested", [-100, -1, 0]) +def test_node_prioritizer_empty_subheap_next_n(num_requested: int) -> None: + """Verify that retrieving a node via `next_n` API does + not allow a request with num_items < 1""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + # request n == {num_requested} nodes from set of 3 available + with pytest.raises(ValueError) as ex: + p.next_n(num_requested) + + assert "Number of items requested" in ex.value.args[0] + + +@pytest.mark.parametrize("num_requested", [-100, -1, 0]) +def test_node_prioritizer_empty_subheap_next_n(num_requested: int) -> None: + """Verify that retrieving multiple nodes via `next_n` API does + not allow a request with num_items < 1""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + hostnames = [cpu_hosts[0], cpu_hosts[2]] + + # request n == {num_requested} nodes from set of 3 available + with pytest.raises(ValueError) as ex: + p.next_n(num_requested, hosts=hostnames) + + assert "Number of items requested" in ex.value.args[0]