From 8a19dee1fd37be9e04c14a9d486c3314a676dbe3 Mon Sep 17 00:00:00 2001 From: Chris McBride Date: Sat, 19 Oct 2024 14:08:06 -0400 Subject: [PATCH 1/5] Merge MLI feature branch into v1.0 branch (#754) Create a v1.0 branch to combine ongoing efforts in `mli-feature` and `smartsim-refactor` feature branches --------- Co-authored-by: Alyssa Cote <46540273+AlyssaCote@users.noreply.github.com> Co-authored-by: Al Rigazzi --- .github/workflows/run_tests.yml | 27 +- Makefile | 13 +- conftest.py | 2 +- doc/changelog.md | 33 + doc/installation_instructions/basic.rst | 14 + ex/high_throughput_inference/mli_driver.py | 77 ++ ex/high_throughput_inference/mock_app.py | 142 +++ .../mock_app_redis.py | 90 ++ ex/high_throughput_inference/redis_driver.py | 66 ++ .../standalone_worker_manager.py | 218 +++++ pyproject.toml | 1 + setup.py | 2 + smartsim/_core/_cli/build.py | 54 +- smartsim/_core/_cli/scripts/dragon_install.py | 368 ++++++-- smartsim/_core/_install/builder.py | 8 +- smartsim/_core/config/config.py | 35 +- smartsim/_core/entrypoints/service.py | 185 ++++ .../_core/launcher/dragon/dragonBackend.py | 400 +++++++-- .../_core/launcher/dragon/dragonConnector.py | 85 +- .../_core/launcher/dragon/dragonLauncher.py | 2 + smartsim/_core/launcher/dragon/pqueue.py | 461 ++++++++++ smartsim/_core/launcher/step/dragonStep.py | 2 + smartsim/_core/mli/__init__.py | 0 smartsim/_core/mli/client/__init__.py | 0 smartsim/_core/mli/client/protoclient.py | 348 ++++++++ smartsim/_core/mli/comm/channel/__init__.py | 0 smartsim/_core/mli/comm/channel/channel.py | 82 ++ .../_core/mli/comm/channel/dragon_channel.py | 127 +++ smartsim/_core/mli/comm/channel/dragon_fli.py | 158 ++++ .../_core/mli/comm/channel/dragon_util.py | 131 +++ smartsim/_core/mli/infrastructure/__init__.py | 0 .../_core/mli/infrastructure/comm/__init__.py | 0 .../mli/infrastructure/comm/broadcaster.py | 239 +++++ .../_core/mli/infrastructure/comm/consumer.py | 281 ++++++ .../_core/mli/infrastructure/comm/event.py | 162 ++++ .../_core/mli/infrastructure/comm/producer.py | 44 + .../mli/infrastructure/control/__init__.py | 0 .../infrastructure/control/device_manager.py | 166 ++++ .../infrastructure/control/error_handling.py | 78 ++ .../mli/infrastructure/control/listener.py | 352 ++++++++ .../control/request_dispatcher.py | 559 ++++++++++++ .../infrastructure/control/worker_manager.py | 330 +++++++ .../mli/infrastructure/environment_loader.py | 116 +++ .../mli/infrastructure/storage/__init__.py | 0 .../storage/backbone_feature_store.py | 259 ++++++ .../storage/dragon_feature_store.py | 126 +++ .../mli/infrastructure/storage/dragon_util.py | 101 +++ .../infrastructure/storage/feature_store.py | 224 +++++ .../mli/infrastructure/worker/__init__.py | 0 .../mli/infrastructure/worker/torch_worker.py | 276 ++++++ .../_core/mli/infrastructure/worker/worker.py | 646 ++++++++++++++ smartsim/_core/mli/message_handler.py | 602 +++++++++++++ .../mli_schemas/data/data_references.capnp | 37 + .../mli_schemas/data/data_references_capnp.py | 41 + .../data/data_references_capnp.pyi | 107 +++ .../_core/mli/mli_schemas/model/__init__.py | 0 .../_core/mli/mli_schemas/model/model.capnp | 33 + .../mli/mli_schemas/model/model_capnp.py | 38 + .../mli/mli_schemas/model/model_capnp.pyi | 72 ++ .../mli/mli_schemas/request/request.capnp | 55 ++ .../request_attributes.capnp | 49 + .../request_attributes_capnp.py | 41 + .../request_attributes_capnp.pyi | 109 +++ .../mli/mli_schemas/request/request_capnp.py | 41 + .../mli/mli_schemas/request/request_capnp.pyi | 319 +++++++ .../mli/mli_schemas/response/response.capnp | 52 ++ .../response_attributes.capnp | 33 + .../response_attributes_capnp.py | 41 + .../response_attributes_capnp.pyi | 103 +++ .../mli_schemas/response/response_capnp.py | 38 + .../mli_schemas/response/response_capnp.pyi | 212 +++++ .../_core/mli/mli_schemas/tensor/tensor.capnp | 75 ++ .../mli/mli_schemas/tensor/tensor_capnp.py | 41 + .../mli/mli_schemas/tensor/tensor_capnp.pyi | 142 +++ smartsim/_core/schemas/utils.py | 4 +- smartsim/_core/utils/helpers.py | 2 +- smartsim/_core/utils/timings.py | 175 ++++ smartsim/database/orchestrator.py | 4 +- smartsim/experiment.py | 2 +- smartsim/log.py | 13 +- smartsim/settings/dragonRunSettings.py | 20 + tests/dragon/__init__.py | 0 tests/dragon/channel.py | 125 +++ tests/dragon/conftest.py | 129 +++ tests/dragon/feature_store.py | 152 ++++ .../test_core_machine_learning_worker.py | 377 ++++++++ tests/dragon/test_device_manager.py | 186 ++++ tests/dragon/test_dragon_backend.py | 308 +++++++ tests/dragon/test_dragon_ddict_utils.py | 117 +++ tests/dragon/test_environment_loader.py | 147 +++ tests/dragon/test_error_handling.py | 511 +++++++++++ tests/dragon/test_event_consumer.py | 386 ++++++++ tests/dragon/test_featurestore.py | 327 +++++++ tests/dragon/test_featurestore_base.py | 844 ++++++++++++++++++ tests/dragon/test_featurestore_integration.py | 213 +++++ tests/dragon/test_inference_reply.py | 76 ++ tests/dragon/test_inference_request.py | 118 +++ tests/dragon/test_protoclient.py | 313 +++++++ tests/dragon/test_reply_building.py | 64 ++ tests/dragon/test_request_dispatcher.py | 233 +++++ tests/dragon/test_torch_worker.py | 221 +++++ tests/dragon/test_worker_manager.py | 314 +++++++ tests/dragon/utils/__init__.py | 0 tests/dragon/utils/channel.py | 125 +++ tests/dragon/utils/msg_pump.py | 225 +++++ tests/dragon/utils/worker.py | 104 +++ tests/mli/__init__.py | 0 tests/mli/channel.py | 125 +++ tests/mli/feature_store.py | 144 +++ tests/mli/test_integrated_torch_worker.py | 275 ++++++ tests/mli/test_service.py | 290 ++++++ tests/mli/worker.py | 104 +++ tests/on_wlm/test_dragon.py | 2 +- tests/test_config.py | 28 +- tests/test_dragon_comm_utils.py | 257 ++++++ tests/test_dragon_installer.py | 142 ++- tests/test_dragon_launcher.py | 30 +- tests/test_dragon_run_policy.py | 9 - tests/test_dragon_run_request.py | 380 ++++---- tests/test_dragon_runsettings.py | 119 +++ tests/test_dragon_step.py | 13 + tests/test_message_handler/__init__.py | 0 .../test_message_handler/test_build_model.py | 72 ++ .../test_build_model_key.py | 47 + .../test_build_request_attributes.py | 55 ++ .../test_build_tensor_desc.py | 90 ++ .../test_build_tensor_key.py | 46 + .../test_output_descriptor.py | 78 ++ tests/test_message_handler/test_request.py | 449 ++++++++++ tests/test_message_handler/test_response.py | 191 ++++ tests/test_node_prioritizer.py | 553 ++++++++++++ 131 files changed, 18278 insertions(+), 427 deletions(-) create mode 100644 ex/high_throughput_inference/mli_driver.py create mode 100644 ex/high_throughput_inference/mock_app.py create mode 100644 ex/high_throughput_inference/mock_app_redis.py create mode 100644 ex/high_throughput_inference/redis_driver.py create mode 100644 ex/high_throughput_inference/standalone_worker_manager.py create mode 100644 smartsim/_core/entrypoints/service.py create mode 100644 smartsim/_core/launcher/dragon/pqueue.py create mode 100644 smartsim/_core/mli/__init__.py create mode 100644 smartsim/_core/mli/client/__init__.py create mode 100644 smartsim/_core/mli/client/protoclient.py create mode 100644 smartsim/_core/mli/comm/channel/__init__.py create mode 100644 smartsim/_core/mli/comm/channel/channel.py create mode 100644 smartsim/_core/mli/comm/channel/dragon_channel.py create mode 100644 smartsim/_core/mli/comm/channel/dragon_fli.py create mode 100644 smartsim/_core/mli/comm/channel/dragon_util.py create mode 100644 smartsim/_core/mli/infrastructure/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/comm/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/comm/broadcaster.py create mode 100644 smartsim/_core/mli/infrastructure/comm/consumer.py create mode 100644 smartsim/_core/mli/infrastructure/comm/event.py create mode 100644 smartsim/_core/mli/infrastructure/comm/producer.py create mode 100644 smartsim/_core/mli/infrastructure/control/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/control/device_manager.py create mode 100644 smartsim/_core/mli/infrastructure/control/error_handling.py create mode 100644 smartsim/_core/mli/infrastructure/control/listener.py create mode 100644 smartsim/_core/mli/infrastructure/control/request_dispatcher.py create mode 100644 smartsim/_core/mli/infrastructure/control/worker_manager.py create mode 100644 smartsim/_core/mli/infrastructure/environment_loader.py create mode 100644 smartsim/_core/mli/infrastructure/storage/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py create mode 100644 smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py create mode 100644 smartsim/_core/mli/infrastructure/storage/dragon_util.py create mode 100644 smartsim/_core/mli/infrastructure/storage/feature_store.py create mode 100644 smartsim/_core/mli/infrastructure/worker/__init__.py create mode 100644 smartsim/_core/mli/infrastructure/worker/torch_worker.py create mode 100644 smartsim/_core/mli/infrastructure/worker/worker.py create mode 100644 smartsim/_core/mli/message_handler.py create mode 100644 smartsim/_core/mli/mli_schemas/data/data_references.capnp create mode 100644 smartsim/_core/mli/mli_schemas/data/data_references_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/model/__init__.py create mode 100644 smartsim/_core/mli/mli_schemas/model/model.capnp create mode 100644 smartsim/_core/mli/mli_schemas/model/model_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/model/model_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/request/request.capnp create mode 100644 smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp create mode 100644 smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/request/request_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/request/request_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/response/response.capnp create mode 100644 smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp create mode 100644 smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/response/response_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/response/response_capnp.pyi create mode 100644 smartsim/_core/mli/mli_schemas/tensor/tensor.capnp create mode 100644 smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py create mode 100644 smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi create mode 100644 smartsim/_core/utils/timings.py create mode 100644 tests/dragon/__init__.py create mode 100644 tests/dragon/channel.py create mode 100644 tests/dragon/conftest.py create mode 100644 tests/dragon/feature_store.py create mode 100644 tests/dragon/test_core_machine_learning_worker.py create mode 100644 tests/dragon/test_device_manager.py create mode 100644 tests/dragon/test_dragon_backend.py create mode 100644 tests/dragon/test_dragon_ddict_utils.py create mode 100644 tests/dragon/test_environment_loader.py create mode 100644 tests/dragon/test_error_handling.py create mode 100644 tests/dragon/test_event_consumer.py create mode 100644 tests/dragon/test_featurestore.py create mode 100644 tests/dragon/test_featurestore_base.py create mode 100644 tests/dragon/test_featurestore_integration.py create mode 100644 tests/dragon/test_inference_reply.py create mode 100644 tests/dragon/test_inference_request.py create mode 100644 tests/dragon/test_protoclient.py create mode 100644 tests/dragon/test_reply_building.py create mode 100644 tests/dragon/test_request_dispatcher.py create mode 100644 tests/dragon/test_torch_worker.py create mode 100644 tests/dragon/test_worker_manager.py create mode 100644 tests/dragon/utils/__init__.py create mode 100644 tests/dragon/utils/channel.py create mode 100644 tests/dragon/utils/msg_pump.py create mode 100644 tests/dragon/utils/worker.py create mode 100644 tests/mli/__init__.py create mode 100644 tests/mli/channel.py create mode 100644 tests/mli/feature_store.py create mode 100644 tests/mli/test_integrated_torch_worker.py create mode 100644 tests/mli/test_service.py create mode 100644 tests/mli/worker.py create mode 100644 tests/test_dragon_comm_utils.py create mode 100644 tests/test_message_handler/__init__.py create mode 100644 tests/test_message_handler/test_build_model.py create mode 100644 tests/test_message_handler/test_build_model_key.py create mode 100644 tests/test_message_handler/test_build_request_attributes.py create mode 100644 tests/test_message_handler/test_build_tensor_desc.py create mode 100644 tests/test_message_handler/test_build_tensor_key.py create mode 100644 tests/test_message_handler/test_output_descriptor.py create mode 100644 tests/test_message_handler/test_request.py create mode 100644 tests/test_message_handler/test_response.py create mode 100644 tests/test_node_prioritizer.py diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index e3c808410b..9b988520a4 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -54,7 +54,7 @@ jobs: strategy: fail-fast: false matrix: - subset: [backends, slow_tests, group_a, group_b] + subset: [backends, slow_tests, group_a, group_b, dragon] os: [macos-12, macos-14, ubuntu-22.04] # Operating systems compiler: [8] # GNU compiler version rai: [1.2.7] # Redis AI versions @@ -109,8 +109,24 @@ jobs: python -m pip install .[dev,mypy] - name: Install ML Runtimes + if: matrix.subset != 'dragon' run: smart build --device cpu -v + + - name: Install ML Runtimes (with dragon) + if: matrix.subset == 'dragon' + env: + SMARTSIM_DRAGON_TOKEN: ${{ secrets.DRAGON_TOKEN }} + run: | + if [ -n "${SMARTSIM_DRAGON_TOKEN}" ]; then + smart build --device cpu -v --dragon-repo dragonhpc/dragon-nightly --dragon-version 0.10 + else + smart build --device cpu -v --dragon + fi + SP=$(python -c 'import site; print(site.getsitepackages()[0])')/smartsim/_core/config/dragon/.env + LLP=$(cat $SP | grep LD_LIBRARY_PATH | awk '{split($0, array, "="); print array[2]}') + echo "LD_LIBRARY_PATH=$LLP:$LD_LIBRARY_PATH" >> $GITHUB_ENV + - name: Run mypy run: | make check-mypy @@ -134,9 +150,16 @@ jobs: echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ ./tests/backends + # Run pytest (dragon subtests) + - name: Run Dragon Pytest + if: (matrix.subset == 'dragon' && matrix.os == 'ubuntu-22.04') + run: | + echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV + dragon -s py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ -m ${{ matrix.subset }} ./tests + # Run pytest (test subsets) - name: Run Pytest - if: "!contains(matrix.subset, 'backends')" # if not running backend tests + if: (matrix.subset != 'backends' && matrix.subset != 'dragon') # if not running backend tests or dragon tests run: | echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ -m ${{ matrix.subset }} ./tests diff --git a/Makefile b/Makefile index 457bb040ac..4e64033d63 100644 --- a/Makefile +++ b/Makefile @@ -164,22 +164,22 @@ tutorials-prod: # help: test - Run all tests .PHONY: test test: - @python -m pytest --ignore=tests/full_wlm/ + @python -m pytest --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-verbose - Run all tests verbosely .PHONY: test-verbose test-verbose: - @python -m pytest -vv --ignore=tests/full_wlm/ + @python -m pytest -vv --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-debug - Run all tests with debug output .PHONY: test-debug test-debug: - @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ + @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-cov - Run all tests with coverage .PHONY: test-cov test-cov: - @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ + @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-full - Run all WLM tests with Python coverage (full test suite) @@ -192,3 +192,8 @@ test-full: .PHONY: test-wlm test-wlm: @python -m pytest -vv tests/full_wlm/ tests/on_wlm + +# help: test-dragon - Run dragon-specific tests +.PHONY: test-dragon +test-dragon: + @dragon pytest tests/dragon diff --git a/conftest.py b/conftest.py index 991c0d17b6..54a47f9e23 100644 --- a/conftest.py +++ b/conftest.py @@ -93,6 +93,7 @@ test_hostlist = None has_aprun = shutil.which("aprun") is not None + def get_account() -> str: return test_account @@ -227,7 +228,6 @@ def kill_all_test_spawned_processes() -> None: print("Not all processes were killed after test") - def get_hostlist() -> t.Optional[t.List[str]]: global test_hostlist if not test_hostlist: diff --git a/doc/changelog.md b/doc/changelog.md index 8f93a1ae2c..bca9209f7a 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -9,6 +9,39 @@ Jump to: ## SmartSim +### MLI branch + +Description + +- Implement asynchronous notifications for shared data +- Quick bug fix in _validate +- Add helper methods to MLI classes +- Update error handling for consistency +- Parameterize installation of dragon package with `smart build` +- Update docstrings +- Filenames conform to snake case +- Update SmartSim environment variables using new naming convention +- Refactor `exception_handler` +- Add RequestDispatcher and the possibility of batching inference requests +- Enable hostname selection for dragon tasks +- Remove pydantic dependency from MLI code +- Update MLI environment variables using new naming convention +- Reduce a copy by using torch.from_numpy instead of torch.tensor +- Enable dynamic feature store selection +- Fix dragon package installation bug +- Adjust schemas for better performance +- Add TorchWorker first implementation and mock inference app example +- Add error handling in Worker Manager pipeline +- Add EnvironmentConfigLoader for ML Worker Manager +- Add Model schema with model metadata included +- Removed device from schemas, MessageHandler and tests +- Add ML worker manager, sample worker, and feature store +- Add schemas and MessageHandler class for de/serialization of + inference requests and response messages + + +### Develop + To be released at some point in the future Description diff --git a/doc/installation_instructions/basic.rst b/doc/installation_instructions/basic.rst index 226ccb0854..73fbceb253 100644 --- a/doc/installation_instructions/basic.rst +++ b/doc/installation_instructions/basic.rst @@ -305,6 +305,20 @@ For example, to install dragon alongside the RedisAI CPU backends, you can run smart build --device cpu --dragon # install Dragon, PT and TF for cpu +``smart build`` supports installing a specific version of dragon. It exposes the +parameters ``--dragon-repo`` and ``--dragon-version``, which can be used alone or +in combination to customize the Dragon installation. For example: + +.. code-block:: bash + + # using the --dragon-repo and --dragon-version flags to customize the Dragon installation + smart build --device cpu --dragon-repo userfork/dragon # install Dragon from a specific repo + smart build --device cpu --dragon-version 0.10 # install a specific Dragon release + + # combining both flags + smart build --device cpu --dragon-repo userfork/dragon --dragon-version 0.91 + + .. note:: Dragon is only supported on Linux systems. For further information, you can read :ref:`the dedicated documentation page `. diff --git a/ex/high_throughput_inference/mli_driver.py b/ex/high_throughput_inference/mli_driver.py new file mode 100644 index 0000000000..36f427937c --- /dev/null +++ b/ex/high_throughput_inference/mli_driver.py @@ -0,0 +1,77 @@ +import os +import base64 +import cloudpickle +import sys +from smartsim import Experiment +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim.status import TERMINAL_STATUSES +from smartsim.settings import DragonRunSettings +import time +import typing as t + +DEVICE = "gpu" +NUM_RANKS = 4 +NUM_WORKERS = 1 +filedir = os.path.dirname(__file__) +worker_manager_script_name = os.path.join(filedir, "standalone_worker_manager.py") +app_script_name = os.path.join(filedir, "mock_app.py") +model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt") + +transport: t.Literal["hsta", "tcp"] = "hsta" + +os.environ["SMARTSIM_DRAGON_TRANSPORT"] = transport + +exp_path = os.path.join(filedir, f"MLI_proto_{transport.upper()}") +os.makedirs(exp_path, exist_ok=True) +exp = Experiment("MLI_proto", launcher="dragon", exp_path=exp_path) + +torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii") + +worker_manager_rs: DragonRunSettings = exp.create_run_settings( + sys.executable, + [ + worker_manager_script_name, + "--device", + DEVICE, + "--worker_class", + torch_worker_str, + "--batch_size", + str(NUM_RANKS//NUM_WORKERS), + "--batch_timeout", + str(0.00), + "--num_workers", + str(NUM_WORKERS) + ], +) + +aff = [] + +worker_manager_rs.set_cpu_affinity(aff) + +worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs) +worker_manager.attach_generator_files(to_copy=[worker_manager_script_name]) + +app_rs: DragonRunSettings = exp.create_run_settings( + sys.executable, + exe_args=[app_script_name, "--device", DEVICE, "--log_max_batchsize", str(6)], +) +app_rs.set_tasks_per_node(NUM_RANKS) + + +app = exp.create_model("app", run_settings=app_rs) +app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) + +exp.generate(worker_manager, app, overwrite=True) +exp.start(worker_manager, app, block=False) + +while True: + if exp.get_status(app)[0] in TERMINAL_STATUSES: + time.sleep(10) + exp.stop(worker_manager) + break + if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES: + time.sleep(10) + exp.stop(app) + break + +print("Exiting.") diff --git a/ex/high_throughput_inference/mock_app.py b/ex/high_throughput_inference/mock_app.py new file mode 100644 index 0000000000..c3b3eaaf4c --- /dev/null +++ b/ex/high_throughput_inference/mock_app.py @@ -0,0 +1,142 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +import dragon +from dragon import fli +from dragon.channels import Channel +import dragon.channels +from dragon.data.ddict.ddict import DDict +from dragon.globalservices.api_setup import connect_to_infrastructure +from dragon.utils import b64decode, b64encode + +# isort: on + +import argparse +import io + +import torch + +from smartsim.log import get_logger + +torch.set_num_interop_threads(16) +torch.set_num_threads(1) + +logger = get_logger("App") +logger.info("Started app") + +from collections import OrderedDict + +from smartsim.log import get_logger, log_to_file +from smartsim._core.mli.client.protoclient import ProtoClient + +logger = get_logger("App") + + +CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False + + +class ResNetWrapper: + """Wrapper around a pre-rained ResNet model.""" + def __init__(self, name: str, model: str): + """Initialize the instance. + + :param name: The name to use for the model + :param model: The path to the pre-trained PyTorch model""" + self._model = torch.jit.load(model) + self._name = name + buffer = io.BytesIO() + scripted = torch.jit.trace(self._model, self.get_batch()) + torch.jit.save(scripted, buffer) + self._serialized_model = buffer.getvalue() + + def get_batch(self, batch_size: int = 32): + """Create a random batch of data with the correct dimensions to + invoke a ResNet model. + + :param batch_size: The desired number of samples to produce + :returns: A PyTorch tensor""" + return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) + + @property + def model(self) -> bytes: + """The content of a model file. + + :returns: The model bytes""" + return self._serialized_model + + @property + def name(self) -> str: + """The name applied to the model. + + :returns: The name""" + return self._name + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("Mock application") + parser.add_argument("--device", default="cpu", type=str) + parser.add_argument("--log_max_batchsize", default=8, type=int) + args = parser.parse_args() + + resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt") + + client = ProtoClient(timing_on=True) + client.set_model(resnet.name, resnet.model) + + if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: + # TODO: adapt to non-Nvidia devices + torch_device = args.device.replace("gpu", "cuda") + pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to( + torch_device + ) + + TOTAL_ITERATIONS = 100 + + for log2_bsize in range(args.log_max_batchsize + 1): + b_size: int = 2**log2_bsize + logger.info(f"Batch size: {b_size}") + for iteration_number in range(TOTAL_ITERATIONS + int(b_size == 1)): + logger.info(f"Iteration: {iteration_number}") + sample_batch = resnet.get_batch(b_size) + remote_result = client.run_model(resnet.name, sample_batch) + logger.info(client.perf_timer.get_last("total_time")) + if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: + local_res = pt_model(sample_batch.to(torch_device)) + err_norm = torch.linalg.vector_norm( + torch.flatten(remote_result).to(torch_device) + - torch.flatten(local_res), + ord=1, + ).cpu() + res_norm = torch.linalg.vector_norm(remote_result, ord=1).item() + local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item() + logger.info( + f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}" + ) + torch.cuda.synchronize() + + client.perf_timer.print_timings(to_file=True) diff --git a/ex/high_throughput_inference/mock_app_redis.py b/ex/high_throughput_inference/mock_app_redis.py new file mode 100644 index 0000000000..8978bcea23 --- /dev/null +++ b/ex/high_throughput_inference/mock_app_redis.py @@ -0,0 +1,90 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import io +import numpy +import time +import torch +from mpi4py import MPI +from smartsim.log import get_logger +from smartsim._core.utils.timings import PerfTimer +from smartredis import Client + +logger = get_logger("App") + +class ResNetWrapper(): + def __init__(self, name: str, model: str): + self._model = torch.jit.load(model) + self._name = name + buffer = io.BytesIO() + scripted = torch.jit.trace(self._model, self.get_batch()) + torch.jit.save(scripted, buffer) + self._serialized_model = buffer.getvalue() + + def get_batch(self, batch_size: int=32): + return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) + + @property + def model(self): + return self._serialized_model + + @property + def name(self): + return self._name + +if __name__ == "__main__": + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + + parser = argparse.ArgumentParser("Mock application") + parser.add_argument("--device", default="cpu") + args = parser.parse_args() + + resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt") + + client = Client(cluster=False, address=None) + client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper()) + + perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"redis{rank}_") + + total_iterations = 100 + timings=[] + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: + logger.info(f"Batch size: {batch_size}") + for iteration_number in range(total_iterations + int(batch_size==1)): + perf_timer.start_timings("batch_size", batch_size) + logger.info(f"Iteration: {iteration_number}") + input_name = f"batch_{rank}" + output_name = f"result_{rank}" + client.put_tensor(name=input_name, data=resnet.get_batch(batch_size).numpy()) + client.run_model(name=resnet.name, inputs=[input_name], outputs=[output_name]) + result = client.get_tensor(name=output_name) + perf_timer.end_timings() + + + perf_timer.print_timings(True) diff --git a/ex/high_throughput_inference/redis_driver.py b/ex/high_throughput_inference/redis_driver.py new file mode 100644 index 0000000000..ff57725d40 --- /dev/null +++ b/ex/high_throughput_inference/redis_driver.py @@ -0,0 +1,66 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +from smartsim import Experiment +from smartsim.status import TERMINAL_STATUSES +import time + +DEVICE = "gpu" +filedir = os.path.dirname(__file__) +app_script_name = os.path.join(filedir, "mock_app_redis.py") +model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt") + + +exp_path = os.path.join(filedir, "redis_ai_multi") +os.makedirs(exp_path, exist_ok=True) +exp = Experiment("redis_ai_multi", launcher="slurm", exp_path=exp_path) + +db = exp.create_database(interface="hsn0") + +app_rs = exp.create_run_settings( + sys.executable, exe_args = [app_script_name, "--device", DEVICE] + ) +app_rs.set_nodes(1) +app_rs.set_tasks(4) +app = exp.create_model("app", run_settings=app_rs) +app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) + +exp.generate(db, app, overwrite=True) + +exp.start(db, app, block=False) + +while True: + if exp.get_status(app)[0] in TERMINAL_STATUSES: + exp.stop(db) + break + if exp.get_status(db)[0] in TERMINAL_STATUSES: + exp.stop(app) + break + time.sleep(5) + +print("Exiting.") \ No newline at end of file diff --git a/ex/high_throughput_inference/standalone_worker_manager.py b/ex/high_throughput_inference/standalone_worker_manager.py new file mode 100644 index 0000000000..b4527bc5d2 --- /dev/null +++ b/ex/high_throughput_inference/standalone_worker_manager.py @@ -0,0 +1,218 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import dragon + +# pylint disable=import-error +import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc +import dragon.native.process as dragon_process +from dragon import fli +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.globalservices.api_setup import connect_to_infrastructure +from dragon.managed_memory import MemoryPool +from dragon.utils import b64decode, b64encode + +# pylint enable=import-error + +# isort: off +# isort: on + +import argparse +import base64 +import multiprocessing as mp +import os +import socket +import time +import typing as t + +import cloudpickle + +from smartsim._core.entrypoints.service import Service +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.control.request_dispatcher import ( + RequestDispatcher, +) +from smartsim._core.mli.infrastructure.control.worker_manager import WorkerManager +from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger("Worker Manager Entry Point") + +mp.set_start_method("dragon") + +pid = os.getpid() +affinity = os.sched_getaffinity(pid) +logger.info(f"Entry point: {socket.gethostname()}, {affinity}") +logger.info(f"CPUS: {os.cpu_count()}") + + +def service_as_dragon_proc( + service: Service, cpu_affinity: list[int], gpu_affinity: list[int] +) -> dragon_process.Process: + + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + return dragon_process.Process( + target=service.execute, + args=[], + cwd=os.getcwd(), + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Worker Manager") + parser.add_argument( + "--device", + type=str, + default="gpu", + choices="gpu cpu".split(), + help="Device on which the inference takes place", + ) + parser.add_argument( + "--worker_class", + type=str, + required=True, + help="Serialized class of worker to run", + ) + parser.add_argument( + "--num_workers", type=int, default=1, help="Number of workers to run" + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="How many requests the workers will try to aggregate before processing them", + ) + parser.add_argument( + "--batch_timeout", + type=float, + default=0.001, + help="How much time (in seconds) should be waited before processing an incomplete aggregated request", + ) + args = parser.parse_args() + + connect_to_infrastructure() + ddict_str = os.environ[BackboneFeatureStore.MLI_BACKBONE] + + backbone = BackboneFeatureStore.from_descriptor(ddict_str) + + to_worker_channel = create_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli) + + backbone.worker_queue = to_worker_fli_comm_ch.descriptor + + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor + + arg_worker_type = cloudpickle.loads( + base64.b64decode(args.worker_class.encode("ascii")) + ) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher = RequestDispatcher( + batch_timeout=args.batch_timeout, + batch_size=args.batch_size, + config_loader=config_loader, + worker_type=arg_worker_type, + ) + + wms = [] + worker_device = args.device + for wm_idx in range(args.num_workers): + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=arg_worker_type, + as_service=True, + cooldown=10, + device=worker_device, + dispatcher_queue=dispatcher.task_queue, + ) + + wms.append(worker_manager) + + wm_affinity: list[int] = [] + disp_affinity: list[int] = [] + + # This is hardcoded for a specific type of node: + # the GPU-to-CPU mapping is taken from the nvidia-smi tool + # TODO can this be computed on the fly? + gpu_to_cpu_aff: dict[int, list[int]] = {} + gpu_to_cpu_aff[0] = list(range(48, 64)) + list(range(112, 128)) + gpu_to_cpu_aff[1] = list(range(32, 48)) + list(range(96, 112)) + gpu_to_cpu_aff[2] = list(range(16, 32)) + list(range(80, 96)) + gpu_to_cpu_aff[3] = list(range(0, 16)) + list(range(64, 80)) + + worker_manager_procs = [] + for worker_idx in range(args.num_workers): + wm_cpus = len(gpu_to_cpu_aff[worker_idx]) - 4 + wm_affinity = gpu_to_cpu_aff[worker_idx][:wm_cpus] + disp_affinity.extend(gpu_to_cpu_aff[worker_idx][wm_cpus:]) + worker_manager_procs.append( + service_as_dragon_proc( + worker_manager, cpu_affinity=wm_affinity, gpu_affinity=[worker_idx] + ) + ) + + dispatcher_proc = service_as_dragon_proc( + dispatcher, cpu_affinity=disp_affinity, gpu_affinity=[] + ) + + # TODO: use ProcessGroup and restart=True? + all_procs = [dispatcher_proc, *worker_manager_procs] + + print(f"Dispatcher proc: {dispatcher_proc}") + for proc in all_procs: + proc.start() + + while all(proc.is_alive for proc in all_procs): + time.sleep(1) diff --git a/pyproject.toml b/pyproject.toml index 62df92f0c9..61e17891b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ markers = [ "group_a: fast test subset a", "group_b: fast test subset b", "slow_tests: tests that take a long duration to complete", + "dragon: tests that must be executed in a dragon runtime", ] [tool.isort] diff --git a/setup.py b/setup.py index 571974d284..cd5ace55db 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ class BuildError(Exception): pass + # Define needed dependencies for the installation extras_require = { @@ -176,6 +177,7 @@ class BuildError(Exception): "GitPython<=3.1.43", "protobuf<=3.20.3", "jinja2>=3.1.2", + "pycapnp==2.0.0", "watchdog>4,<5", "pydantic>2", "pyzmq>=25.1.2", diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 5d094b72f4..ec9ef4aa29 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -36,7 +36,13 @@ from tabulate import tabulate -from smartsim._core._cli.scripts.dragon_install import install_dragon +from smartsim._core._cli.scripts.dragon_install import ( + DEFAULT_DRAGON_REPO, + DEFAULT_DRAGON_VERSION, + DragonInstallRequest, + display_post_install_logs, + install_dragon, +) from smartsim._core._cli.utils import SMART_LOGGER_FORMAT from smartsim._core._install import builder from smartsim._core._install.buildenv import BuildEnv, DbEngine, Version_, Versioner @@ -67,22 +73,22 @@ def check_backends_install() -> bool: """Checks if backends have already been installed. Logs details on how to proceed forward - if the RAI_PATH environment variable is set or if + if the SMARTSIM_RAI_LIB environment variable is set or if backends have already been installed. """ - rai_path = os.environ.get("RAI_PATH", "") + rai_path = os.environ.get("SMARTSIM_RAI_LIB", "") installed = installed_redisai_backends() msg = "" if rai_path and installed: msg = ( f"There is no need to build. backends are already built and " - f"specified in the environment at 'RAI_PATH': {CONFIG.redisai}" + f"specified in the environment at 'SMARTSIM_RAI_LIB': {CONFIG.redisai}" ) elif rai_path and not installed: msg = ( - "Before running 'smart build', unset your RAI_PATH environment " - "variable with 'unset RAI_PATH'." + "Before running 'smart build', unset your SMARTSIM_RAI_LIB environment " + "variable with 'unset SMARTSIM_RAI_LIB'." ) elif not rai_path and installed: msg = ( @@ -231,7 +237,7 @@ def _configure_keydb_build(versions: Versioner) -> None: CONFIG.conf_path = Path(CONFIG.core_path, "config", "keydb.conf") if not CONFIG.conf_path.resolve().is_file(): raise SSConfigError( - "Database configuration file at REDIS_CONF could not be found" + "Database configuration file at SMARTSIM_REDIS_CONF could not be found" ) @@ -245,6 +251,8 @@ def execute( keydb = args.keydb device = Device.from_str(args.device.lower()) is_dragon_requested = args.dragon + dragon_repo = args.dragon_repo + dragon_version = args.dragon_version if Path(CONFIG.build_path).exists(): logger.warning(f"Build path already exists, removing: {CONFIG.build_path}") @@ -294,12 +302,23 @@ def execute( logger.info("ML Packages") print(mlpackages) - if is_dragon_requested: + if is_dragon_requested or dragon_repo or dragon_version: install_to = CONFIG.core_path / ".dragon" - return_code = install_dragon(install_to) + + try: + request = DragonInstallRequest( + install_to, + dragon_repo, + dragon_version, + ) + return_code = install_dragon(request) + except ValueError as ex: + return_code = 2 + logger.error(" ".join(ex.args)) if return_code == 0: - logger.info("Dragon installation complete") + display_post_install_logs() + elif return_code == 1: logger.info("Dragon installation not supported on platform") else: @@ -358,6 +377,21 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: default=False, help="Install the dragon runtime", ) + parser.add_argument( + "--dragon-repo", + default=None, + type=str, + help=( + "Specify a git repo containing dragon release assets " + f"(e.g. {DEFAULT_DRAGON_REPO})" + ), + ) + parser.add_argument( + "--dragon-version", + default=None, + type=str, + help=f"Specify the dragon version to install (e.g. {DEFAULT_DRAGON_VERSION})", + ) parser.add_argument( "--skip-python-packages", action="store_true", diff --git a/smartsim/_core/_cli/scripts/dragon_install.py b/smartsim/_core/_cli/scripts/dragon_install.py index 8028b8ecfd..3a9358390b 100644 --- a/smartsim/_core/_cli/scripts/dragon_install.py +++ b/smartsim/_core/_cli/scripts/dragon_install.py @@ -1,10 +1,16 @@ import os import pathlib +import re +import shutil import sys import typing as t +from urllib.request import Request, urlopen from github import Github +from github.Auth import Token +from github.GitRelease import GitRelease from github.GitReleaseAsset import GitReleaseAsset +from github.Repository import Repository from smartsim._core._cli.utils import pip from smartsim._core._install.utils import retrieve @@ -15,20 +21,90 @@ logger = get_logger(__name__) +DEFAULT_DRAGON_REPO = "DragonHPC/dragon" +DEFAULT_DRAGON_VERSION = "0.9" +DEFAULT_DRAGON_VERSION_TAG = f"v{DEFAULT_DRAGON_VERSION}" +_GH_TOKEN = "SMARTSIM_DRAGON_TOKEN" -def create_dotenv(dragon_root_dir: pathlib.Path) -> None: + +class DragonInstallRequest: + """Encapsulates a request to install the dragon package""" + + def __init__( + self, + working_dir: pathlib.Path, + repo_name: t.Optional[str] = None, + version: t.Optional[str] = None, + ) -> None: + """Initialize an install request. + + :param working_dir: A path to store temporary files used during installation + :param repo_name: The name of a repository to install from, e.g. DragonHPC/dragon + :param version: The version to install, e.g. v0.10 + """ + + self.working_dir = working_dir + """A path to store temporary files used during installation""" + + self.repo_name = repo_name or DEFAULT_DRAGON_REPO + """The name of a repository to install from, e.g. DragonHPC/dragon""" + + self.pkg_version = version or DEFAULT_DRAGON_VERSION + """The version to install, e.g. 0.10""" + + self._check() + + def _check(self) -> None: + """Perform validation of this instance + + :raises ValueError: if any value fails validation""" + if not self.repo_name or len(self.repo_name.split("/")) != 2: + raise ValueError( + f"Invalid dragon repository name. Example: `dragonhpc/dragon`" + ) + + # version must match standard dragon tag & filename format `vX.YZ` + match = re.match(r"^\d\.\d+$", self.pkg_version) + if not self.pkg_version or not match: + raise ValueError("Invalid dragon version. Examples: `0.9, 0.91, 0.10`") + + # attempting to retrieve from a non-default repository requires an auth token + if self.repo_name.lower() != DEFAULT_DRAGON_REPO.lower() and not self.raw_token: + raise ValueError( + f"An access token must be available to access {self.repo_name}. " + f"Set the `{_GH_TOKEN}` env var to pass your access token." + ) + + @property + def raw_token(self) -> t.Optional[str]: + """Returns the raw access token from the environment, if available""" + return os.environ.get(_GH_TOKEN, None) + + +def get_auth_token(request: DragonInstallRequest) -> t.Optional[Token]: + """Create a Github.Auth.Token if an access token can be found + in the environment + + :param request: details of a request for the installation of the dragon package + :returns: an auth token if one can be built, otherwise `None`""" + if gh_token := request.raw_token: + return Token(gh_token) + return None + + +def create_dotenv(dragon_root_dir: pathlib.Path, dragon_version: str) -> None: """Create a .env file with required environment variables for the Dragon runtime""" dragon_root = str(dragon_root_dir) - dragon_inc_dir = str(dragon_root_dir / "include") - dragon_lib_dir = str(dragon_root_dir / "lib") - dragon_bin_dir = str(dragon_root_dir / "bin") + dragon_inc_dir = dragon_root + "/include" + dragon_lib_dir = dragon_root + "/lib" + dragon_bin_dir = dragon_root + "/bin" dragon_vars = { "DRAGON_BASE_DIR": dragon_root, - "DRAGON_ROOT_DIR": dragon_root, # note: same as base_dir + "DRAGON_ROOT_DIR": dragon_root, "DRAGON_INCLUDE_DIR": dragon_inc_dir, "DRAGON_LIB_DIR": dragon_lib_dir, - "DRAGON_VERSION": dragon_pin(), + "DRAGON_VERSION": dragon_version, "PATH": dragon_bin_dir, "LD_LIBRARY_PATH": dragon_lib_dir, } @@ -48,12 +124,6 @@ def python_version() -> str: return f"py{sys.version_info.major}.{sys.version_info.minor}" -def dragon_pin() -> str: - """Return a string indicating the pinned major/minor version of the dragon - package to install""" - return "0.9" - - def _platform_filter(asset_name: str) -> bool: """Return True if the asset name matches naming standard for current platform (Cray or non-Cray). Otherwise, returns False. @@ -75,67 +145,125 @@ def _version_filter(asset_name: str) -> bool: return python_version() in asset_name -def _pin_filter(asset_name: str) -> bool: +def _pin_filter(asset_name: str, dragon_version: str) -> bool: """Return true if the supplied value contains a dragon version pin match - :param asset_name: A value to inspect for keywords indicating a dragon version + :param asset_name: the asset name to inspect for keywords indicating a dragon version + :param dragon_version: the dragon version to match :returns: True if supplied value is correct for current dragon version""" - return f"dragon-{dragon_pin()}" in asset_name + return f"dragon-{dragon_version}" in asset_name + + +def _get_all_releases(dragon_repo: Repository) -> t.Collection[GitRelease]: + """Retrieve all available releases for the configured dragon repository + + :param dragon_repo: A GitHub repository object for the dragon package + :returns: A list of GitRelease""" + all_releases = [release for release in list(dragon_repo.get_releases())] + return all_releases -def _get_release_assets() -> t.Collection[GitReleaseAsset]: +def _get_release_assets(request: DragonInstallRequest) -> t.Collection[GitReleaseAsset]: """Retrieve a collection of available assets for all releases that satisfy the dragon version pin + :param request: details of a request for the installation of the dragon package :returns: A collection of release assets""" - git = Github() - - dragon_repo = git.get_repo("DragonHPC/dragon") + auth = get_auth_token(request) + git = Github(auth=auth) + dragon_repo = git.get_repo(request.repo_name) if dragon_repo is None: raise SmartSimCLIActionCancelled("Unable to locate dragon repo") - # find any releases matching our pinned version requirement - tags = [tag for tag in dragon_repo.get_tags() if dragon_pin() in tag.name] - # repo.get_latest_release fails if only pre-release results are returned - pin_releases = list(dragon_repo.get_release(tag.name) for tag in tags) - releases = sorted(pin_releases, key=lambda r: r.published_at, reverse=True) + all_releases = sorted( + _get_all_releases(dragon_repo), key=lambda r: r.published_at, reverse=True + ) - # take the most recent release for the given pin - assets = releases[0].assets + # filter the list of releases to include only the target version + releases = [ + release + for release in all_releases + if request.pkg_version in release.title or release.tag_name + ] + + releases = sorted(releases, key=lambda r: r.published_at, reverse=True) + + if not releases: + release_titles = ", ".join(release.title for release in all_releases) + raise SmartSimCLIActionCancelled( + f"Unable to find a release for dragon version {request.pkg_version}. " + f"Available releases: {release_titles}" + ) + + assets: t.List[GitReleaseAsset] = [] + + # install the latest release of the target version (including pre-release) + for release in releases: + # delay in attaching release assets may leave us with an empty list, retry + # with the next available release + if assets := list(release.get_assets()): + logger.debug(f"Found assets for dragon release {release.title}") + break + else: + logger.debug(f"No assets for dragon release {release.title}. Retrying.") + + if not assets: + raise SmartSimCLIActionCancelled( + f"Unable to find assets for dragon release {release.title}" + ) return assets -def filter_assets(assets: t.Collection[GitReleaseAsset]) -> t.Optional[GitReleaseAsset]: +def filter_assets( + request: DragonInstallRequest, assets: t.Collection[GitReleaseAsset] +) -> t.Optional[GitReleaseAsset]: """Filter the available release assets so that HSTA agents are used when run on a Cray EX platform + :param request: details of a request for the installation of the dragon package :param assets: The collection of dragon release assets to filter :returns: An asset meeting platform & version filtering requirements""" # Expect cray & non-cray assets that require a filter, e.g. # 'dragon-0.8-py3.9.4.1-bafaa887f.tar.gz', # 'dragon-0.8-py3.9.4.1-CRAYEX-ac132fe95.tar.gz' - asset = next( - ( - asset - for asset in assets - if _version_filter(asset.name) - and _platform_filter(asset.name) - and _pin_filter(asset.name) - ), - None, + all_assets = [asset.name for asset in assets] + + assets = list( + asset + for asset in assets + if _version_filter(asset.name) and _pin_filter(asset.name, request.pkg_version) ) + + if len(assets) == 0: + available = "\n\t".join(all_assets) + logger.warning( + f"Please specify a dragon version (e.g. {DEFAULT_DRAGON_VERSION}) " + f"of an asset available in the repository:\n\t{available}" + ) + return None + + asset: t.Optional[GitReleaseAsset] = None + + # Apply platform filter if we have multiple matches for python/dragon version + if len(assets) > 0: + asset = next((asset for asset in assets if _platform_filter(asset.name)), None) + + if not asset: + asset = assets[0] + logger.warning(f"Platform-specific package not found. Using {asset.name}") + return asset -def retrieve_asset_info() -> GitReleaseAsset: +def retrieve_asset_info(request: DragonInstallRequest) -> GitReleaseAsset: """Find a release asset that meets all necessary filtering criteria - :param dragon_pin: identify the dragon version to install (e.g. dragon-0.8) + :param request: details of a request for the installation of the dragon package :returns: A GitHub release asset""" - assets = _get_release_assets() - asset = filter_assets(assets) + assets = _get_release_assets(request) + asset = filter_assets(request, assets) platform_result = check_platform() if not platform_result.is_cray: @@ -150,42 +278,79 @@ def retrieve_asset_info() -> GitReleaseAsset: return asset -def retrieve_asset(working_dir: pathlib.Path, asset: GitReleaseAsset) -> pathlib.Path: +def retrieve_asset( + request: DragonInstallRequest, asset: GitReleaseAsset +) -> pathlib.Path: """Retrieve the physical file associated to a given GitHub release asset - :param working_dir: location in file system where assets should be written + :param request: details of a request for the installation of the dragon package :param asset: GitHub release asset to retrieve - :returns: path to the downloaded asset""" - if working_dir.exists() and list(working_dir.rglob("*.whl")): - return working_dir + :returns: path to the directory containing the extracted release asset + :raises SmartSimCLIActionCancelled: if the asset cannot be downloaded or extracted + """ + download_dir = request.working_dir / str(asset.id) + + # if we've previously downloaded the release and still have + # wheels laying around, use that cached version instead + cleanup(download_dir) + download_dir.mkdir(parents=True, exist_ok=True) + + # grab a copy of the complete asset + asset_path = download_dir / str(asset.name) + + # use the asset URL instead of the browser_download_url to enable + # using auth for private repositories + headers: t.Dict[str, str] = {"Accept": "application/octet-stream"} + + if request.raw_token: + headers["Authorization"] = f"Bearer {request.raw_token}" + + try: + # a github asset endpoint causes a redirect. the first request + # receives a pre-signed URL to the asset to pass on to retrieve + dl_request = Request(asset.url, headers=headers) + response = urlopen(dl_request) + presigned_url = response.url + + logger.debug(f"Retrieved asset {asset.name} metadata from {asset.url}") + except Exception: + logger.exception(f"Unable to download {asset.name} from: {asset.url}") + presigned_url = asset.url + + # extract the asset + try: + retrieve(presigned_url, asset_path) - retrieve(asset.browser_download_url, working_dir) + logger.debug(f"Extracted {asset.name} to {download_dir}") + except Exception as ex: + raise SmartSimCLIActionCancelled( + f"Unable to extract {asset.name} from {download_dir}" + ) from ex - logger.debug(f"Retrieved {asset.browser_download_url} to {working_dir}") - return working_dir + return download_dir -def install_package(asset_dir: pathlib.Path) -> int: +def install_package(request: DragonInstallRequest, asset_dir: pathlib.Path) -> int: """Install the package found in `asset_dir` into the current python environment - :param asset_dir: path to a decompressed archive contents for a release asset""" - wheels = asset_dir.rglob("*.whl") - wheel_path = next(wheels, None) - if not wheel_path: - logger.error(f"No wheel found for package in {asset_dir}") + :param request: details of a request for the installation of the dragon package + :param asset_dir: path to a decompressed archive contents for a release asset + :returns: Integer return code, 0 for success, non-zero on failures""" + found_wheels = list(asset_dir.rglob("*.whl")) + if not found_wheels: + logger.error(f"No wheel(s) found for package in {asset_dir}") return 1 - create_dotenv(wheel_path.parent) - - while wheel_path is not None: - logger.info(f"Installing package: {wheel_path.absolute()}") + create_dotenv(found_wheels[0].parent, request.pkg_version) - try: - pip("install", "--force-reinstall", str(wheel_path), "numpy<2") - wheel_path = next(wheels, None) - except Exception: - logger.error(f"Unable to install from {asset_dir}") - return 1 + try: + wheels = list(map(str, found_wheels)) + for wheel_path in wheels: + logger.info(f"Installing package: {wheel_path}") + pip("install", wheel_path) + except Exception: + logger.error(f"Unable to install from {asset_dir}") + return 1 return 0 @@ -196,36 +361,83 @@ def cleanup( """Delete the downloaded asset and any files extracted during installation :param archive_path: path to a downloaded archive for a release asset""" - if archive_path: - archive_path.unlink(missing_ok=True) - logger.debug(f"Deleted archive: {archive_path}") + if not archive_path: + return + + if archive_path.exists() and archive_path.is_file(): + archive_path.unlink() + archive_path = archive_path.parent + if archive_path.exists() and archive_path.is_dir(): + shutil.rmtree(archive_path, ignore_errors=True) + logger.debug(f"Deleted temporary files in: {archive_path}") -def install_dragon(extraction_dir: t.Union[str, os.PathLike[str]]) -> int: + +def install_dragon(request: DragonInstallRequest) -> int: """Retrieve a dragon runtime appropriate for the current platform and install to the current python environment - :param extraction_dir: path for download and extraction of assets + + :param request: details of a request for the installation of the dragon package :returns: Integer return code, 0 for success, non-zero on failures""" if sys.platform == "darwin": logger.debug(f"Dragon not supported on platform: {sys.platform}") return 1 - extraction_dir = pathlib.Path(extraction_dir) - filename: t.Optional[pathlib.Path] = None asset_dir: t.Optional[pathlib.Path] = None try: - asset_info = retrieve_asset_info() - asset_dir = retrieve_asset(extraction_dir, asset_info) + asset_info = retrieve_asset_info(request) + if asset_info is not None: + asset_dir = retrieve_asset(request, asset_info) + return install_package(request, asset_dir) - return install_package(asset_dir) + except SmartSimCLIActionCancelled as ex: + logger.warning(*ex.args) except Exception as ex: - logger.error("Unable to install dragon runtime", exc_info=ex) - finally: - cleanup(filename) + logger.error("Unable to install dragon runtime", exc_info=True) return 2 +def display_post_install_logs() -> None: + """Display post-installation instructions for the user""" + + examples = { + "ofi-include": "/opt/cray/include", + "ofi-build-lib": "/opt/cray/lib64", + "ofi-runtime-lib": "/opt/cray/lib64", + } + + config = ":".join(f"{k}={v}" for k, v in examples.items()) + example_msg1 = f"dragon-config -a \\" + example_msg2 = f' "{config}"' + + logger.info( + "************************** Dragon Package Installed *****************************" + ) + logger.info("To enable Dragon to use HSTA (default: TCP), configure the following:") + + for key in examples: + logger.info(f"\t{key}") + + logger.info("Example:") + logger.info(example_msg1) + logger.info(example_msg2) + logger.info( + "*********************************************************************************" + ) + + if __name__ == "__main__": - sys.exit(install_dragon(CONFIG.core_path / ".dragon")) + # path for download and extraction of assets + extraction_dir = CONFIG.core_path / ".dragon" + dragon_repo = DEFAULT_DRAGON_REPO + dragon_version = DEFAULT_DRAGON_VERSION + + request = DragonInstallRequest( + extraction_dir, + dragon_repo, + dragon_version, + ) + + sys.exit(install_dragon(request)) diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 17036e825e..957f2b6ef6 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -246,7 +246,9 @@ def build_from_git(self, git_url: str, branch: str) -> None: bin_path = Path(dependency_path, "bin").resolve() try: database_exe = next(bin_path.glob("*-server")) - database = Path(os.environ.get("REDIS_PATH", database_exe)).resolve() + database = Path( + os.environ.get("SMARTSIM_REDIS_SERVER_EXE", database_exe) + ).resolve() _ = expand_exe_path(str(database)) except (TypeError, FileNotFoundError) as e: raise BuildError("Installation of redis-server failed!") from e @@ -254,7 +256,9 @@ def build_from_git(self, git_url: str, branch: str) -> None: # validate install -- redis-cli try: redis_cli_exe = next(bin_path.glob("*-cli")) - redis_cli = Path(os.environ.get("REDIS_CLI_PATH", redis_cli_exe)).resolve() + redis_cli = Path( + os.environ.get("SMARTSIM_REDIS_CLI_EXE", redis_cli_exe) + ).resolve() _ = expand_exe_path(str(redis_cli)) except (TypeError, FileNotFoundError) as e: raise BuildError("Installation of redis-cli failed!") from e diff --git a/smartsim/_core/config/config.py b/smartsim/_core/config/config.py index 03c284edb3..c8b4ff17b9 100644 --- a/smartsim/_core/config/config.py +++ b/smartsim/_core/config/config.py @@ -40,19 +40,19 @@ # These values can be set through environment variables to # override the default behavior of SmartSim. # -# RAI_PATH +# SMARTSIM_RAI_LIB # - Path to the RAI shared library # - Default: /smartsim/smartsim/_core/lib/redisai.so # -# REDIS_CONF +# SMARTSIM_REDIS_CONF # - Path to the redis.conf file # - Default: /SmartSim/smartsim/_core/config/redis.conf # -# REDIS_PATH +# SMARTSIM_REDIS_SERVER_EXE # - Path to the redis-server executable # - Default: /SmartSim/smartsim/_core/bin/redis-server # -# REDIS_CLI_PATH +# SMARTSIM_REDIS_CLI_EXE # - Path to the redis-cli executable # - Default: /SmartSim/smartsim/_core/bin/redis-cli # @@ -120,20 +120,20 @@ def build_path(self) -> Path: @property def redisai(self) -> str: rai_path = self.lib_path / "redisai.so" - redisai = Path(os.environ.get("RAI_PATH", rai_path)).resolve() + redisai = Path(os.environ.get("SMARTSIM_RAI_LIB", rai_path)).resolve() if not redisai.is_file(): raise SSConfigError( "RedisAI dependency not found. Build with `smart` cli " - "or specify RAI_PATH" + "or specify SMARTSIM_RAI_LIB" ) return str(redisai) @property def database_conf(self) -> str: - conf = Path(os.environ.get("REDIS_CONF", self.conf_path)).resolve() + conf = Path(os.environ.get("SMARTSIM_REDIS_CONF", self.conf_path)).resolve() if not conf.is_file(): raise SSConfigError( - "Database configuration file at REDIS_CONF could not be found" + "Database configuration file at SMARTSIM_REDIS_CONF could not be found" ) return str(conf) @@ -141,24 +141,29 @@ def database_conf(self) -> str: def database_exe(self) -> str: try: database_exe = next(self.bin_path.glob("*-server")) - database = Path(os.environ.get("REDIS_PATH", database_exe)).resolve() + database = Path( + os.environ.get("SMARTSIM_REDIS_SERVER_EXE", database_exe) + ).resolve() exe = expand_exe_path(str(database)) return exe except (TypeError, FileNotFoundError) as e: raise SSConfigError( - "Specified database binary at REDIS_PATH could not be used" + "Specified database binary at SMARTSIM_REDIS_SERVER_EXE " + "could not be used" ) from e @property def database_cli(self) -> str: try: redis_cli_exe = next(self.bin_path.glob("*-cli")) - redis_cli = Path(os.environ.get("REDIS_CLI_PATH", redis_cli_exe)).resolve() + redis_cli = Path( + os.environ.get("SMARTSIM_REDIS_CLI_EXE", redis_cli_exe) + ).resolve() exe = expand_exe_path(str(redis_cli)) return exe except (TypeError, FileNotFoundError) as e: raise SSConfigError( - "Specified Redis binary at REDIS_CLI_PATH could not be used" + "Specified Redis binary at SMARTSIM_REDIS_CLI_EXE could not be used" ) from e @property @@ -178,7 +183,7 @@ def dragon_dotenv(self) -> Path: def dragon_server_path(self) -> t.Optional[str]: return os.getenv( "SMARTSIM_DRAGON_SERVER_PATH", - os.getenv("SMARTSIM_DRAGON_SERVER_PATH_EXP", None), + os.getenv("_SMARTSIM_DRAGON_SERVER_PATH_EXP", None), ) @property @@ -306,10 +311,6 @@ def smartsim_key_path(self) -> str: default_path = Path.home() / ".smartsim" / "keys" return os.environ.get("SMARTSIM_KEY_PATH", str(default_path)) - @property - def dragon_pin(self) -> str: - return "0.9" - @lru_cache(maxsize=128, typed=False) def get_config() -> Config: diff --git a/smartsim/_core/entrypoints/service.py b/smartsim/_core/entrypoints/service.py new file mode 100644 index 0000000000..719c2a60fe --- /dev/null +++ b/smartsim/_core/entrypoints/service.py @@ -0,0 +1,185 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import datetime +import time +import typing as t +from abc import ABC, abstractmethod + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class Service(ABC): + """Core API for standalone entrypoint scripts. Makes use of overridable hook + methods to modify behaviors (event loop, automatic shutdown, cooldown) as + well as simple hooks for status changes""" + + def __init__( + self, + as_service: bool = False, + cooldown: float = 0, + loop_delay: float = 0, + health_check_frequency: float = 0, + ) -> None: + """Initialize the Service + + :param as_service: Determines lifetime of the service. When `True`, calling + execute on the service will run continuously until shutdown criteria are met. + Otherwise, `execute` performs a single pass through the service lifecycle and + automatically exits (regardless of the result of `_can_shutdown`). + :param cooldown: Period of time (in seconds) to allow the service to run + after a shutdown is permitted. Enables the service to avoid restarting if + new work is discovered. A value of 0 disables the cooldown. + :param loop_delay: Duration (in seconds) of a forced delay between + iterations of the event loop + :param health_check_frequency: Time (in seconds) between calls to a + health check handler. A value of 0 triggers the health check on every + iteration. + """ + self._as_service = as_service + """Determines lifetime of the service. When `True`, calling + `execute` on the service will run continuously until shutdown criteria are met. + Otherwise, `execute` performs a single pass through the service lifecycle and + automatically exits (regardless of the result of `_can_shutdown`).""" + self._cooldown = abs(cooldown) + """Period of time (in seconds) to allow the service to run + after a shutdown is permitted. Enables the service to avoid restarting if + new work is discovered. A value of 0 disables the cooldown.""" + self._loop_delay = abs(loop_delay) + """Duration (in seconds) of a forced delay between + iterations of the event loop""" + self._health_check_frequency = health_check_frequency + """Time (in seconds) between calls to a + health check handler. A value of 0 triggers the health check on every + iteration.""" + self._last_health_check = time.time() + """The timestamp of the latest health check""" + + @abstractmethod + def _on_iteration(self) -> None: + """The user-defined event handler. Executed repeatedly until shutdown + conditions are satisfied and cooldown is elapsed. + """ + + @abstractmethod + def _can_shutdown(self) -> bool: + """Return true when the criteria to shut down the service are met.""" + + def _on_start(self) -> None: + """Empty hook method for use by subclasses. Called on initial entry into + Service `execute` event loop before `_on_iteration` is invoked.""" + logger.debug(f"Starting {self.__class__.__name__}") + + def _on_shutdown(self) -> None: + """Empty hook method for use by subclasses. Called immediately after exiting + the main event loop during automatic shutdown.""" + logger.debug(f"Shutting down {self.__class__.__name__}") + + def _on_health_check(self) -> None: + """Empty hook method for use by subclasses. Invoked based on the + value of `self._health_check_frequency`.""" + logger.debug(f"Performing health check for {self.__class__.__name__}") + + def _on_cooldown_elapsed(self) -> None: + """Empty hook method for use by subclasses. Called on every event loop + iteration immediately upon exceeding the cooldown period""" + logger.debug(f"Cooldown exceeded by {self.__class__.__name__}") + + def _on_delay(self) -> None: + """Empty hook method for use by subclasses. Called on every event loop + iteration immediately before executing a delay before the next iteration""" + logger.debug(f"Service iteration waiting for {self.__class__.__name__}s") + + def _log_cooldown(self, elapsed: float) -> None: + """Log the remaining cooldown time, if any""" + remaining = self._cooldown - elapsed + if remaining > 0: + logger.debug(f"{abs(remaining):.2f}s remains of {self._cooldown}s cooldown") + else: + logger.info(f"exceeded cooldown {self._cooldown}s by {abs(remaining):.2f}s") + + def execute(self) -> None: + """The main event loop of a service host. Evaluates shutdown criteria and + combines with a cooldown period to allow automatic service termination. + Responsible for executing calls to subclass implementation of `_on_iteration`""" + + try: + self._on_start() + except Exception: + logger.exception("Unable to start service.") + return + + running = True + cooldown_start: t.Optional[datetime.datetime] = None + + while running: + try: + self._on_iteration() + except Exception: + running = False + logger.exception( + "Failure in event loop resulted in service termination" + ) + + if self._health_check_frequency >= 0: + hc_elapsed = time.time() - self._last_health_check + if hc_elapsed >= self._health_check_frequency: + self._on_health_check() + self._last_health_check = time.time() + + # allow immediate shutdown if not set to run as a service + if not self._as_service: + running = False + continue + + # reset cooldown period if shutdown criteria are not met + if not self._can_shutdown(): + cooldown_start = None + + # start tracking cooldown elapsed once eligible to quit + if cooldown_start is None: + cooldown_start = datetime.datetime.now() + + # change running state if cooldown period is exceeded + if self._cooldown > 0: + elapsed = datetime.datetime.now() - cooldown_start + running = elapsed.total_seconds() < self._cooldown + self._log_cooldown(elapsed.total_seconds()) + if not running: + self._on_cooldown_elapsed() + elif self._cooldown < 1 and self._can_shutdown(): + running = False + + if self._loop_delay: + self._on_delay() + time.sleep(self._loop_delay) + + try: + self._on_shutdown() + except Exception: + logger.exception("Service shutdown may not have completed.") diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragonBackend.py index 4aba60d558..5e01299141 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragonBackend.py @@ -26,6 +26,8 @@ import collections import functools import itertools +import os +import socket import time import typing as t from dataclasses import dataclass, field @@ -34,15 +36,27 @@ from tabulate import tabulate -# pylint: disable=import-error +# pylint: disable=import-error,C0302,R0915 # isort: off + import dragon.infrastructure.connection as dragon_connection import dragon.infrastructure.policy as dragon_policy -import dragon.native.group_state as dragon_group_state +import dragon.infrastructure.process_desc as dragon_process_desc + import dragon.native.process as dragon_process import dragon.native.process_group as dragon_process_group import dragon.native.machine as dragon_machine +from smartsim._core.launcher.dragon.pqueue import NodePrioritizer, PrioritizerFilter +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_util import create_ddict +from smartsim.error.errors import SmartSimError + # pylint: enable=import-error # isort: on from ...._core.config import get_config @@ -68,8 +82,8 @@ class DragonStatus(str, Enum): - ERROR = str(dragon_group_state.Error()) - RUNNING = str(dragon_group_state.Running()) + ERROR = "Error" + RUNNING = "Running" def __str__(self) -> str: return self.value @@ -86,7 +100,7 @@ class ProcessGroupInfo: return_codes: t.Optional[t.List[int]] = None """List of return codes of completed processes""" hosts: t.List[str] = field(default_factory=list) - """List of hosts on which the Process Group """ + """List of hosts on which the Process Group should be executed""" redir_workers: t.Optional[dragon_process_group.ProcessGroup] = None """Workers used to redirect stdout and stderr to file""" @@ -143,6 +157,11 @@ class DragonBackend: by threads spawned by it. """ + _DEFAULT_NUM_MGR_PER_NODE = 2 + """The default number of manager processes for each feature store node""" + _DEFAULT_MEM_PER_NODE = 512 * 1024**2 + """The default memory capacity (in bytes) to allocate for a feaure store node""" + def __init__(self, pid: int) -> None: self._pid = pid """PID of dragon executable which launched this server""" @@ -153,7 +172,6 @@ def __init__(self, pid: int) -> None: self._step_ids = (f"{create_short_id_str()}-{id}" for id in itertools.count()) """Incremental ID to assign to new steps prior to execution""" - self._initialize_hosts() self._queued_steps: "collections.OrderedDict[str, DragonRunRequest]" = ( collections.OrderedDict() ) @@ -177,16 +195,26 @@ def __init__(self, pid: int) -> None: """Whether the server frontend should shut down when the backend does""" self._shutdown_initiation_time: t.Optional[float] = None """The time at which the server initiated shutdown""" - smartsim_config = get_config() - self._cooldown_period = ( - smartsim_config.telemetry_frequency * 2 + 5 - if smartsim_config.telemetry_enabled - else 5 - ) - """Time in seconds needed to server to complete shutdown""" + self._cooldown_period = self._initialize_cooldown() + """Time in seconds needed by the server to complete shutdown""" + self._backbone: t.Optional[BackboneFeatureStore] = None + """The backbone feature store""" + self._listener: t.Optional[dragon_process.Process] = None + """The standalone process executing the event consumer""" + + self._nodes: t.List["dragon_machine.Node"] = [] + """Node capability information for hosts in the allocation""" + self._hosts: t.List[str] = [] + """List of hosts available in allocation""" + self._cpus: t.List[int] = [] + """List of cpu-count by node""" + self._gpus: t.List[int] = [] + """List of gpu-count by node""" + self._allocated_hosts: t.Dict[str, t.Set[str]] = {} + """Mapping with hostnames as keys and a set of running step IDs as the value""" - self._view = DragonBackendView(self) - logger.debug(self._view.host_desc) + self._initialize_hosts() + self._prioritizer = NodePrioritizer(self._nodes, self._queue_lock) @property def hosts(self) -> list[str]: @@ -194,34 +222,39 @@ def hosts(self) -> list[str]: return self._hosts @property - def allocated_hosts(self) -> dict[str, str]: + def allocated_hosts(self) -> dict[str, t.Set[str]]: + """A map of host names to the step id executing on a host + + :returns: Dictionary with host name as key and step id as value""" with self._queue_lock: return self._allocated_hosts @property - def free_hosts(self) -> t.Deque[str]: + def free_hosts(self) -> t.Sequence[str]: + """Find hosts that do not have a step assigned + + :returns: List of host names""" with self._queue_lock: - return self._free_hosts + return list(map(lambda x: x.hostname, self._prioritizer.unassigned())) @property def group_infos(self) -> dict[str, ProcessGroupInfo]: + """Find information pertaining to process groups executing on a host + + :returns: Dictionary with host name as key and group information as value""" with self._queue_lock: return self._group_infos def _initialize_hosts(self) -> None: + """Prepare metadata about the allocation""" with self._queue_lock: self._nodes = [ dragon_machine.Node(node) for node in dragon_machine.System().nodes ] - self._hosts: t.List[str] = sorted(node.hostname for node in self._nodes) + self._hosts = sorted(node.hostname for node in self._nodes) self._cpus = [node.num_cpus for node in self._nodes] self._gpus = [node.num_gpus for node in self._nodes] - - """List of hosts available in allocation""" - self._free_hosts: t.Deque[str] = collections.deque(self._hosts) - """List of hosts on which steps can be launched""" - self._allocated_hosts: t.Dict[str, str] = {} - """Mapping of hosts on which a step is already running to step ID""" + self._allocated_hosts = collections.defaultdict(set) def __str__(self) -> str: return self.status_message @@ -230,21 +263,19 @@ def __str__(self) -> str: def status_message(self) -> str: """Message with status of available nodes and history of launched jobs. - :returns: Status message + :returns: a status message """ - return ( - "Dragon server backend update\n" - f"{self._view.host_table}\n{self._view.step_table}" - ) + view = DragonBackendView(self) + return "Dragon server backend update\n" f"{view.host_table}\n{view.step_table}" def _heartbeat(self) -> None: + """Update the value of the last heartbeat to the current time.""" self._last_beat = self.current_time @property def cooldown_period(self) -> int: - """Time (in seconds) the server will wait before shutting down - - when exit conditions are met (see ``should_shutdown()`` for further details). + """Time (in seconds) the server will wait before shutting down when + exit conditions are met (see ``should_shutdown()`` for further details). """ return self._cooldown_period @@ -278,6 +309,8 @@ def should_shutdown(self) -> bool: and it requested immediate shutdown, or if it did not request immediate shutdown, but all jobs have been executed. In both cases, a cooldown period may need to be waited before shutdown. + + :returns: `True` if the server should terminate, otherwise `False` """ if self._shutdown_requested and self._can_shutdown: return self._has_cooled_down @@ -285,7 +318,9 @@ def should_shutdown(self) -> bool: @property def current_time(self) -> float: - """Current time for DragonBackend object, in seconds since the Epoch""" + """Current time for DragonBackend object, in seconds since the Epoch + + :returns: the current timestamp""" return time.time() def _can_honor_policy( @@ -293,63 +328,149 @@ def _can_honor_policy( ) -> t.Tuple[bool, t.Optional[str]]: """Check if the policy can be honored with resources available in the allocation. - :param request: DragonRunRequest containing policy information + + :param request: `DragonRunRequest` to validate :returns: Tuple indicating if the policy can be honored and an optional error message""" # ensure the policy can be honored if request.policy: + logger.debug(f"{request.policy=}{self._cpus=}{self._gpus=}") + if request.policy.cpu_affinity: # make sure some node has enough CPUs - available = max(self._cpus) + last_available = max(self._cpus or [-1]) requested = max(request.policy.cpu_affinity) - - if requested >= available: + if not any(self._cpus) or requested >= last_available: return False, "Cannot satisfy request, not enough CPUs available" - if request.policy.gpu_affinity: # make sure some node has enough GPUs - available = max(self._gpus) + last_available = max(self._gpus or [-1]) requested = max(request.policy.gpu_affinity) - - if requested >= available: + if not any(self._gpus) or requested >= last_available: + logger.warning( + f"failed check w/{self._gpus=}, {requested=}, {last_available=}" + ) return False, "Cannot satisfy request, not enough GPUs available" - return True, None def _can_honor(self, request: DragonRunRequest) -> t.Tuple[bool, t.Optional[str]]: - """Check if request can be honored with resources available in the allocation. - - Currently only checks for total number of nodes, - in the future it will also look at other constraints - such as memory, accelerators, and so on. + """Check if request can be honored with resources available in + the allocation. Currently only checks for total number of nodes, + in the future it will also look at other constraints such as memory, + accelerators, and so on. + + :param request: `DragonRunRequest` to validate + :returns: Tuple indicating if the request can be honored and + an optional error message """ - if request.nodes > len(self._hosts): - message = f"Cannot satisfy request. Requested {request.nodes} nodes, " - message += f"but only {len(self._hosts)} nodes are available." - return False, message - if self._shutdown_requested: - message = "Cannot satisfy request, server is shutting down." - return False, message + honorable, err = self._can_honor_state(request) + if not honorable: + return False, err honorable, err = self._can_honor_policy(request) if not honorable: return False, err + honorable, err = self._can_honor_hosts(request) + if not honorable: + return False, err + + return True, None + + def _can_honor_hosts( + self, request: DragonRunRequest + ) -> t.Tuple[bool, t.Optional[str]]: + """Check if the current state of the backend process inhibits executing + the request. + + :param request: `DragonRunRequest` to validate + :returns: Tuple indicating if the request can be honored and + an optional error message""" + all_hosts = frozenset(self._hosts) + num_nodes = request.nodes + + # fail if requesting more nodes than the total number available + if num_nodes > len(all_hosts): + message = f"Cannot satisfy request. {num_nodes} requested nodes" + message += f" exceeds {len(all_hosts)} available." + return False, message + + requested_hosts = all_hosts + if request.hostlist: + requested_hosts = frozenset( + {host.strip() for host in request.hostlist.split(",")} + ) + + valid_hosts = all_hosts.intersection(requested_hosts) + invalid_hosts = requested_hosts - valid_hosts + + logger.debug(f"{num_nodes=}{valid_hosts=}{invalid_hosts=}") + + if invalid_hosts: + logger.warning(f"Some invalid hostnames were requested: {invalid_hosts}") + + # fail if requesting specific hostnames and there aren't enough available + if num_nodes > len(valid_hosts): + message = f"Cannot satisfy request. Requested {num_nodes} nodes, " + message += f"but only {len(valid_hosts)} named hosts are available." + return False, message + + return True, None + + def _can_honor_state( + self, _request: DragonRunRequest + ) -> t.Tuple[bool, t.Optional[str]]: + """Check if the current state of the backend process inhibits executing + the request. + :param _request: the DragonRunRequest to verify + :returns: Tuple indicating if the request can be honored and + an optional error message""" + if self._shutdown_requested: + message = "Cannot satisfy request, server is shutting down." + return False, message + return True, None def _allocate_step( self, step_id: str, request: DragonRunRequest ) -> t.Optional[t.List[str]]: + """Identify the hosts on which the request will be executed + :param step_id: The identifier of a step that will be executed on the host + :param request: The request to be executed + :returns: A list of selected hostnames""" + # ensure at least one host is selected num_hosts: int = request.nodes + with self._queue_lock: - if num_hosts <= 0 or num_hosts > len(self._free_hosts): + if num_hosts <= 0 or num_hosts > len(self._hosts): + logger.debug( + f"The number of requested hosts ({num_hosts}) is invalid or" + f" cannot be satisfied with {len(self._hosts)} available nodes" + ) return None - to_allocate = [] - for _ in range(num_hosts): - host = self._free_hosts.popleft() - self._allocated_hosts[host] = step_id - to_allocate.append(host) + + hosts = [] + if request.hostlist: + # convert the comma-separated argument into a real list + hosts = [host for host in request.hostlist.split(",") if host] + + filter_on: t.Optional[PrioritizerFilter] = None + if request.policy and request.policy.gpu_affinity: + filter_on = PrioritizerFilter.GPU + + nodes = self._prioritizer.next_n(num_hosts, filter_on, step_id, hosts) + + if len(nodes) < num_hosts: + # exit if the prioritizer can't identify enough nodes + return None + + to_allocate = [node.hostname for node in nodes] + + for hostname in to_allocate: + # track assigning this step to each node + self._allocated_hosts[hostname].add(step_id) + return to_allocate @staticmethod @@ -389,6 +510,7 @@ def _create_redirect_workers( return grp_redir def _stop_steps(self) -> None: + """Trigger termination of all currently executing steps""" self._heartbeat() with self._queue_lock: while len(self._stop_requests) > 0: @@ -427,18 +549,96 @@ def _stop_steps(self) -> None: self._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED self._group_infos[step_id].return_codes = [-9] + def _create_backbone(self) -> BackboneFeatureStore: + """ + Creates a BackboneFeatureStore if one does not exist. Updates + environment variables of this process to include the backbone + descriptor. + + :returns: The backbone feature store + """ + if self._backbone is None: + backbone_storage = create_ddict( + len(self._hosts), + self._DEFAULT_NUM_MGR_PER_NODE, + self._DEFAULT_MEM_PER_NODE, + ) + + self._backbone = BackboneFeatureStore( + backbone_storage, allow_reserved_writes=True + ) + + # put the backbone descriptor in the env vars + os.environ.update(self._backbone.get_env()) + + return self._backbone + + @staticmethod + def _initialize_cooldown() -> int: + """Load environment configuration and determine the correct cooldown + period to apply to the backend process. + + :returns: The calculated cooldown (in seconds) + """ + smartsim_config = get_config() + return ( + smartsim_config.telemetry_frequency * 2 + 5 + if smartsim_config.telemetry_enabled + else 5 + ) + + def start_event_listener( + self, cpu_affinity: list[int], gpu_affinity: list[int] + ) -> dragon_process.Process: + """Start a standalone event listener. + + :param cpu_affinity: The CPU affinity for the process + :param gpu_affinity: The GPU affinity for the process + :returns: The dragon Process managing the process + :raises SmartSimError: If the backbone is not provided + """ + if self._backbone is None: + raise SmartSimError("Backbone feature store is not available") + + service = ConsumerRegistrationListener( + self._backbone, 1.0, 2.0, as_service=True, health_check_frequency=90 + ) + + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + process = dragon_process.Process( + target=service.execute, + args=[], + cwd=os.getcwd(), + env={ + **os.environ, + **self._backbone.get_env(), + }, + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) + process.start() + return process + @staticmethod def create_run_policy( request: DragonRequest, node_name: str ) -> "dragon_policy.Policy": """Create a dragon Policy from the request and node name + :param request: DragonRunRequest containing policy information :param node_name: Name of the node on which the process will run :returns: dragon_policy.Policy object mapped from request properties""" if isinstance(request, DragonRunRequest): run_request: DragonRunRequest = request - affinity = dragon_policy.Policy.Affinity.DEFAULT cpu_affinity: t.List[int] = [] gpu_affinity: t.List[int] = [] @@ -446,25 +646,20 @@ def create_run_policy( if run_request.policy is not None: # Affinities are not mutually exclusive. If specified, both are used if run_request.policy.cpu_affinity: - affinity = dragon_policy.Policy.Affinity.SPECIFIC cpu_affinity = run_request.policy.cpu_affinity if run_request.policy.gpu_affinity: - affinity = dragon_policy.Policy.Affinity.SPECIFIC gpu_affinity = run_request.policy.gpu_affinity logger.debug( - f"Affinity strategy: {affinity}, " f"CPU affinity mask: {cpu_affinity}, " f"GPU affinity mask: {gpu_affinity}" ) - if affinity != dragon_policy.Policy.Affinity.DEFAULT: - return dragon_policy.Policy( - placement=dragon_policy.Policy.Placement.HOST_NAME, - host_name=node_name, - affinity=affinity, - cpu_affinity=cpu_affinity, - gpu_affinity=gpu_affinity, - ) + return dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=node_name, + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) return dragon_policy.Policy( placement=dragon_policy.Policy.Placement.HOST_NAME, @@ -472,7 +667,9 @@ def create_run_policy( ) def _start_steps(self) -> None: + """Start all new steps created since the last update.""" self._heartbeat() + with self._queue_lock: started = [] for step_id, request in self._queued_steps.items(): @@ -482,10 +679,8 @@ def _start_steps(self) -> None: logger.debug(f"Step id {step_id} allocated on {hosts}") - global_policy = dragon_policy.Policy( - placement=dragon_policy.Policy.Placement.HOST_NAME, - host_name=hosts[0], - ) + global_policy = self.create_run_policy(request, hosts[0]) + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) grp = dragon_process_group.ProcessGroup( restart=False, pmi_enabled=request.pmi_enabled, policy=global_policy ) @@ -498,10 +693,15 @@ def _start_steps(self) -> None: target=request.exe, args=request.exe_args, cwd=request.path, - env={**request.current_env, **request.env}, + env={ + **request.current_env, + **request.env, + **(self._backbone.get_env() if self._backbone else {}), + }, stdout=dragon_process.Popen.PIPE, stderr=dragon_process.Popen.PIPE, policy=local_policy, + options=options, ) grp.add_process(nproc=request.tasks_per_node, template=tmp_proc) @@ -567,9 +767,11 @@ def _start_steps(self) -> None: logger.error(e) def _refresh_statuses(self) -> None: + """Query underlying management system for step status and update + stored assigned and unassigned task information""" self._heartbeat() with self._queue_lock: - terminated = [] + terminated: t.Set[str] = set() for step_id in self._running_steps: group_info = self._group_infos[step_id] grp = group_info.process_group @@ -603,11 +805,15 @@ def _refresh_statuses(self) -> None: ) if group_info.status in TERMINAL_STATUSES: - terminated.append(step_id) + terminated.add(step_id) if terminated: logger.debug(f"{terminated=}") + # remove all the terminated steps from all hosts + for host in list(self._allocated_hosts.keys()): + self._allocated_hosts[host].difference_update(terminated) + for step_id in terminated: self._running_steps.remove(step_id) self._completed_steps.append(step_id) @@ -615,15 +821,20 @@ def _refresh_statuses(self) -> None: if group_info is not None: for host in group_info.hosts: logger.debug(f"Releasing host {host}") - try: - self._allocated_hosts.pop(host) - except KeyError: + if host not in self._allocated_hosts: logger.error(f"Tried to free a non-allocated host: {host}") - self._free_hosts.append(host) + else: + # remove any hosts that have had all their steps terminated + if not self._allocated_hosts[host]: + self._allocated_hosts.pop(host) + self._prioritizer.decrement(host, step_id) group_info.process_group = None group_info.redir_workers = None def _update_shutdown_status(self) -> None: + """Query the status of running tasks and update the status + of any that have completed. + """ self._heartbeat() with self._queue_lock: self._can_shutdown |= ( @@ -637,12 +848,18 @@ def _update_shutdown_status(self) -> None: ) def _should_print_status(self) -> bool: + """Determine if status messages should be printed based off the last + update. Returns `True` to trigger prints, `False` otherwise. + """ if self.current_time - self._last_update_time > 10: self._last_update_time = self.current_time return True return False def _update(self) -> None: + """Trigger all update queries and update local state database""" + self._create_backbone() + self._stop_steps() self._start_steps() self._refresh_statuses() @@ -650,6 +867,9 @@ def _update(self) -> None: def _kill_all_running_jobs(self) -> None: with self._queue_lock: + if self._listener and self._listener.is_alive: + self._listener.kill() + for step_id, group_info in self._group_infos.items(): if group_info.status not in TERMINAL_STATUSES: self._stop_requests.append(DragonStopRequest(step_id=step_id)) @@ -730,8 +950,14 @@ def _(self, request: DragonShutdownRequest) -> DragonShutdownResponse: class DragonBackendView: - def __init__(self, backend: DragonBackend): + def __init__(self, backend: DragonBackend) -> None: + """Initialize the instance + + :param backend: A dragon backend used to produce the view""" self._backend = backend + """A dragon backend used to produce the view""" + + logger.debug(self.host_desc) @property def host_desc(self) -> str: @@ -793,9 +1019,7 @@ def step_table(self) -> str: @property def host_table(self) -> str: """Table representation of current state of nodes available - - in the allocation. - """ + in the allocation.""" headers = ["Host", "Status"] hosts = self._backend.hosts free_hosts = self._backend.free_hosts diff --git a/smartsim/_core/launcher/dragon/dragonConnector.py b/smartsim/_core/launcher/dragon/dragonConnector.py index 0cd68c24e9..1144b7764e 100644 --- a/smartsim/_core/launcher/dragon/dragonConnector.py +++ b/smartsim/_core/launcher/dragon/dragonConnector.py @@ -71,17 +71,23 @@ class DragonConnector: def __init__(self) -> None: self._context: zmq.Context[t.Any] = zmq.Context.instance() + """ZeroMQ context used to share configuration across requests""" self._context.setsockopt(zmq.REQ_CORRELATE, 1) self._context.setsockopt(zmq.REQ_RELAXED, 1) self._authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None + """ZeroMQ authenticator used to secure queue access""" config = get_config() self._reset_timeout(config.dragon_server_timeout) self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None + """ZeroMQ socket exposing the connection to the DragonBackend""" self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None + """A handle to the process executing the DragonBackend""" # Returned by dragon head, useful if shutdown is to be requested # but process was started by another connector self._dragon_head_pid: t.Optional[int] = None + """Process ID of the process executing the DragonBackend""" self._dragon_server_path = config.dragon_server_path + """Path to a dragon installation""" logger.debug(f"Dragon Server path was set to {self._dragon_server_path}") self._env_vars: t.Dict[str, str] = {} if self._dragon_server_path is None: @@ -95,7 +101,7 @@ def __init__(self) -> None: @property def is_connected(self) -> bool: - """Whether the Connector established a connection to the server + """Whether the Connector established a connection to the server. :return: True if connected """ @@ -104,12 +110,18 @@ def is_connected(self) -> bool: @property def can_monitor(self) -> bool: """Whether the Connector knows the PID of the dragon server head process - and can monitor its status + and can monitor its status. :return: True if the server can be monitored""" return self._dragon_head_pid is not None def _handshake(self, address: str) -> None: + """Perform the handshake process with the DragonBackend and + confirm two-way communication is established. + + :param address: The address of the head node socket to initiate a + handhake with + """ self._dragon_head_socket = dragonSockets.get_secure_socket( self._context, zmq.REQ, False ) @@ -132,6 +144,11 @@ def _handshake(self, address: str) -> None: ) from e def _reset_timeout(self, timeout: int = get_config().dragon_server_timeout) -> None: + """Reset the timeout applied to the ZMQ context. If an authenticator is + enabled, also update the authenticator timeouts. + + :param timeout: The timeout value to apply to ZMQ sockets + """ self._context.setsockopt(zmq.SNDTIMEO, value=timeout) self._context.setsockopt(zmq.RCVTIMEO, value=timeout) if self._authenticator is not None and self._authenticator.thread is not None: @@ -183,11 +200,19 @@ def _get_new_authenticator( @staticmethod def _get_dragon_log_level() -> str: + """Maps the log level from SmartSim to a valid log level + for a dragon process. + + :returns: The dragon log level string + """ smartsim_to_dragon = defaultdict(lambda: "NONE") smartsim_to_dragon["developer"] = "INFO" return smartsim_to_dragon.get(get_config().log_level, "NONE") def _connect_to_existing_server(self, path: Path) -> None: + """Connects to an existing DragonBackend using address information from + a persisted dragon log file. + """ config = get_config() dragon_config_log = path / config.dragon_log_filename @@ -217,6 +242,11 @@ def _connect_to_existing_server(self, path: Path) -> None: return def _start_connector_socket(self, socket_addr: str) -> zmq.Socket[t.Any]: + """Instantiate the ZMQ socket to be used by the connector. + + :param socket_addr: The socket address the connector should bind to + :returns: The bound socket + """ config = get_config() connector_socket: t.Optional[zmq.Socket[t.Any]] = None self._reset_timeout(config.dragon_server_startup_timeout) @@ -245,9 +275,14 @@ def load_persisted_env(self) -> t.Dict[str, str]: with open(config.dragon_dotenv, encoding="utf-8") as dot_env: for kvp in dot_env.readlines(): - split = kvp.strip().split("=", maxsplit=1) - key, value = split[0], split[-1] - self._env_vars[key] = value + if not kvp: + continue + + # skip any commented lines + if not kvp.startswith("#"): + split = kvp.strip().split("=", maxsplit=1) + key, value = split[0], split[-1] + self._env_vars[key] = value return self._env_vars @@ -418,6 +453,15 @@ def send_request(self, request: DragonRequest, flags: int = 0) -> DragonResponse def _parse_launched_dragon_server_info_from_iterable( stream: t.Iterable[str], num_dragon_envs: t.Optional[int] = None ) -> t.List[t.Dict[str, str]]: + """Parses dragon backend connection information from a stream. + + :param stream: The stream to inspect. Usually the stdout of the + DragonBackend process + :param num_dragon_envs: The expected number of dragon environments + to parse from the stream. + :returns: A list of dictionaries, one per environment, containing + the parsed server information + """ lines = (line.strip() for line in stream) lines = (line for line in lines if line) tokenized = (line.split(maxsplit=1) for line in lines) @@ -444,6 +488,15 @@ def _parse_launched_dragon_server_info_from_files( file_paths: t.List[t.Union[str, "os.PathLike[str]"]], num_dragon_envs: t.Optional[int] = None, ) -> t.List[t.Dict[str, str]]: + """Read a known log file into a Stream and parse dragon server configuration + from the stream. + + :param file_paths: Path to a file containing dragon server configuration + :num_dragon_envs: The expected number of dragon environments to be found + in the file + :returns: The parsed server configuration, one item per + discovered dragon environment + """ with fileinput.FileInput(file_paths) as ifstream: dragon_envs = cls._parse_launched_dragon_server_info_from_iterable( ifstream, num_dragon_envs @@ -458,6 +511,15 @@ def _send_req_with_socket( send_flags: int = 0, recv_flags: int = 0, ) -> DragonResponse: + """Sends a synchronous request through a ZMQ socket. + + :param socket: Socket to send on + :param request: The request to send + :param send_flags: Configuration to apply to the send operation + :param recv_flags: Configuration to apply to the recv operation; used to + allow the receiver to immediately respond to the sent request. + :returns: The response from the target + """ client = dragonSockets.as_client(socket) with DRG_LOCK: logger.debug(f"Sending {type(request).__name__}: {request}") @@ -469,6 +531,13 @@ def _send_req_with_socket( def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT: + """Verify that objects can be sent as messages acceptable to the target. + + :param obj: The message to test + :param typ: The type that is acceptable + :returns: The original `obj` if it is of the requested type + :raises TypeError: If the object fails the test and is not + an instance of the desired type""" if not isinstance(obj, typ): raise TypeError(f"Expected schema of type `{typ}`, but got {type(obj)}") return obj @@ -520,6 +589,12 @@ def _dragon_cleanup( def _resolve_dragon_path(fallback: t.Union[str, "os.PathLike[str]"]) -> Path: + """Determine the applicable dragon server path for the connector + + :param fallback: A default dragon server path to use if one is not + found in the runtime configuration + :returns: The path to the dragon libraries + """ dragon_server_path = get_config().dragon_server_path or os.path.join( fallback, ".smartsim", "dragon" ) diff --git a/smartsim/_core/launcher/dragon/dragonLauncher.py b/smartsim/_core/launcher/dragon/dragonLauncher.py index 9078fed54f..75ca675225 100644 --- a/smartsim/_core/launcher/dragon/dragonLauncher.py +++ b/smartsim/_core/launcher/dragon/dragonLauncher.py @@ -170,6 +170,7 @@ def run(self, step: Step) -> t.Optional[str]: merged_env = self._connector.merge_persisted_env(os.environ.copy()) nodes = int(run_args.get("nodes", None) or 1) tasks_per_node = int(run_args.get("tasks-per-node", None) or 1) + hosts = run_args.get("host-list", None) policy = DragonRunPolicy.from_run_args(run_args) @@ -187,6 +188,7 @@ def run(self, step: Step) -> t.Optional[str]: output_file=out, error_file=err, policy=policy, + hostlist=hosts, ) ), DragonRunResponse, diff --git a/smartsim/_core/launcher/dragon/pqueue.py b/smartsim/_core/launcher/dragon/pqueue.py new file mode 100644 index 0000000000..8c14a828f5 --- /dev/null +++ b/smartsim/_core/launcher/dragon/pqueue.py @@ -0,0 +1,461 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# import collections +import enum +import heapq +import threading +import typing as t + +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class Node(t.Protocol): + """Base Node API required to support the NodePrioritizer""" + + @property + def hostname(self) -> str: + """The hostname of the node""" + + @property + def num_cpus(self) -> int: + """The number of CPUs in the node""" + + @property + def num_gpus(self) -> int: + """The number of GPUs in the node""" + + +class NodeReferenceCount(t.Protocol): + """Contains details pertaining to references to a node""" + + @property + def hostname(self) -> str: + """The hostname of the node""" + + @property + def num_refs(self) -> int: + """The number of jobs assigned to the node""" + + +class _TrackedNode: + """Node API required to have support in the NodePrioritizer""" + + def __init__(self, node: Node) -> None: + self._node = node + """The node being tracked""" + self._num_refs = 0 + """The number of references to the tracked node""" + self._assigned_tasks: t.Set[str] = set() + """The unique identifiers of processes using this node""" + self._is_dirty = False + """Flag indicating that tracking information has been modified""" + + @property + def hostname(self) -> str: + """Returns the hostname of the node""" + return self._node.hostname + + @property + def num_cpus(self) -> int: + """Returns the number of CPUs in the node""" + return self._node.num_cpus + + @property + def num_gpus(self) -> int: + """Returns the number of GPUs attached to the node""" + return self._node.num_gpus + + @property + def num_refs(self) -> int: + """Returns the number of processes currently running on the node""" + return self._num_refs + + @property + def is_assigned(self) -> bool: + """Returns `True` if no references are currently counted, `False` otherwise""" + return self._num_refs > 0 + + @property + def assigned_tasks(self) -> t.Set[str]: + """Returns the set of unique IDs for currently running processes""" + return self._assigned_tasks + + @property + def is_dirty(self) -> bool: + """Returns a flag indicating if the reference counter has changed. `True` + if references have been added or removed, `False` otherwise.""" + return self._is_dirty + + def clean(self) -> None: + """Marks the node as unmodified""" + self._is_dirty = False + + def add( + self, + tracking_id: t.Optional[str] = None, + ) -> None: + """Update the node to indicate the addition of a process that must be + reference counted. + + :param tracking_id: a unique task identifier executing on the node + to add + :raises ValueError: if tracking_id is already assigned to this node""" + if tracking_id in self.assigned_tasks: + raise ValueError("Attempted adding task more than once") + + self._num_refs = self._num_refs + 1 + if tracking_id: + self._assigned_tasks = self._assigned_tasks.union({tracking_id}) + self._is_dirty = True + + def remove( + self, + tracking_id: t.Optional[str] = None, + ) -> None: + """Update the reference counter to indicate the removal of a process. + + :param tracking_id: a unique task identifier executing on the node + to remove + :raises ValueError: if tracking_id is already assigned to this node""" + self._num_refs = max(self._num_refs - 1, 0) + if tracking_id: + self._assigned_tasks = self._assigned_tasks - {tracking_id} + self._is_dirty = True + + def __lt__(self, other: "_TrackedNode") -> bool: + """Comparison operator used to evaluate the ordering of nodes within + the prioritizer. This comparison only considers reference counts. + + :param other: Another node to compare against + :returns: True if this node has fewer references than the other node""" + if self.num_refs < other.num_refs: + return True + + return False + + +class PrioritizerFilter(str, enum.Enum): + """A filter used to select a subset of nodes to be queried""" + + CPU = enum.auto() + GPU = enum.auto() + + +class NodePrioritizer: + def __init__(self, nodes: t.List[Node], lock: threading.RLock) -> None: + """Initialize the prioritizer + + :param nodes: node attribute information for initializing the priorizer + :param lock: a lock used to ensure threadsafe operations + :raises SmartSimError: if the nodes collection is empty + """ + if not nodes: + raise SmartSimError("Missing nodes to prioritize") + + self._lock = lock + """Lock used to ensure thread safe changes of the reference counters""" + self._cpu_refs: t.List[_TrackedNode] = [] + """Track reference counts to CPU-only nodes""" + self._gpu_refs: t.List[_TrackedNode] = [] + """Track reference counts to GPU nodes""" + self._nodes: t.Dict[str, _TrackedNode] = {} + + self._initialize_reference_counters(nodes) + + def _initialize_reference_counters(self, nodes: t.List[Node]) -> None: + """Perform initialization of reference counters for nodes in the allocation + + :param nodes: node attribute information for initializing the priorizer""" + for node in nodes: + # create a set of reference counters for the nodes + tracked = _TrackedNode(node) + + self._nodes[node.hostname] = tracked # for O(1) access + + if node.num_gpus: + self._gpu_refs.append(tracked) + else: + self._cpu_refs.append(tracked) + + def increment( + self, host: str, tracking_id: t.Optional[str] = None + ) -> NodeReferenceCount: + """Directly increment the reference count of a given node and ensure the + ref counter is marked as dirty to trigger a reordering on retrieval + + :param host: a hostname that should have a reference counter selected + :param tracking_id: a unique task identifier executing on the node + to add""" + with self._lock: + tracked_node = self._nodes[host] + tracked_node.add(tracking_id) + return tracked_node + + def _heapify_all_refs(self) -> t.List[_TrackedNode]: + """Combine the CPU and GPU nodes into a single heap + + :returns: list of all reference counters""" + refs = [*self._cpu_refs, *self._gpu_refs] + heapq.heapify(refs) + return refs + + def get_tracking_info(self, host: str) -> NodeReferenceCount: + """Returns the reference counter information for a single node + + :param host: a hostname that should have a reference counter selected + :returns: a reference counter for the node + :raises ValueError: if the hostname is not in the set of managed nodes""" + if host not in self._nodes: + raise ValueError("The supplied hostname was not found") + + return self._nodes[host] + + def decrement( + self, host: str, tracking_id: t.Optional[str] = None + ) -> NodeReferenceCount: + """Directly decrement the reference count of a given node and ensure the + ref counter is marked as dirty to trigger a reordering + + :param host: a hostname that should have a reference counter decremented + :param tracking_id: unique task identifier to remove""" + with self._lock: + tracked_node = self._nodes[host] + tracked_node.remove(tracking_id) + + return tracked_node + + def _create_sub_heap( + self, + hosts: t.Optional[t.List[str]] = None, + filter_on: t.Optional[PrioritizerFilter] = None, + ) -> t.List[_TrackedNode]: + """Create a new heap from the primary heap with user-specified nodes + + :param hosts: a list of hostnames used to filter the available nodes + :returns: a list of assigned reference counters + """ + nodes_tracking_info: t.List[_TrackedNode] = [] + heap = self._get_filtered_heap(filter_on) + + # Collect all the tracking info for the requested nodes... + for node in heap: + if not hosts or node.hostname in hosts: + nodes_tracking_info.append(node) + + # ... and use it to create a new heap from a specified subset of nodes + heapq.heapify(nodes_tracking_info) + + return nodes_tracking_info + + def unassigned( + self, heap: t.Optional[t.List[_TrackedNode]] = None + ) -> t.Sequence[Node]: + """Select nodes that are currently not assigned a task + + :param heap: a subset of the node heap to consider + :returns: a list of reference counts for all unassigned nodes""" + if heap is None: + heap = list(self._nodes.values()) + + nodes: t.List[_TrackedNode] = [] + for item in heap: + if item.num_refs == 0: + nodes.append(item) + return nodes + + def assigned( + self, heap: t.Optional[t.List[_TrackedNode]] = None + ) -> t.Sequence[Node]: + """Helper method to identify the nodes that are currently assigned + + :param heap: a subset of the node heap to consider + :returns: a list of reference counts for all assigned nodes""" + if heap is None: + heap = list(self._nodes.values()) + + nodes: t.List[_TrackedNode] = [] + for item in heap: + if item.num_refs > 0: + nodes.append(item) + return nodes + + def _check_satisfiable_n( + self, num_items: int, heap: t.Optional[t.List[_TrackedNode]] = None + ) -> bool: + """Validates that a request for some number of nodes `n` can be + satisfied by the prioritizer given the set of nodes available + + :param num_items: the desired number of nodes to allocate + :param heap: a subset of the node heap to consider + :returns: True if the request can be fulfilled, False otherwise""" + num_nodes = len(self._nodes.keys()) + + if num_items < 1: + msg = "Cannot handle request; request requires a positive integer" + logger.warning(msg) + return False + + if num_nodes < num_items: + msg = f"Cannot satisfy request for {num_items} nodes; {num_nodes} in pool" + logger.warning(msg) + return False + + num_open = len(self.unassigned(heap)) + if num_open < num_items: + msg = f"Cannot satisfy request for {num_items} nodes; {num_open} available" + logger.warning(msg) + return False + + return True + + def _get_next_unassigned_node( + self, + heap: t.List[_TrackedNode], + tracking_id: t.Optional[str] = None, + ) -> t.Optional[Node]: + """Finds the next node with no running processes and + ensures that any elements that were directly updated are updated in + the priority structure before being made available + + :param heap: a subset of the node heap to consider + :param tracking_id: unique task identifier to track + :returns: a reference counter for an available node if an unassigned node + exists, `None` otherwise""" + tracking_info: t.Optional[_TrackedNode] = None + + with self._lock: + # re-sort the heap to handle any tracking changes + if any(node.is_dirty for node in heap): + heapq.heapify(heap) + + # grab the min node from the heap + tracking_info = heapq.heappop(heap) + + # the node is available if it has no assigned tasks + is_assigned = tracking_info.is_assigned + if not is_assigned: + # track the new process on the node + tracking_info.add(tracking_id) + + # add the node that was popped back into the heap + heapq.heappush(heap, tracking_info) + + # mark all nodes as clean now that everything is updated & sorted + for node in heap: + node.clean() + + # next available must only return previously unassigned nodes + if is_assigned: + return None + + return tracking_info + + def _get_next_n_available_nodes( + self, + num_items: int, + heap: t.List[_TrackedNode], + tracking_id: t.Optional[str] = None, + ) -> t.List[Node]: + """Find the next N available nodes w/least amount of references using + the supplied filter to target a specific node capability + + :param num_items: number of nodes to reserve + :param heap: a subset of the node heap to consider + :param tracking_id: unique task identifier to track + :returns: a list of reference counters for a available nodes if enough + unassigned nodes exists, `None` otherwise + :raises ValueError: if the number of requested nodes is not a positive integer + """ + next_nodes: t.List[Node] = [] + + if num_items < 1: + raise ValueError(f"Number of items requested {num_items} is invalid") + + if not self._check_satisfiable_n(num_items, heap): + return next_nodes + + while len(next_nodes) < num_items: + if next_node := self._get_next_unassigned_node(heap, tracking_id): + next_nodes.append(next_node) + continue + break + + return next_nodes + + def _get_filtered_heap( + self, filter_on: t.Optional[PrioritizerFilter] = None + ) -> t.List[_TrackedNode]: + """Helper method to select the set of nodes to include in a filtered + heap. + + :param filter_on: A list of nodes that satisfy the filter. If no + filter is supplied, all nodes are returned""" + if filter_on == PrioritizerFilter.GPU: + return self._gpu_refs + if filter_on == PrioritizerFilter.CPU: + return self._cpu_refs + + return self._heapify_all_refs() + + def next( + self, + filter_on: t.Optional[PrioritizerFilter] = None, + tracking_id: t.Optional[str] = None, + hosts: t.Optional[t.List[str]] = None, + ) -> t.Optional[Node]: + """Find the next unsassigned node using the supplied filter to target + a specific node capability + + :param filter_on: the subset of nodes to query for available nodes + :param tracking_id: unique task identifier to track + :param hosts: a list of hostnames used to filter the available nodes + :returns: a reference counter for an available node if an unassigned node + exists, `None` otherwise""" + if results := self.next_n(1, filter_on, tracking_id, hosts): + return results[0] + return None + + def next_n( + self, + num_items: int = 1, + filter_on: t.Optional[PrioritizerFilter] = None, + tracking_id: t.Optional[str] = None, + hosts: t.Optional[t.List[str]] = None, + ) -> t.List[Node]: + """Find the next N available nodes w/least amount of references using + the supplied filter to target a specific node capability + + :param num_items: number of nodes to reserve + :param filter_on: the subset of nodes to query for available nodes + :param tracking_id: unique task identifier to track + :param hosts: a list of hostnames used to filter the available nodes + :returns: Collection of reserved nodes + :raises ValueError: if the hosts parameter is an empty list""" + heap = self._create_sub_heap(hosts, filter_on) + return self._get_next_n_available_nodes(num_items, heap, tracking_id) diff --git a/smartsim/_core/launcher/step/dragonStep.py b/smartsim/_core/launcher/step/dragonStep.py index dd93d7910c..8583ceeb1b 100644 --- a/smartsim/_core/launcher/step/dragonStep.py +++ b/smartsim/_core/launcher/step/dragonStep.py @@ -169,6 +169,7 @@ def _write_request_file(self) -> str: env = run_settings.env_vars nodes = int(run_args.get("nodes", None) or 1) tasks_per_node = int(run_args.get("tasks-per-node", None) or 1) + hosts_csv = run_args.get("host-list", None) policy = DragonRunPolicy.from_run_args(run_args) @@ -187,6 +188,7 @@ def _write_request_file(self) -> str: output_file=out, error_file=err, policy=policy, + hostlist=hosts_csv, ) requests.append(request_registry.to_string(request)) with open(request_file, "w", encoding="utf-8") as script_file: diff --git a/smartsim/_core/mli/__init__.py b/smartsim/_core/mli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/client/__init__.py b/smartsim/_core/mli/client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/client/protoclient.py b/smartsim/_core/mli/client/protoclient.py new file mode 100644 index 0000000000..46598a8171 --- /dev/null +++ b/smartsim/_core/mli/client/protoclient.py @@ -0,0 +1,348 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +# pylint: disable=unused-import,import-error +import dragon +import dragon.channels +from dragon.globalservices.api_setup import connect_to_infrastructure + +try: + from mpi4py import MPI # type: ignore[import-not-found] +except Exception: + MPI = None + print("Unable to import `mpi4py` package") + +# isort: on +# pylint: enable=unused-import,import-error + +import numbers +import os +import time +import typing as t +from collections import OrderedDict + +import numpy +import torch + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.utils.timings import PerfTimer +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +_TimingDict = OrderedDict[str, list[str]] + + +logger = get_logger("App") +logger.info("Started app") +CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False + + +class ProtoClient: + """Proof of concept implementation of a client enabling user applications + to interact with MLI resources.""" + + _DEFAULT_BACKBONE_TIMEOUT = 1.0 + """A default timeout period applied to connection attempts with the + backbone feature store.""" + + _DEFAULT_WORK_QUEUE_SIZE = 500 + """A default number of events to be buffered in the work queue before + triggering QueueFull exceptions.""" + + _EVENT_SOURCE = "proto-client" + """A user-friendly name for this class instance to identify + the client as the publisher of an event.""" + + @staticmethod + def _attach_to_backbone() -> BackboneFeatureStore: + """Use the supplied environment variables to attach + to a pre-existing backbone featurestore. Requires the + environment to contain `_SMARTSIM_INFRA_BACKBONE` + environment variable. + + :returns: The attached backbone featurestore + :raises SmartSimError: If the backbone descriptor is not contained + in the appropriate environment variable + """ + descriptor = os.environ.get(BackboneFeatureStore.MLI_BACKBONE, None) + if descriptor is None or not descriptor: + raise SmartSimError( + "Missing required backbone configuration in environment: " + f"{BackboneFeatureStore.MLI_BACKBONE}" + ) + + backbone = t.cast( + BackboneFeatureStore, BackboneFeatureStore.from_descriptor(descriptor) + ) + return backbone + + def _attach_to_worker_queue(self) -> DragonFLIChannel: + """Wait until the backbone contains the worker queue configuration, + then attach an FLI to the given worker queue. + + :returns: The attached FLI channel + :raises SmartSimError: if the required configuration is not found in the + backbone feature store + """ + + descriptor = "" + try: + # NOTE: without wait_for, this MUST be in the backbone.... + config = self._backbone.wait_for( + [BackboneFeatureStore.MLI_WORKER_QUEUE], self.backbone_timeout + ) + descriptor = str(config[BackboneFeatureStore.MLI_WORKER_QUEUE]) + except Exception as ex: + logger.info( + f"Unable to retrieve {BackboneFeatureStore.MLI_WORKER_QUEUE} " + "to attach to the worker queue." + ) + raise SmartSimError("Unable to locate worker queue using backbone") from ex + + return DragonFLIChannel.from_descriptor(descriptor) + + def _create_broadcaster(self) -> EventBroadcaster: + """Create an EventBroadcaster that broadcasts events to + all MLI components registered to consume them. + + :returns: An EventBroadcaster instance + """ + broadcaster = EventBroadcaster( + self._backbone, DragonCommChannel.from_descriptor + ) + return broadcaster + + def __init__( + self, + timing_on: bool, + backbone_timeout: float = _DEFAULT_BACKBONE_TIMEOUT, + ) -> None: + """Initialize the client instance. + + :param timing_on: Flag indicating if timing information should be + written to file + :param backbone_timeout: Maximum wait time (in seconds) allowed to attach to the + worker queue + :raises SmartSimError: If unable to attach to a backbone featurestore + :raises ValueError: If an invalid backbone timeout is specified + """ + if MPI is not None: + # TODO: determine a way to make MPI work in the test environment + # - consider catching the import exception and defaulting rank to 0 + comm = MPI.COMM_WORLD + rank: int = comm.Get_rank() + else: + rank = 0 + + if backbone_timeout <= 0: + raise ValueError( + f"Invalid backbone timeout provided: {backbone_timeout}. " + "The value must be greater than zero." + ) + self._backbone_timeout = max(backbone_timeout, 0.1) + + connect_to_infrastructure() + + self._backbone = self._attach_to_backbone() + self._backbone.wait_timeout = self.backbone_timeout + self._to_worker_fli = self._attach_to_worker_queue() + + self._from_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE) + self._to_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE) + + self._publisher = self._create_broadcaster() + + self.perf_timer: PerfTimer = PerfTimer( + debug=False, timing_on=timing_on, prefix=f"a{rank}_" + ) + self._start: t.Optional[float] = None + self._interm: t.Optional[float] = None + self._timings: _TimingDict = OrderedDict() + self._timing_on = timing_on + + @property + def backbone_timeout(self) -> float: + """The timeout (in seconds) applied to retrievals + from the backbone feature store. + + :returns: A float indicating the number of seconds to allow""" + return self._backbone_timeout + + def _add_label_to_timings(self, label: str) -> None: + """Adds a new label into the timing dictionary to prepare for + receiving timing events. + + :param label: The label to create storage for + """ + if label not in self._timings: + self._timings[label] = [] + + @staticmethod + def _format_number(number: t.Union[numbers.Number, float]) -> str: + """Utility function for formatting numbers consistently for logs. + + :param number: The number to convert to a formatted string + :returns: The formatted string containing the number + """ + return f"{number:0.4e}" + + def start_timings(self, batch_size: numbers.Number) -> None: + """Configure the client to begin storing timing information. + + :param batch_size: The size of batches to generate as inputs + to the model + """ + if self._timing_on: + self._add_label_to_timings("batch_size") + self._timings["batch_size"].append(self._format_number(batch_size)) + self._start = time.perf_counter() + self._interm = time.perf_counter() + + def end_timings(self) -> None: + """Configure the client to stop storing timing information.""" + if self._timing_on and self._start is not None: + self._add_label_to_timings("total_time") + self._timings["total_time"].append( + self._format_number(time.perf_counter() - self._start) + ) + + def measure_time(self, label: str) -> None: + """Measures elapsed time since the last recorded signal. + + :param label: The label to measure time for + """ + if self._timing_on and self._interm is not None: + self._add_label_to_timings(label) + self._timings[label].append( + self._format_number(time.perf_counter() - self._interm) + ) + self._interm = time.perf_counter() + + def print_timings(self, to_file: bool = False) -> None: + """Print timing information to standard output. If `to_file` + is `True`, also write results to a file. + + :param to_file: If `True`, also saves timing information + to the files `timings.npy` and `timings.txt` + """ + print(" ".join(self._timings.keys())) + + value_array = numpy.array(self._timings.values(), dtype=float) + value_array = numpy.transpose(value_array) + for i in range(value_array.shape[0]): + print(" ".join(self._format_number(value) for value in value_array[i])) + if to_file: + numpy.save("timings.npy", value_array) + numpy.savetxt("timings.txt", value_array) + + def run_model(self, model: t.Union[bytes, str], batch: torch.Tensor) -> t.Any: + """Execute a batch of inference requests with the supplied ML model. + + :param model: The raw bytes or path to a pytorch model + :param batch: The tensor batch to perform inference on + :returns: The inference results + :raises ValueError: if the worker queue is not configured properly + in the environment variables + """ + tensors = [batch.numpy()] + self.perf_timer.start_timings("batch_size", batch.shape[0]) + built_tensor_desc = MessageHandler.build_tensor_descriptor( + "c", "float32", list(batch.shape) + ) + self.perf_timer.measure_time("build_tensor_descriptor") + if isinstance(model, str): + model_arg = MessageHandler.build_model_key(model, self._backbone.descriptor) + else: + model_arg = MessageHandler.build_model( + model, "resnet-50", "1.0" + ) # type: ignore + request = MessageHandler.build_request( + reply_channel=self._from_worker_ch.descriptor, + model=model_arg, + inputs=[built_tensor_desc], + outputs=[], + output_descriptors=[], + custom_attributes=None, + ) + self.perf_timer.measure_time("build_request") + request_bytes = MessageHandler.serialize_request(request) + self.perf_timer.measure_time("serialize_request") + + if self._to_worker_fli is None: + raise ValueError("No worker queue available.") + + # pylint: disable-next=protected-access + with self._to_worker_fli._channel.sendh( # type: ignore + timeout=None, + stream_channel=self._to_worker_ch.channel, + ) as to_sendh: + to_sendh.send_bytes(request_bytes) + self.perf_timer.measure_time("send_request") + for tensor in tensors: + to_sendh.send_bytes(tensor.tobytes()) # TODO NOT FAST ENOUGH!!! + logger.info(f"Message size: {len(request_bytes)} bytes") + + self.perf_timer.measure_time("send_tensors") + with self._from_worker_ch.channel.recvh(timeout=None) as from_recvh: + resp = from_recvh.recv_bytes(timeout=None) + self.perf_timer.measure_time("receive_response") + response = MessageHandler.deserialize_response(resp) + self.perf_timer.measure_time("deserialize_response") + + # recv depending on the len(response.result.descriptors)? + data_blob: bytes = from_recvh.recv_bytes(timeout=None) + self.perf_timer.measure_time("receive_tensor") + result = torch.from_numpy( + numpy.frombuffer( + data_blob, + dtype=str(response.result.descriptors[0].dataType), + ) + ) + self.perf_timer.measure_time("deserialize_tensor") + + self.perf_timer.end_timings() + return result + + def set_model(self, key: str, model: bytes) -> None: + """Write the supplied model to the feature store. + + :param key: The unique key used to identify the model + :param model: The raw bytes of the model to execute + """ + self._backbone[key] = model + + # notify components of a change in the data at this key + event = OnWriteFeatureStore(self._EVENT_SOURCE, self._backbone.descriptor, key) + self._publisher.send(event) diff --git a/smartsim/_core/mli/comm/channel/__init__.py b/smartsim/_core/mli/comm/channel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/comm/channel/channel.py b/smartsim/_core/mli/comm/channel/channel.py new file mode 100644 index 0000000000..104333ce7f --- /dev/null +++ b/smartsim/_core/mli/comm/channel/channel.py @@ -0,0 +1,82 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import typing as t +import uuid +from abc import ABC, abstractmethod + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class CommChannelBase(ABC): + """Base class for abstracting a message passing mechanism""" + + def __init__( + self, + descriptor: str, + name: t.Optional[str] = None, + ) -> None: + """Initialize the CommChannel instance. + + :param descriptor: Channel descriptor + """ + self._descriptor = descriptor + """An opaque identifier used to connect to an underlying communication channel""" + self._name = name or str(uuid.uuid4()) + """A user-friendly identifier for channel-related logging""" + + @abstractmethod + def send(self, value: bytes, timeout: float = 0.001) -> None: + """Send a message through the underlying communication channel. + + :param value: The value to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :raises SmartSimError: If sending message fails + """ + + @abstractmethod + def recv(self, timeout: float = 0.001) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: Maximum time to wait (in seconds) for messages to arrive + :returns: The received message + """ + + @property + def descriptor(self) -> str: + """Return the channel descriptor for the underlying dragon channel. + + :returns: Byte encoded channel descriptor + """ + return self._descriptor + + def __str__(self) -> str: + """Build a string representation of the channel useful for printing.""" + classname = type(self).__class__.__name__ + return f"{classname}('{self._name}', '{self._descriptor}')" diff --git a/smartsim/_core/mli/comm/channel/dragon_channel.py b/smartsim/_core/mli/comm/channel/dragon_channel.py new file mode 100644 index 0000000000..110f19258a --- /dev/null +++ b/smartsim/_core/mli/comm/channel/dragon_channel.py @@ -0,0 +1,127 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import dragon.channels as dch + +import smartsim._core.mli.comm.channel.channel as cch +import smartsim._core.mli.comm.channel.dragon_util as drg_util +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class DragonCommChannel(cch.CommChannelBase): + """Passes messages by writing to a Dragon channel.""" + + def __init__(self, channel: "dch.Channel") -> None: + """Initialize the DragonCommChannel instance. + + :param channel: A channel to use for communications + """ + descriptor = drg_util.channel_to_descriptor(channel) + super().__init__(descriptor) + self._channel = channel + """The underlying dragon channel used by this CommChannel for communications""" + + @property + def channel(self) -> "dch.Channel": + """The underlying communication channel. + + :returns: The channel + """ + return self._channel + + def send(self, value: bytes, timeout: float = 0.001) -> None: + """Send a message through the underlying communication channel. + + :param value: The value to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :raises SmartSimError: If sending message fails + """ + try: + with self._channel.sendh(timeout=timeout) as sendh: + sendh.send_bytes(value, blocking=False) + logger.debug(f"DragonCommChannel {self.descriptor} sent message") + except Exception as e: + raise SmartSimError( + f"Error sending via DragonCommChannel {self.descriptor}" + ) from e + + def recv(self, timeout: float = 0.001) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: Maximum time to wait (in seconds) for messages to arrive + :returns: The received message(s) + """ + with self._channel.recvh(timeout=timeout) as recvh: + messages: t.List[bytes] = [] + + try: + message_bytes = recvh.recv_bytes(timeout=timeout) + messages.append(message_bytes) + logger.debug(f"DragonCommChannel {self.descriptor} received message") + except dch.ChannelEmpty: + # emptied the queue, ok to swallow this ex + logger.debug(f"DragonCommChannel exhausted: {self.descriptor}") + except dch.ChannelRecvTimeout: + logger.debug(f"Timeout exceeded on channel.recv: {self.descriptor}") + + return messages + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "DragonCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource. + :returns: An attached DragonCommChannel + :raises SmartSimError: If creation of comm channel fails + """ + try: + channel = drg_util.descriptor_to_channel(descriptor) + return DragonCommChannel(channel) + except Exception as ex: + raise SmartSimError( + f"Failed to create dragon comm channel: {descriptor}" + ) from ex + + @classmethod + def from_local(cls, _descriptor: t.Optional[str] = None) -> "DragonCommChannel": + """A factory method that creates a local channel instance. + + :param _descriptor: Unused placeholder + :returns: An attached DragonCommChannel""" + try: + channel = drg_util.create_local() + return DragonCommChannel(channel) + except: + logger.error(f"Failed to create local dragon comm channel", exc_info=True) + raise diff --git a/smartsim/_core/mli/comm/channel/dragon_fli.py b/smartsim/_core/mli/comm/channel/dragon_fli.py new file mode 100644 index 0000000000..5fb0790a84 --- /dev/null +++ b/smartsim/_core/mli/comm/channel/dragon_fli.py @@ -0,0 +1,158 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +import typing as t + +import smartsim._core.mli.comm.channel.channel as cch +import smartsim._core.mli.comm.channel.dragon_util as drg_util +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class DragonFLIChannel(cch.CommChannelBase): + """Passes messages by writing to a Dragon FLI Channel.""" + + def __init__( + self, + fli_: fli.FLInterface, + buffer_size: int = drg_util.DEFAULT_CHANNEL_BUFFER_SIZE, + ) -> None: + """Initialize the DragonFLIChannel instance. + + :param fli_: The FLIInterface to use as the underlying communications channel + :param sender_supplied: Flag indicating if the FLI uses sender-supplied streams + :param buffer_size: Maximum number of sent messages that can be buffered + """ + descriptor = drg_util.channel_to_descriptor(fli_) + super().__init__(descriptor) + + self._channel: t.Optional["Channel"] = None + """The underlying dragon Channel used by a sender-side DragonFLIChannel + to attach to the main FLI channel""" + + self._fli = fli_ + """The underlying dragon FLInterface used by this CommChannel for communications""" + self._buffer_size: int = buffer_size + """Maximum number of messages that can be buffered before sending""" + + def send(self, value: bytes, timeout: float = 0.001) -> None: + """Send a message through the underlying communication channel. + + :param value: The value to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :raises SmartSimError: If sending message fails + """ + try: + if self._channel is None: + self._channel = drg_util.create_local(self._buffer_size) + + with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: + sendh.send_bytes(value, timeout=timeout) + logger.debug(f"DragonFLIChannel {self.descriptor} sent message") + except Exception as e: + self._channel = None + raise SmartSimError( + f"Error sending via DragonFLIChannel {self.descriptor}" + ) from e + + def send_multiple( + self, + values: t.Sequence[bytes], + timeout: float = 0.001, + ) -> None: + """Send a message through the underlying communication channel. + + :param values: The values to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :raises SmartSimError: If sending message fails + """ + try: + if self._channel is None: + self._channel = drg_util.create_local(self._buffer_size) + + with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: + for value in values: + sendh.send_bytes(value) + logger.debug(f"DragonFLIChannel {self.descriptor} sent message") + except Exception as e: + self._channel = None + raise SmartSimError( + f"Error sending via DragonFLIChannel {self.descriptor} {e}" + ) from e + + def recv(self, timeout: float = 0.001) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: Maximum time to wait (in seconds) for messages to arrive + :returns: The received message(s) + :raises SmartSimError: If receiving message(s) fails + """ + messages = [] + eot = False + with self._fli.recvh(timeout=timeout) as recvh: + while not eot: + try: + message, _ = recvh.recv_bytes(timeout=timeout) + messages.append(message) + logger.debug(f"DragonFLIChannel {self.descriptor} received message") + except fli.FLIEOT: + eot = True + logger.debug(f"DragonFLIChannel exhausted: {self.descriptor}") + except Exception as e: + raise SmartSimError( + f"Error receiving messages: DragonFLIChannel {self.descriptor}" + ) from e + return messages + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "DragonFLIChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached DragonFLIChannel + :raises SmartSimError: If creation of DragonFLIChannel fails + :raises ValueError: If the descriptor is invalid + """ + if not descriptor: + raise ValueError("Invalid descriptor provided") + + try: + return DragonFLIChannel(fli_=drg_util.descriptor_to_fli(descriptor)) + except Exception as e: + raise SmartSimError( + f"Error while creating DragonFLIChannel: {descriptor}" + ) from e diff --git a/smartsim/_core/mli/comm/channel/dragon_util.py b/smartsim/_core/mli/comm/channel/dragon_util.py new file mode 100644 index 0000000000..8517979ec4 --- /dev/null +++ b/smartsim/_core/mli/comm/channel/dragon_util.py @@ -0,0 +1,131 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import binascii +import typing as t + +import dragon.channels as dch +import dragon.fli as fli +import dragon.managed_memory as dm + +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + +DEFAULT_CHANNEL_BUFFER_SIZE = 500 +"""Maximum number of messages that can be buffered. DragonCommChannel will +raise an exception if no clients consume messages before the buffer is filled.""" + +LAST_OFFSET = 0 +"""The last offset used to create a local channel. This is used to avoid +unnecessary retries when creating a local channel.""" + + +def channel_to_descriptor(channel: t.Union[dch.Channel, fli.FLInterface]) -> str: + """Convert a dragon channel to a descriptor string. + + :param channel: The dragon channel to convert + :returns: The descriptor string + :raises ValueError: If a dragon channel is not provided + """ + if channel is None: + raise ValueError("Channel is not available to create a descriptor") + + serialized_ch = channel.serialize() + return base64.b64encode(serialized_ch).decode("utf-8") + + +def pool_to_descriptor(pool: dm.MemoryPool) -> str: + """Convert a dragon memory pool to a descriptor string. + + :param pool: The memory pool to convert + :returns: The descriptor string + :raises ValueError: If a memory pool is not provided + """ + if pool is None: + raise ValueError("Memory pool is not available to create a descriptor") + + serialized_pool = pool.serialize() + return base64.b64encode(serialized_pool).decode("utf-8") + + +def descriptor_to_fli(descriptor: str) -> "fli.FLInterface": + """Create and attach a new FLI instance given + the string-encoded descriptor. + + :param descriptor: The descriptor of an FLI to attach to + :returns: The attached dragon FLI + :raises ValueError: If the descriptor is empty or incorrectly formatted + :raises SmartSimError: If attachment using the descriptor fails + """ + if len(descriptor) < 1: + raise ValueError("Descriptors may not be empty") + + try: + encoded = descriptor.encode("utf-8") + descriptor_ = base64.b64decode(encoded) + return fli.FLInterface.attach(descriptor_) + except binascii.Error: + raise ValueError("The descriptor was not properly base64 encoded") + except fli.DragonFLIError: + raise SmartSimError("The descriptor did not address an available FLI") + + +def descriptor_to_channel(descriptor: str) -> dch.Channel: + """Create and attach a new Channel instance given + the string-encoded descriptor. + + :param descriptor: The descriptor of a channel to attach to + :returns: The attached dragon Channel + :raises ValueError: If the descriptor is empty or incorrectly formatted + :raises SmartSimError: If attachment using the descriptor fails + """ + if len(descriptor) < 1: + raise ValueError("Descriptors may not be empty") + + try: + encoded = descriptor.encode("utf-8") + descriptor_ = base64.b64decode(encoded) + return dch.Channel.attach(descriptor_) + except binascii.Error: + raise ValueError("The descriptor was not properly base64 encoded") + except dch.ChannelError: + raise SmartSimError("The descriptor did not address an available channel") + + +def create_local(_capacity: int = 0) -> dch.Channel: + """Creates a Channel attached to the local memory pool. Replacement for + direct calls to `dch.Channel.make_process_local()` to enable + supplying a channel capacity. + + :param _capacity: The number of events the channel can buffer; uses the default + buffer size `DEFAULT_CHANNEL_BUFFER_SIZE` when not supplied + :returns: The instantiated channel + """ + channel = dch.Channel.make_process_local() + return channel diff --git a/smartsim/_core/mli/infrastructure/__init__.py b/smartsim/_core/mli/infrastructure/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/comm/__init__.py b/smartsim/_core/mli/infrastructure/comm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/comm/broadcaster.py b/smartsim/_core/mli/infrastructure/comm/broadcaster.py new file mode 100644 index 0000000000..56dcf549f7 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/broadcaster.py @@ -0,0 +1,239 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +import uuid +from collections import defaultdict, deque + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.infrastructure.comm.event import EventBase +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class BroadcastResult(t.NamedTuple): + """Contains summary details about a broadcast.""" + + num_sent: int + """The total number of messages delivered across all consumers""" + num_failed: int + """The total number of messages not delivered across all consumers""" + + +class EventBroadcaster: + """Performs fan-out publishing of system events.""" + + def __init__( + self, + backbone: BackboneFeatureStore, + channel_factory: t.Optional[t.Callable[[str], CommChannelBase]] = None, + name: t.Optional[str] = None, + ) -> None: + """Initialize the EventPublisher instance. + + :param backbone: The MLI backbone feature store + :param channel_factory: Factory method to construct new channel instances + :param name: A user-friendly name for logging. If not provided, an + auto-generated GUID will be used + """ + self._backbone = backbone + """The backbone feature store used to retrieve consumer descriptors""" + self._channel_factory = channel_factory + """A factory method used to instantiate channels from descriptors""" + self._channel_cache: t.Dict[str, t.Optional[CommChannelBase]] = defaultdict( + lambda: None + ) + """A mapping of instantiated channels that can be re-used. Automatically + calls the channel factory if a descriptor is not already in the collection""" + self._event_buffer: t.Deque[EventBase] = deque() + """A buffer for storing events when a consumer list is not found""" + self._descriptors: t.Set[str] + """Stores the most recent list of broadcast consumers. Updated automatically + on each broadcast""" + self._name = name or str(uuid.uuid4()) + """A unique identifer assigned to the broadcaster for logging""" + + @property + def name(self) -> str: + """The friendly name assigned to the broadcaster. + + :returns: The broadcaster name if one is assigned, otherwise a unique + id assigned by the system. + """ + return self._name + + @property + def num_buffered(self) -> int: + """Return the number of events currently buffered to send. + + :returns: Number of buffered events + """ + return len(self._event_buffer) + + def _save_to_buffer(self, event: EventBase) -> None: + """Places the event in the buffer to be sent once a consumer + list is available. + + :param event: The event to buffer + :raises ValueError: If the event cannot be buffered + """ + try: + self._event_buffer.append(event) + logger.debug(f"Buffered event {event=}") + except Exception as ex: + raise ValueError( + f"Unable to buffer event {event} in broadcaster {self.name}" + ) from ex + + def _log_broadcast_start(self) -> None: + """Logs broadcast statistics.""" + num_events = len(self._event_buffer) + num_copies = len(self._descriptors) + logger.debug( + f"Broadcast {num_events} events to {num_copies} consumers from {self.name}" + ) + + def _prune_unused_consumers(self) -> None: + """Performs maintenance on the channel cache by pruning any channel + that has been removed from the consumers list.""" + active_consumers = set(self._descriptors) + current_channels = set(self._channel_cache.keys()) + + # find any cached channels that are now unused + inactive_channels = current_channels.difference(active_consumers) + new_channels = active_consumers.difference(current_channels) + + for descriptor in inactive_channels: + self._channel_cache.pop(descriptor) + + logger.debug( + f"Pruning {len(inactive_channels)} stale consumers and" + f" found {len(new_channels)} new channels for {self.name}" + ) + + def _get_comm_channel(self, descriptor: str) -> CommChannelBase: + """Helper method to build and cache a comm channel. + + :param descriptor: The descriptor to pass to the channel factory + :returns: The instantiated channel + :raises SmartSimError: If the channel fails to attach + """ + comm_channel = self._channel_cache[descriptor] + if comm_channel is not None: + return comm_channel + + if self._channel_factory is None: + raise SmartSimError("No channel factory provided for consumers") + + try: + channel = self._channel_factory(descriptor) + self._channel_cache[descriptor] = channel + return channel + except Exception as ex: + msg = f"Unable to construct channel with descriptor: {descriptor}" + logger.error(msg, exc_info=True) + raise SmartSimError(msg) from ex + + def _get_next_event(self) -> t.Optional[EventBase]: + """Pop the next event to be sent from the queue. + + :returns: The next event to send if any events are enqueued, otherwise `None`. + """ + try: + return self._event_buffer.popleft() + except IndexError: + logger.debug(f"Broadcast buffer exhausted for {self.name}") + + return None + + def _broadcast(self, timeout: float = 0.001) -> BroadcastResult: + """Broadcasts all buffered events to registered event consumers. + + :param timeout: Maximum time to wait (in seconds) for messages to send + :returns: BroadcastResult containing the number of messages that were + successfully and unsuccessfully sent for all consumers + :raises SmartSimError: If the channel fails to attach or broadcasting fails + """ + # allow descriptors to be empty since events are buffered + self._descriptors = set(x for x in self._backbone.notification_channels if x) + if not self._descriptors: + msg = f"No event consumers are registered for {self.name}" + logger.warning(msg) + return BroadcastResult(0, 0) + + self._prune_unused_consumers() + self._log_broadcast_start() + + num_listeners = len(self._descriptors) + num_sent = 0 + num_failures = 0 + + # send each event to every consumer + while event := self._get_next_event(): + logger.debug(f"Broadcasting {event=} to {num_listeners} listeners") + event_bytes = bytes(event) + + for i, descriptor in enumerate(self._descriptors): + comm_channel = self._get_comm_channel(descriptor) + + try: + comm_channel.send(event_bytes, timeout) + num_sent += 1 + except Exception: + msg = ( + f"Broadcast {i+1}/{num_listeners} for event {event.uid} to " + f"channel {descriptor} from {self.name} failed." + ) + logger.exception(msg) + num_failures += 1 + + return BroadcastResult(num_sent, num_failures) + + def send(self, event: EventBase, timeout: float = 0.001) -> int: + """Implementation of `send` method of the `EventPublisher` protocol. Publishes + the supplied event to all registered broadcast consumers. + + :param event: An event to publish + :param timeout: Maximum time to wait (in seconds) for messages to send + :returns: The total number of events successfully published to consumers + :raises ValueError: If event serialization fails + :raises AttributeError: If event cannot be serialized + :raises KeyError: If channel fails to attach using registered descriptors + :raises SmartSimError: If any unexpected error occurs during send + """ + try: + self._save_to_buffer(event) + result = self._broadcast(timeout) + return result.num_sent + except (KeyError, ValueError, AttributeError, SmartSimError): + raise + except Exception as ex: + raise SmartSimError("An unexpected failure occurred while sending") from ex diff --git a/smartsim/_core/mli/infrastructure/comm/consumer.py b/smartsim/_core/mli/infrastructure/comm/consumer.py new file mode 100644 index 0000000000..08b5c47852 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/consumer.py @@ -0,0 +1,281 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pickle +import time +import typing as t +import uuid + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.infrastructure.comm.event import ( + EventBase, + OnCreateConsumer, + OnRemoveConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class EventConsumer: + """Reads system events published to a communications channel.""" + + _BACKBONE_WAIT_TIMEOUT = 10.0 + """Maximum time (in seconds) to wait for the backbone to register the consumer""" + + def __init__( + self, + comm_channel: CommChannelBase, + backbone: BackboneFeatureStore, + filters: t.Optional[t.List[str]] = None, + name: t.Optional[str] = None, + event_handler: t.Optional[t.Callable[[EventBase], None]] = None, + ) -> None: + """Initialize the EventConsumer instance. + + :param comm_channel: Communications channel to listen to for events + :param backbone: The MLI backbone feature store + :param filters: A list of event types to deliver. when empty, all + events will be delivered + :param name: A user-friendly name for logging. If not provided, an + auto-generated GUID will be used + """ + self._comm_channel = comm_channel + """The comm channel used by the consumer to receive messages. The channel + descriptor will be published for senders to discover.""" + self._backbone = backbone + """The backbone instance used to bootstrap the instance. The EventConsumer + uses the backbone to discover where it can publish its descriptor.""" + self._global_filters = filters or [] + """A set of global filters to apply to incoming events. Global filters are + combined with per-call filters. Filters act as an allow-list.""" + self._name = name or str(uuid.uuid4()) + """User-friendly name assigned to a consumer for logging. Automatically + assigned if not provided.""" + self._event_handler = event_handler + """The function that should be executed when an event + passed by the filters is received.""" + self.listening = True + """Flag indicating that the consumer is currently listening for new + events. Setting this flag to `False` will cause any active calls to + `listen` to terminate.""" + + @property + def descriptor(self) -> str: + """The descriptor of the underlying comm channel. + + :returns: The comm channel descriptor""" + return self._comm_channel.descriptor + + @property + def name(self) -> str: + """The friendly name assigned to the consumer. + + :returns: The consumer name if one is assigned, otherwise a unique + id assigned by the system. + """ + return self._name + + def recv( + self, + filters: t.Optional[t.List[str]] = None, + timeout: float = 0.001, + batch_timeout: float = 1.0, + ) -> t.List[EventBase]: + """Receives available published event(s). + + :param filters: Additional filters to add to the global filters configured + on the EventConsumer instance + :param timeout: Maximum time to wait for a single message to arrive + :param batch_timeout: Maximum time to wait for messages to arrive; allows + multiple batches to be retrieved in one call to `send` + :returns: A list of events that pass any configured filters + :raises ValueError: If a positive, non-zero value is not provided for the + timeout or batch_timeout. + """ + if filters is None: + filters = [] + + if timeout is not None and timeout <= 0: + raise ValueError("request timeout must be a non-zero, positive value") + + if batch_timeout is not None and batch_timeout <= 0: + raise ValueError("batch_timeout must be a non-zero, positive value") + + filter_set = {*self._global_filters, *filters} + all_message_bytes: t.List[bytes] = [] + + # firehose as many messages as possible within the batch_timeout + start_at = time.time() + remaining = batch_timeout + + batch_message_bytes = self._comm_channel.recv(timeout=timeout) + while batch_message_bytes: + # remove any empty messages that will fail to decode + all_message_bytes.extend(batch_message_bytes) + batch_message_bytes = [] + + # avoid getting stuck indefinitely waiting for the channel + elapsed = time.time() - start_at + remaining = batch_timeout - elapsed + + if remaining > 0: + batch_message_bytes = self._comm_channel.recv(timeout=timeout) + + events_received: t.List[EventBase] = [] + + # Timeout elapsed or no messages received - return the empty list + if not all_message_bytes: + return events_received + + for message in all_message_bytes: + if not message or message is None: + continue + + event = pickle.loads(message) + if not event: + logger.warning(f"Consumer {self.name} is unable to unpickle message") + continue + + # skip events that don't pass a filter + if filter_set and event.category not in filter_set: + continue + + events_received.append(event) + + return events_received + + def _send_to_registrar(self, event: EventBase) -> None: + """Send an event direct to the registrar listener.""" + registrar_key = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER + config = self._backbone.wait_for([registrar_key], self._BACKBONE_WAIT_TIMEOUT) + registrar_descriptor = str(config.get(registrar_key, None)) + + if not registrar_descriptor: + logger.warning( + f"Unable to send {event.category} from {self.name}. " + "No registrar channel found." + ) + return + + logger.debug(f"Sending {event.category} from {self.name}") + + registrar_channel = DragonCommChannel.from_descriptor(registrar_descriptor) + registrar_channel.send(bytes(event), timeout=1.0) + + logger.debug(f"{event.category} from {self.name} sent") + + def register(self) -> None: + """Send an event to register this consumer as a listener.""" + descriptor = self._comm_channel.descriptor + event = OnCreateConsumer(self.name, descriptor, self._global_filters) + + self._send_to_registrar(event) + + def unregister(self) -> None: + """Send an event to un-register this consumer as a listener.""" + descriptor = self._comm_channel.descriptor + event = OnRemoveConsumer(self.name, descriptor) + + self._send_to_registrar(event) + + def _on_handler_missing(self, event: EventBase) -> None: + """A "dead letter" event handler that is called to perform + processing on events before they're discarded. + + :param event: The event to handle + """ + logger.warning( + "No event handler is registered in consumer " + f"{self.name}. Discarding {event=}" + ) + + def listen_once(self, timeout: float = 0.001, batch_timeout: float = 1.0) -> None: + """Receives messages for the consumer a single time. Delivers + all messages that pass the consumer filters. Shutdown requests + are handled by a default event handler. + + + NOTE: Executes a single batch-retrieval to receive the maximum + number of messages available under batch timeout. To continually + listen, use `listen` in a non-blocking thread/process + + :param timeout: Maximum time to wait (in seconds) for a message to arrive + :param timeout: Maximum time to wait (in seconds) for a batch to arrive + """ + logger.info( + f"Consumer {self.name} listening with {timeout} second timeout" + f" on channel {self._comm_channel.descriptor}" + ) + + if not self._event_handler: + logger.info("Unable to handle messages. No event handler is registered.") + + incoming_messages = self.recv(timeout=timeout, batch_timeout=batch_timeout) + + if not incoming_messages: + logger.info(f"Consumer {self.name} received empty message list") + + for message in incoming_messages: + logger.info(f"Consumer {self.name} is handling event {message=}") + self._handle_shutdown(message) + + if self._event_handler: + self._event_handler(message) + else: + self._on_handler_missing(message) + + def _handle_shutdown(self, event: EventBase) -> bool: + """Handles shutdown requests sent to the consumer by setting the + `self.listener` property to `False`. + + :param event: The event to handle + :returns: A bool indicating if the event was a shutdown request + """ + if isinstance(event, OnShutdownRequested): + logger.debug(f"Shutdown requested from: {event.source}") + self.listening = False + return True + return False + + def listen(self, timeout: float = 0.001, batch_timeout: float = 1.0) -> None: + """Receives messages for the consumer until a shutdown request is received. + + :param timeout: Maximum time to wait (in seconds) for a message to arrive + :param batch_timeout: Maximum time to wait (in seconds) for a batch to arrive + """ + + logger.debug(f"Consumer {self.name} is now listening for events.") + + while self.listening: + self.listen_once(timeout, batch_timeout) + + logger.debug(f"Consumer {self.name} is no longer listening.") diff --git a/smartsim/_core/mli/infrastructure/comm/event.py b/smartsim/_core/mli/infrastructure/comm/event.py new file mode 100644 index 0000000000..ccef9f9b86 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/event.py @@ -0,0 +1,162 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pickle +import typing as t +import uuid +from dataclasses import dataclass, field + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +@dataclass +class EventBase: + """Core API for an event.""" + + category: str + """Unique category name for an event class""" + source: str + """A unique identifier for the publisher of the event""" + uid: str = field(default_factory=lambda: str(uuid.uuid4())) + """A unique identifier for this event""" + + def __bytes__(self) -> bytes: + """Default conversion to bytes for an event required to publish + messages using byte-oriented communication channels. + + :returns: This entity encoded as bytes""" + return pickle.dumps(self) + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance""" + return f"{self.uid}|{self.category}" + + +class OnShutdownRequested(EventBase): + """Publish this event to trigger the listener to shutdown.""" + + SHUTDOWN: t.ClassVar[str] = "consumer-unregister" + """Unique category name for an event raised when a new consumer is unregistered""" + + def __init__(self, source: str) -> None: + """Initialize the event instance. + + :param source: A unique identifier for the publisher of the event + creating the event + """ + super().__init__(self.SHUTDOWN, source) + + +class OnCreateConsumer(EventBase): + """Publish this event when a new event consumer registration is required.""" + + descriptor: str + """Descriptor of the comm channel exposed by the consumer""" + filters: t.List[str] = field(default_factory=list) + """The collection of filters indicating messages of interest to this consumer""" + + CONSUMER_CREATED: t.ClassVar[str] = "consumer-created" + """Unique category name for an event raised when a new consumer is registered""" + + def __init__(self, source: str, descriptor: str, filters: t.Sequence[str]) -> None: + """Initialize the event instance. + + :param source: A unique identifier for the publisher of the event + :param descriptor: Descriptor of the comm channel exposed by the consumer + :param filters: Collection of filters indicating messages of interest + """ + super().__init__(self.CONSUMER_CREATED, source) + self.descriptor = descriptor + self.filters = list(filters) + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance + """ + _filters = ",".join(self.filters) + return f"{str(super())}|{self.descriptor}|{_filters}" + + +class OnRemoveConsumer(EventBase): + """Publish this event when a consumer is shutting down and + should be removed from notification lists.""" + + descriptor: str + """Descriptor of the comm channel exposed by the consumer""" + + CONSUMER_REMOVED: t.ClassVar[str] = "consumer-removed" + """Unique category name for an event raised when a new consumer is unregistered""" + + def __init__(self, source: str, descriptor: str) -> None: + """Initialize the OnRemoveConsumer event. + + :param source: A unique identifier for the publisher of the event + :param descriptor: Descriptor of the comm channel exposed by the consumer + """ + super().__init__(self.CONSUMER_REMOVED, source) + self.descriptor = descriptor + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance + """ + return f"{str(super())}|{self.descriptor}" + + +class OnWriteFeatureStore(EventBase): + """Publish this event when a feature store key is written.""" + + descriptor: str + """The descriptor of the feature store where the write occurred""" + key: str + """The key identifying where the write occurred""" + + FEATURE_STORE_WRITTEN: str = "feature-store-written" + """Event category for an event raised when a feature store key is written""" + + def __init__(self, source: str, descriptor: str, key: str) -> None: + """Initialize the OnWriteFeatureStore event. + + :param source: A unique identifier for the publisher of the event + :param descriptor: The descriptor of the feature store where the write occurred + :param key: The key identifying where the write occurred + """ + super().__init__(self.FEATURE_STORE_WRITTEN, source) + self.descriptor = descriptor + self.key = key + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance + """ + return f"{str(super())}|{self.descriptor}|{self.key}" diff --git a/smartsim/_core/mli/infrastructure/comm/producer.py b/smartsim/_core/mli/infrastructure/comm/producer.py new file mode 100644 index 0000000000..2d8a7c14ad --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/producer.py @@ -0,0 +1,44 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +from smartsim._core.mli.infrastructure.comm.event import EventBase +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class EventProducer(t.Protocol): + """Core API of a class that publishes events.""" + + def send(self, event: EventBase, timeout: float = 0.001) -> int: + """Send an event using the configured comm channel. + + :param event: The event to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :returns: The number of messages that were sent + """ diff --git a/smartsim/_core/mli/infrastructure/control/__init__.py b/smartsim/_core/mli/infrastructure/control/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/control/device_manager.py b/smartsim/_core/mli/infrastructure/control/device_manager.py new file mode 100644 index 0000000000..9334971f8c --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/device_manager.py @@ -0,0 +1,166 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +from contextlib import _GeneratorContextManager, contextmanager + +from .....log import get_logger +from ..storage.feature_store import FeatureStore +from ..worker.worker import MachineLearningWorkerBase, RequestBatch + +logger = get_logger(__name__) + + +class WorkerDevice: + def __init__(self, name: str) -> None: + """Wrapper around a device to keep track of loaded Models and availability. + + :param name: Name used by the toolkit to identify this device, e.g. ``cuda:0`` + """ + self._name = name + """The name used by the toolkit to identify this device""" + self._models: dict[str, t.Any] = {} + """Dict of keys to models which are loaded on this device""" + + @property + def name(self) -> str: + """The identifier of the device represented by this object + + :returns: Name used by the toolkit to identify this device + """ + return self._name + + def add_model(self, key: str, model: t.Any) -> None: + """Add a reference to a model loaded on this device and assign it a key. + + :param key: The key under which the model is saved + :param model: The model which is added + """ + self._models[key] = model + + def remove_model(self, key: str) -> None: + """Remove the reference to a model loaded on this device. + + :param key: The key of the model to remove + :raises KeyError: If key does not exist for removal + """ + try: + self._models.pop(key) + except KeyError: + logger.warning(f"An unknown key was requested for removal: {key}") + raise + + def get_model(self, key: str) -> t.Any: + """Get the model corresponding to a given key. + + :param key: The model key + :returns: The model for the given key + :raises KeyError: If key does not exist + """ + try: + return self._models[key] + except KeyError: + logger.warning(f"An unknown key was requested: {key}") + raise + + def __contains__(self, key: str) -> bool: + """Check if model with a given key is available on the device. + + :param key: The key of the model to check for existence + :returns: Whether the model is available on the device + """ + return key in self._models + + @contextmanager + def get(self, key_to_remove: t.Optional[str]) -> t.Iterator["WorkerDevice"]: + """Get the WorkerDevice generator and optionally remove a model. + + :param key_to_remove: The key of the model to optionally remove + :returns: WorkerDevice generator + """ + yield self + if key_to_remove is not None: + self.remove_model(key_to_remove) + + +class DeviceManager: + def __init__(self, device: WorkerDevice): + """An object to manage devices such as GPUs and CPUs. + + The main goal of the ``DeviceManager`` is to ensure that + the managed device is ready to be used by a worker to + run a given model. + + :param device: The managed device + """ + self._device = device + """Device managed by this object""" + + def _load_model_on_device( + self, + worker: MachineLearningWorkerBase, + batch: RequestBatch, + feature_stores: dict[str, FeatureStore], + ) -> None: + """Load the model needed to execute a batch on the managed device. + + The model is loaded by the worker. + + :param worker: The worker that loads the model + :param batch: The batch for which the model is needed + :param feature_stores: Feature stores where the model could be stored + """ + + model_bytes = worker.fetch_model(batch, feature_stores) + loaded_model = worker.load_model(batch, model_bytes, self._device.name) + self._device.add_model(batch.model_id.key, loaded_model.model) + + def get_device( + self, + worker: MachineLearningWorkerBase, + batch: RequestBatch, + feature_stores: dict[str, FeatureStore], + ) -> _GeneratorContextManager[WorkerDevice]: + """Get the device managed by this object. + + The model needed to run the batch of requests is + guaranteed to be available on the device. + + :param worker: The worker that wants to access the device + :param batch: The batch of requests + :param feature_store: The feature store on which part of the + data needed by the request may be stored + :returns: A generator yielding the device + """ + model_in_request = batch.has_raw_model + + # Load model if not already loaded, or + # because it is sent with the request + if model_in_request or not batch.model_id.key in self._device: + self._load_model_on_device(worker, batch, feature_stores) + + key_to_remove = batch.model_id.key if model_in_request else None + return self._device.get(key_to_remove) diff --git a/smartsim/_core/mli/infrastructure/control/error_handling.py b/smartsim/_core/mli/infrastructure/control/error_handling.py new file mode 100644 index 0000000000..a75f533a37 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/error_handling.py @@ -0,0 +1,78 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +from .....log import get_logger +from ...comm.channel.channel import CommChannelBase +from ...message_handler import MessageHandler +from ...mli_schemas.response.response_capnp import ResponseBuilder + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +logger = get_logger(__file__) + + +def build_failure_reply(status: "Status", message: str) -> ResponseBuilder: + """ + Builds a failure response message. + + :param status: Status enum + :param message: Status message + :returns: Failure response + """ + return MessageHandler.build_response( + status=status, + message=message, + result=None, + custom_attributes=None, + ) + + +def exception_handler( + exc: Exception, + reply_channel: t.Optional[CommChannelBase], + failure_message: t.Optional[str], +) -> None: + """ + Logs exceptions and sends a failure response. + + :param exc: The exception to be logged + :param reply_channel: The channel used to send replies + :param failure_message: Failure message to log and send back + """ + logger.exception(exc) + if reply_channel: + if failure_message is None: + failure_message = str(exc) + + serialized_resp = MessageHandler.serialize_response( + build_failure_reply("fail", failure_message) + ) + reply_channel.send(serialized_resp) + else: + logger.warning("Unable to notify client of error without a reply channel") diff --git a/smartsim/_core/mli/infrastructure/control/listener.py b/smartsim/_core/mli/infrastructure/control/listener.py new file mode 100644 index 0000000000..56a7b12d34 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/listener.py @@ -0,0 +1,352 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +# pylint: disable=import-error +# pylint: disable=unused-import +import socket +import dragon + +# pylint: enable=unused-import +# pylint: enable=import-error +# isort: on + +import argparse +import multiprocessing as mp +import os +import sys +import typing as t + +from smartsim._core.entrypoints.service import Service +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + EventBase, + OnCreateConsumer, + OnRemoveConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class ConsumerRegistrationListener(Service): + """A long-running service that manages the list of consumers receiving + events that are broadcast. It hosts handlers for adding and removing consumers + """ + + def __init__( + self, + backbone: BackboneFeatureStore, + timeout: float, + batch_timeout: float, + as_service: bool = False, + cooldown: int = 0, + health_check_frequency: float = 60.0, + ) -> None: + """Initialize the EventListener. + + :param backbone: The backbone feature store + :param timeout: Maximum time (in seconds) to allow a single recv request to wait + :param batch_timeout: Maximum time (in seconds) to allow a batch of receives to + continue to build + :param as_service: Specifies run-once or run-until-complete behavior of service + :param cooldown: Number of seconds to wait before shutting down after + shutdown criteria are met + """ + super().__init__( + as_service, cooldown, health_check_frequency=health_check_frequency + ) + self._timeout = timeout + """ Maximum time (in seconds) to allow a single recv request to wait""" + self._batch_timeout = batch_timeout + """Maximum time (in seconds) to allow a batch of receives to + continue to build""" + self._consumer: t.Optional[EventConsumer] = None + """The event consumer that handles receiving events""" + self._backbone = backbone + """A standalone, system-created feature store used to share internal + information among MLI components""" + + def _on_start(self) -> None: + """Called on initial entry into Service `execute` event loop before + `_on_iteration` is invoked.""" + super()._on_start() + self._create_eventing() + + def _on_shutdown(self) -> None: + """Release dragon resources. Called immediately after exiting + the main event loop during automatic shutdown.""" + super()._on_shutdown() + + if not self._consumer: + return + + # remove descriptor for this listener from the backbone if it's there + if registered_consumer := self._backbone.backend_channel: + # if there is a descriptor in the backbone and it's still this listener + if registered_consumer == self._consumer.descriptor: + logger.info( + f"Listener clearing backend consumer {self._consumer.name} " + "from backbone" + ) + + # unregister this listener in the backbone + self._backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + # TODO: need the channel to be cleaned up + # self._consumer._comm_channel._channel.destroy() + + def _on_iteration(self) -> None: + """Executes calls to the machine learning worker implementation to complete + the inference pipeline.""" + + if self._consumer is None: + logger.info("Unable to listen. No consumer available.") + return + + self._consumer.listen_once(self._timeout, self._batch_timeout) + + def _can_shutdown(self) -> bool: + """Determines if the event consumer is ready to stop listening. + + :returns: True when criteria to shutdown the service are met, False otherwise + """ + + if self._backbone is None: + logger.info("Listener must shutdown. No backbone attached") + return True + + if self._consumer is None: + logger.info("Listener must shutdown. No consumer channel created") + return True + + if not self._consumer.listening: + logger.info( + f"Listener can shutdown. Consumer `{self._consumer.name}` " + "is not listening" + ) + return True + + return False + + def _on_unregister(self, event: OnRemoveConsumer) -> None: + """Event handler for updating the backbone when event consumers + are un-registered. + + :param event: The event that was received + """ + notify_list = set(self._backbone.notification_channels) + + # remove the descriptor specified in the event + if event.descriptor in notify_list: + logger.debug(f"Removing notify consumer: {event.descriptor}") + notify_list.remove(event.descriptor) + + # push the updated list back into the backbone + self._backbone.notification_channels = list(notify_list) + + def _on_register(self, event: OnCreateConsumer) -> None: + """Event handler for updating the backbone when new event consumers + are registered. + + :param event: The event that was received + """ + notify_list = set(self._backbone.notification_channels) + logger.debug(f"Adding notify consumer: {event.descriptor}") + notify_list.add(event.descriptor) + self._backbone.notification_channels = list(notify_list) + + def _on_event_received(self, event: EventBase) -> None: + """Primary event handler for the listener. Distributes events to + type-specific handlers. + + :param event: The event that was received + """ + if self._backbone is None: + logger.info("Unable to handle event. Backbone is missing.") + + if isinstance(event, OnCreateConsumer): + self._on_register(event) + elif isinstance(event, OnRemoveConsumer): + self._on_unregister(event) + else: + logger.info( + "Consumer registration listener received an " + f"unexpected event: {event=}" + ) + + def _on_health_check(self) -> None: + """Check if this consumer has been replaced by a new listener + and automatically trigger a shutdown. Invoked based on the + value of `self._health_check_frequency`.""" + super()._on_health_check() + + try: + logger.debug("Retrieving registered listener descriptor") + descriptor = self._backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + except KeyError: + descriptor = None + if self._consumer: + self._consumer.listening = False + + if self._consumer and descriptor != self._consumer.descriptor: + logger.warning( + f"Consumer `{self._consumer.name}` for `ConsumerRegistrationListener` " + "is no longer registered. It will automatically shut down." + ) + self._consumer.listening = False + + def _publish_consumer(self) -> None: + """Publish the registrar consumer descriptor to the backbone.""" + if self._consumer is None: + logger.warning("No registrar consumer descriptor available to publisher") + return + + logger.debug(f"Publishing {self._consumer.descriptor} to backbone") + self._backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = ( + self._consumer.descriptor + ) + + def _create_eventing(self) -> EventConsumer: + """ + Create an event publisher and event consumer for communicating with + other MLI resources. + + NOTE: the backbone must be initialized before connecting eventing clients. + + :returns: The newly created EventConsumer instance + :raises SmartSimError: If a listener channel cannot be created + """ + + if self._consumer: + return self._consumer + + logger.info("Creating event consumer") + + dragon_channel = create_local(500) + event_channel = DragonCommChannel(dragon_channel) + + if not event_channel.descriptor: + raise SmartSimError( + "Unable to generate the descriptor for the event channel" + ) + + self._consumer = EventConsumer( + event_channel, + self._backbone, + [ + OnCreateConsumer.CONSUMER_CREATED, + OnRemoveConsumer.CONSUMER_REMOVED, + OnShutdownRequested.SHUTDOWN, + ], + name=f"ConsumerRegistrar.{socket.gethostname()}", + event_handler=self._on_event_received, + ) + self._publish_consumer() + + logger.info( + f"Backend consumer `{self._consumer.name}` created: " + f"{self._consumer.descriptor}" + ) + + return self._consumer + + +def _create_parser() -> argparse.ArgumentParser: + """ + Create an argument parser that contains the arguments + required to start the listener as a new process: + + --timeout + --batch_timeout + + :returns: A configured parser + """ + arg_parser = argparse.ArgumentParser(prog="ConsumerRegistrarEventListener") + + arg_parser.add_argument("--timeout", type=float, default=1.0) + arg_parser.add_argument("--batch_timeout", type=float, default=1.0) + + return arg_parser + + +def _connect_backbone() -> t.Optional[BackboneFeatureStore]: + """ + Load the backbone by retrieving the descriptor from environment variables. + + :returns: The backbone feature store + :raises SmartSimError: if a descriptor is not found + """ + descriptor = os.environ.get(BackboneFeatureStore.MLI_BACKBONE, "") + + if not descriptor: + return None + + logger.info(f"Listener backbone descriptor: {descriptor}\n") + + # `from_writable_descriptor` ensures we can update the backbone + return BackboneFeatureStore.from_writable_descriptor(descriptor) + + +if __name__ == "__main__": + mp.set_start_method("dragon") + + parser = _create_parser() + args = parser.parse_args() + + backbone_fs = _connect_backbone() + + if backbone_fs is None: + logger.error( + "Unable to attach to the backbone without the " + f"`{BackboneFeatureStore.MLI_BACKBONE}` environment variable." + ) + sys.exit(1) + + logger.debug(f"Listener attached to backbone: {backbone_fs.descriptor}") + + listener = ConsumerRegistrationListener( + backbone_fs, + float(args.timeout), + float(args.batch_timeout), + as_service=True, + ) + + logger.info(f"listener created? {listener}") + + try: + listener.execute() + sys.exit(0) + except Exception: + logger.exception("An error occurred in the event listener") + sys.exit(1) diff --git a/smartsim/_core/mli/infrastructure/control/request_dispatcher.py b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py new file mode 100644 index 0000000000..e22a2c8f62 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py @@ -0,0 +1,559 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +# pylint: disable-next=unused-import +import dragon +import dragon.globalservices.pool as dragon_gs_pool +from dragon.managed_memory import MemoryPool +from dragon.mpbridge.queues import DragonQueue + +# pylint: enable=import-error + +# isort: off +# isort: on + +import multiprocessing as mp +import time +import typing as t +import uuid +from queue import Empty, Full, Queue + +from smartsim._core.entrypoints.service import Service + +from .....error import SmartSimError +from .....log import get_logger +from ....utils.timings import PerfTimer +from ..environment_loader import EnvironmentConfigLoader +from ..storage.feature_store import FeatureStore +from ..worker.worker import ( + InferenceRequest, + MachineLearningWorkerBase, + ModelIdentifier, + RequestBatch, +) +from .error_handling import exception_handler + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +logger = get_logger("Request Dispatcher") + + +class BatchQueue(Queue[InferenceRequest]): + def __init__( + self, batch_timeout: float, batch_size: int, model_id: ModelIdentifier + ) -> None: + """Queue used to store inference requests waiting to be batched and + sent to Worker Managers. + + :param batch_timeout: Time in seconds that has to be waited before flushing a + non-full queue. The time of the first item put is 0 seconds. + :param batch_size: Total capacity of the queue + :param model_id: Key of the model which needs to be executed on the queued + requests + """ + super().__init__(maxsize=batch_size) + self._batch_timeout = batch_timeout + """Time in seconds that has to be waited before flushing a non-full queue. + The time of the first item put is 0 seconds.""" + self._batch_size = batch_size + """Total capacity of the queue""" + self._first_put: t.Optional[float] = None + """Time at which the first item was put on the queue""" + self._disposable = False + """Whether the queue will not be used again and can be deleted. + A disposable queue is always full.""" + self._model_id: ModelIdentifier = model_id + """Key of the model which needs to be executed on the queued requests""" + self._uid = str(uuid.uuid4()) + """Unique ID of queue""" + + @property + def uid(self) -> str: + """ID of this queue. + + :returns: Queue ID + """ + return self._uid + + @property + def model_id(self) -> ModelIdentifier: + """Key of the model which needs to be run on the queued requests. + + :returns: Model key + """ + return self._model_id + + def put( + self, + item: InferenceRequest, + block: bool = False, + timeout: t.Optional[float] = 0.0, + ) -> None: + """Put an inference request in the queue. + + :param item: The request + :param block: Whether to block when trying to put the item + :param timeout: Time (in seconds) to wait if block==True + :raises Full: If an item cannot be put on the queue + """ + super().put(item, block=block, timeout=timeout) + if self._first_put is None: + self._first_put = time.time() + + @property + def _elapsed_time(self) -> float: + """Time elapsed since the first item was put on this queue. + + :returns: Time elapsed + """ + if self.empty() or self._first_put is None: + return 0 + return time.time() - self._first_put + + @property + def ready(self) -> bool: + """Check if the queue can be flushed. + + :returns: True if the queue can be flushed, False otherwise + """ + if self.empty(): + logger.debug("Request dispatcher queue is empty") + return False + + timed_out = False + if self._batch_timeout >= 0: + timed_out = self._elapsed_time >= self._batch_timeout + + if self.full(): + logger.debug("Request dispatcher ready to deliver full batch") + return True + + if timed_out: + logger.debug("Request dispatcher delivering partial batch") + return True + + return False + + def make_disposable(self) -> None: + """Set this queue as disposable, and never use it again after it gets + flushed.""" + self._disposable = True + + @property + def can_be_removed(self) -> bool: + """Determine whether this queue can be deleted and garbage collected. + + :returns: True if queue can be removed, False otherwise + """ + return self.empty() and self._disposable + + def flush(self) -> list[t.Any]: + """Get all requests from queue. + + :returns: Requests waiting to be executed + """ + num_items = self.qsize() + self._first_put = None + items = [] + for _ in range(num_items): + try: + items.append(self.get()) + except Empty: + break + + return items + + def full(self) -> bool: + """Check if the queue has reached its maximum capacity. + + :returns: True if the queue has reached its maximum capacity, + False otherwise + """ + if self._disposable: + return True + return self.qsize() >= self._batch_size + + def empty(self) -> bool: + """Check if the queue is empty. + + :returns: True if the queue has 0 elements, False otherwise + """ + return self.qsize() == 0 + + +class RequestDispatcher(Service): + def __init__( + self, + batch_timeout: float, + batch_size: int, + config_loader: EnvironmentConfigLoader, + worker_type: t.Type[MachineLearningWorkerBase], + mem_pool_size: int = 2 * 1024**3, + ) -> None: + """The RequestDispatcher intercepts inference requests, stages them in + queues and batches them together before making them available to Worker + Managers. + + :param batch_timeout: Maximum elapsed time before flushing a complete or + incomplete batch + :param batch_size: Total capacity of each batch queue + :param mem_pool: Memory pool used to share batched input tensors with worker + managers + :param config_loader: Object to load configuration from environment + :param worker_type: Type of worker to instantiate to batch inputs + :param mem_pool_size: Size of the memory pool used to allocate tensors + """ + super().__init__(as_service=True, cooldown=1) + self._queues: dict[str, list[BatchQueue]] = {} + """Dict of all batch queues available for a given model id""" + self._active_queues: dict[str, BatchQueue] = {} + """Mapping telling which queue is the recipient of requests for a given model + key""" + self._batch_timeout = batch_timeout + """Time in seconds that has to be waited before flushing a non-full queue""" + self._batch_size = batch_size + """Total capacity of each batch queue""" + incoming_channel = config_loader.get_queue() + if incoming_channel is None: + raise SmartSimError("No incoming channel for dispatcher") + self._incoming_channel = incoming_channel + """The channel the dispatcher monitors for new tasks""" + self._outgoing_queue: DragonQueue = mp.Queue(maxsize=0) + """The queue on which batched inference requests are placed""" + self._feature_stores: t.Dict[str, FeatureStore] = {} + """A collection of attached feature stores""" + self._featurestore_factory = config_loader._featurestore_factory + """A factory method to create a desired feature store client type""" + self._backbone: t.Optional[FeatureStore] = config_loader.get_backbone() + """A standalone, system-created feature store used to share internal + information among MLI components""" + self._callback_factory = config_loader._callback_factory + """The type of communication channel to construct for callbacks""" + self._worker = worker_type() + """The worker used to batch inputs""" + self._mem_pool = MemoryPool.attach(dragon_gs_pool.create(mem_pool_size).sdesc) + """Memory pool used to share batched input tensors with the Worker Managers""" + self._perf_timer = PerfTimer(prefix="r_", debug=False, timing_on=True) + """Performance timer""" + + @property + def has_featurestore_factory(self) -> bool: + """Check if the RequestDispatcher has a FeatureStore factory. + + :returns: True if there is a FeatureStore factory, False otherwise + """ + return self._featurestore_factory is not None + + def _check_feature_stores(self, request: InferenceRequest) -> bool: + """Ensures that all feature stores required by the request are available. + + :param request: The request to validate + :returns: False if feature store validation fails for the request, True + otherwise + """ + # collect all feature stores required by the request + fs_model: t.Set[str] = set() + if request.model_key: + fs_model = {request.model_key.descriptor} + fs_inputs = {key.descriptor for key in request.input_keys} + fs_outputs = {key.descriptor for key in request.output_keys} + + # identify which feature stores are requested and unknown + fs_desired = fs_model.union(fs_inputs).union(fs_outputs) + fs_actual = {item.descriptor for item in self._feature_stores.values()} + fs_missing = fs_desired - fs_actual + + if not self.has_featurestore_factory: + logger.error("No feature store factory is configured. Unable to dispatch.") + return False + + # create the feature stores we need to service request + if fs_missing: + logger.debug(f"Adding feature store(s): {fs_missing}") + for descriptor in fs_missing: + feature_store = self._featurestore_factory(descriptor) + self._feature_stores[descriptor] = feature_store + + return True + + # pylint: disable-next=no-self-use + def _check_model(self, request: InferenceRequest) -> bool: + """Ensure that a model is available for the request. + + :param request: The request to validate + :returns: False if model validation fails for the request, True otherwise + """ + if request.has_model_key or request.has_raw_model: + return True + + logger.error("Unable to continue without model bytes or feature store key") + return False + + # pylint: disable-next=no-self-use + def _check_inputs(self, request: InferenceRequest) -> bool: + """Ensure that inputs are available for the request. + + :param request: The request to validate + :returns: False if input validation fails for the request, True otherwise + """ + if request.has_input_keys or request.has_raw_inputs: + return True + + logger.error("Unable to continue without input bytes or feature store keys") + return False + + # pylint: disable-next=no-self-use + def _check_callback(self, request: InferenceRequest) -> bool: + """Ensure that a callback channel is available for the request. + + :param request: The request to validate + :returns: False if callback validation fails for the request, True otherwise + """ + if request.callback: + return True + + logger.error("No callback channel provided in request") + return False + + def _validate_request(self, request: InferenceRequest) -> bool: + """Ensure the request can be processed. + + :param request: The request to validate + :returns: False if the request fails any validation checks, True otherwise + """ + checks = [ + self._check_feature_stores(request), + self._check_model(request), + self._check_inputs(request), + self._check_callback(request), + ] + + return all(checks) + + def _on_iteration(self) -> None: + """This method is executed repeatedly until ``Service`` shutdown + conditions are satisfied and cooldown is elapsed.""" + try: + self._perf_timer.is_active = True + bytes_list: t.List[bytes] = self._incoming_channel.recv() + except Exception: + self._perf_timer.is_active = False + else: + if not bytes_list: + exception_handler( + ValueError("No request data found"), + None, + None, + ) + + logger.debug(f"Dispatcher is processing {len(bytes_list)} messages") + request_bytes = bytes_list[0] + tensor_bytes_list = bytes_list[1:] + self._perf_timer.start_timings() + + request = self._worker.deserialize_message( + request_bytes, self._callback_factory + ) + if request.has_input_meta and tensor_bytes_list: + request.raw_inputs = tensor_bytes_list + + self._perf_timer.measure_time("deserialize_message") + + if not self._validate_request(request): + exception_handler( + ValueError("Error validating the request"), + request.callback, + None, + ) + self._perf_timer.measure_time("validate_request") + else: + self._perf_timer.measure_time("validate_request") + self.dispatch(request) + self._perf_timer.measure_time("dispatch") + finally: + self.flush_requests() + self.remove_queues() + + self._perf_timer.end_timings() + + if self._perf_timer.max_length == 801 and self._perf_timer.is_active: + self._perf_timer.print_timings(True) + + def remove_queues(self) -> None: + """Remove references to queues that can be removed + and allow them to be garbage collected.""" + queue_lists_to_remove = [] + for key, queues in self._queues.items(): + queues_to_remove = [] + for queue in queues: + if queue.can_be_removed: + queues_to_remove.append(queue) + + for queue_to_remove in queues_to_remove: + queues.remove(queue_to_remove) + if ( + key in self._active_queues + and self._active_queues[key] == queue_to_remove + ): + del self._active_queues[key] + + if len(queues) == 0: + queue_lists_to_remove.append(key) + + for key in queue_lists_to_remove: + del self._queues[key] + + @property + def task_queue(self) -> DragonQueue: + """The queue on which batched requests are placed. + + :returns: The queue + """ + return self._outgoing_queue + + def _swap_queue(self, model_id: ModelIdentifier) -> None: + """Get an empty queue or create a new one + and make it the active one for a given model. + + :param model_id: The id of the model for which the + queue has to be swapped + """ + if model_id.key in self._queues: + for queue in self._queues[model_id.key]: + if not queue.full(): + self._active_queues[model_id.key] = queue + return + + new_queue = BatchQueue(self._batch_timeout, self._batch_size, model_id) + if model_id.key in self._queues: + self._queues[model_id.key].append(new_queue) + else: + self._queues[model_id.key] = [new_queue] + self._active_queues[model_id.key] = new_queue + return + + def dispatch(self, request: InferenceRequest) -> None: + """Assign a request to a batch queue. + + :param request: The request to place + """ + if request.has_raw_model: + logger.debug("Direct inference requested, creating tmp queue") + tmp_id = f"_tmp_{str(uuid.uuid4())}" + tmp_queue: BatchQueue = BatchQueue( + batch_timeout=0, + batch_size=1, + model_id=ModelIdentifier(key=tmp_id, descriptor="TMP"), + ) + self._active_queues[tmp_id] = tmp_queue + self._queues[tmp_id] = [tmp_queue] + tmp_queue.put(request) + tmp_queue.make_disposable() + return + + if request.model_key: + success = False + while not success: + try: + self._active_queues[request.model_key.key].put_nowait(request) + success = True + except (Full, KeyError): + self._swap_queue(request.model_key) + + def flush_requests(self) -> None: + """Get all requests from queues which are ready to be flushed. Place all + available request batches in the outgoing queue.""" + for queue_list in self._queues.values(): + for queue in queue_list: + if queue.ready: + self._perf_timer.measure_time("find_queue") + try: + batch = RequestBatch( + requests=queue.flush(), + inputs=None, + model_id=queue.model_id, + ) + finally: + self._perf_timer.measure_time("flush_requests") + try: + fetch_results = self._worker.fetch_inputs( + batch=batch, feature_stores=self._feature_stores + ) + except Exception as exc: + exception_handler( + exc, + None, + "Error fetching input.", + ) + continue + self._perf_timer.measure_time("fetch_input") + try: + transformed_inputs = self._worker.transform_input( + batch=batch, + fetch_results=fetch_results, + mem_pool=self._mem_pool, + ) + except Exception as exc: + exception_handler( + exc, + None, + "Error transforming input.", + ) + continue + + self._perf_timer.measure_time("transform_input") + batch.inputs = transformed_inputs + for request in batch.requests: + request.raw_inputs = [] + request.input_meta = [] + + try: + self._outgoing_queue.put(batch) + except Exception as exc: + exception_handler( + exc, + None, + "Error placing batch on task queue.", + ) + continue + self._perf_timer.measure_time("put") + + def _can_shutdown(self) -> bool: + """Determine whether the Service can be shut down. + + :returns: False + """ + return False + + def __del__(self) -> None: + """Destroy allocated memory resources.""" + # pool may be null if a failure occurs prior to successful attach + pool: t.Optional[MemoryPool] = getattr(self, "_mem_pool", None) + + if pool: + pool.destroy() diff --git a/smartsim/_core/mli/infrastructure/control/worker_manager.py b/smartsim/_core/mli/infrastructure/control/worker_manager.py new file mode 100644 index 0000000000..bf6fddb81d --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/worker_manager.py @@ -0,0 +1,330 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +# pylint: disable-next=unused-import +import dragon + +# pylint: enable=import-error + +# isort: off +# isort: on + +import multiprocessing as mp +import time +import typing as t +from queue import Empty + +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore + +from .....log import get_logger +from ....entrypoints.service import Service +from ....utils.timings import PerfTimer +from ...message_handler import MessageHandler +from ..environment_loader import EnvironmentConfigLoader +from ..worker.worker import ( + InferenceReply, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, +) +from .device_manager import DeviceManager, WorkerDevice +from .error_handling import build_failure_reply, exception_handler + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +logger = get_logger(__name__) + + +class WorkerManager(Service): + """An implementation of a service managing distribution of tasks to + machine learning workers.""" + + def __init__( + self, + config_loader: EnvironmentConfigLoader, + worker_type: t.Type[MachineLearningWorkerBase], + dispatcher_queue: "mp.Queue[RequestBatch]", + as_service: bool = False, + cooldown: int = 0, + device: t.Literal["cpu", "gpu"] = "cpu", + ) -> None: + """Initialize the WorkerManager. + + :param config_loader: Environment config loader for loading queues + and feature stores + :param worker_type: The type of worker to manage + :param dispatcher_queue: Queue from which the batched requests are pulled + :param as_service: Specifies run-once or run-until-complete behavior of service + :param cooldown: Number of seconds to wait before shutting down after + shutdown criteria are met + :param device: The device on which the Worker should run. Every worker manager + is assigned one single GPU (if available), thus the device should have no index. + """ + super().__init__(as_service, cooldown) + + self._dispatcher_queue = dispatcher_queue + """The Dispatcher queue that the WorkerManager monitors for new batches""" + self._worker = worker_type() + """The ML Worker implementation""" + self._callback_factory = config_loader._callback_factory + """The type of communication channel to construct for callbacks""" + self._device = device + """Device on which workers need to run""" + self._cached_models: dict[str, t.Any] = {} + """Dictionary of previously loaded models""" + self._feature_stores: t.Dict[str, FeatureStore] = {} + """A collection of attached feature stores""" + self._featurestore_factory = config_loader._featurestore_factory + """A factory method to create a desired feature store client type""" + self._backbone: t.Optional[FeatureStore] = config_loader.get_backbone() + """A standalone, system-created feature store used to share internal + information among MLI components""" + self._device_manager: t.Optional[DeviceManager] = None + """Object responsible for model caching and device access""" + self._perf_timer = PerfTimer(prefix="w_", debug=False, timing_on=True) + """Performance timer""" + + @property + def has_featurestore_factory(self) -> bool: + """Check if the WorkerManager has a FeatureStore factory. + + :returns: True if there is a FeatureStore factory, False otherwise + """ + return self._featurestore_factory is not None + + def _on_start(self) -> None: + """Called on initial entry into Service `execute` event loop before + `_on_iteration` is invoked.""" + self._device_manager = DeviceManager(WorkerDevice(self._device)) + + def _check_feature_stores(self, batch: RequestBatch) -> bool: + """Ensures that all feature stores required by the request are available. + + :param batch: The batch of requests to validate + :returns: False if feature store validation fails for the batch, True otherwise + """ + # collect all feature stores required by the request + fs_model: t.Set[str] = set() + if batch.model_id.key: + fs_model = {batch.model_id.descriptor} + fs_inputs = {key.descriptor for key in batch.input_keys} + fs_outputs = {key.descriptor for key in batch.output_keys} + + # identify which feature stores are requested and unknown + fs_desired = fs_model.union(fs_inputs).union(fs_outputs) + fs_actual = {item.descriptor for item in self._feature_stores.values()} + fs_missing = fs_desired - fs_actual + + if not self.has_featurestore_factory: + logger.error("No feature store factory configured") + return False + + # create the feature stores we need to service request + if fs_missing: + logger.debug(f"Adding feature store(s): {fs_missing}") + for descriptor in fs_missing: + feature_store = self._featurestore_factory(descriptor) + self._feature_stores[descriptor] = feature_store + + return True + + def _validate_batch(self, batch: RequestBatch) -> bool: + """Ensure the request can be processed. + + :param batch: The batch of requests to validate + :returns: False if the request fails any validation checks, True otherwise + """ + if batch is None or not batch.has_valid_requests: + return False + + return self._check_feature_stores(batch) + + # remove this when we are done with time measurements + # pylint: disable-next=too-many-statements + def _on_iteration(self) -> None: + """Executes calls to the machine learning worker implementation to complete + the inference pipeline.""" + pre_batch_time = time.perf_counter() + try: + batch: RequestBatch = self._dispatcher_queue.get(timeout=0.0001) + except Empty: + return + + self._perf_timer.start_timings( + "flush_requests", time.perf_counter() - pre_batch_time + ) + + if not self._validate_batch(batch): + exception_handler( + ValueError("An invalid batch was received"), + None, + None, + ) + return + + if not self._device_manager: + for request in batch.requests: + msg = "No Device Manager found. WorkerManager._on_start() " + "must be called after initialization. If possible, " + "you should use `WorkerManager.execute()` instead of " + "directly calling `_on_iteration()`." + try: + self._dispatcher_queue.put(batch) + except Exception: + msg += "\nThe batch could not be put back in the queue " + "and will not be processed." + exception_handler( + RuntimeError(msg), + request.callback, + "Error acquiring device manager", + ) + return + + try: + device_cm = self._device_manager.get_device( + worker=self._worker, + batch=batch, + feature_stores=self._feature_stores, + ) + except Exception as exc: + for request in batch.requests: + exception_handler( + exc, + request.callback, + "Error loading model on device or getting device.", + ) + return + self._perf_timer.measure_time("fetch_model") + + with device_cm as device: + + try: + model_result = LoadModelResult(device.get_model(batch.model_id.key)) + except Exception as exc: + for request in batch.requests: + exception_handler( + exc, request.callback, "Error getting model from device." + ) + return + self._perf_timer.measure_time("load_model") + + if not batch.inputs: + for request in batch.requests: + exception_handler( + ValueError("Error batching inputs"), + request.callback, + None, + ) + return + transformed_input = batch.inputs + + try: + execute_result = self._worker.execute( + batch, model_result, transformed_input, device.name + ) + except Exception as e: + for request in batch.requests: + exception_handler(e, request.callback, "Error while executing.") + return + self._perf_timer.measure_time("execute") + + try: + transformed_outputs = self._worker.transform_output( + batch, execute_result + ) + except Exception as e: + for request in batch.requests: + exception_handler( + e, request.callback, "Error while transforming the output." + ) + return + + for request, transformed_output in zip(batch.requests, transformed_outputs): + reply = InferenceReply() + if request.has_output_keys: + try: + reply.output_keys = self._worker.place_output( + request, + transformed_output, + self._feature_stores, + ) + except Exception as e: + exception_handler( + e, request.callback, "Error while placing the output." + ) + continue + else: + reply.outputs = transformed_output.outputs + self._perf_timer.measure_time("assign_output") + + if not reply.has_outputs: + response = build_failure_reply("fail", "Outputs not found.") + else: + reply.status_enum = "complete" + reply.message = "Success" + + results = self._worker.prepare_outputs(reply) + response = MessageHandler.build_response( + status=reply.status_enum, + message=reply.message, + result=results, + custom_attributes=None, + ) + + self._perf_timer.measure_time("build_reply") + + serialized_resp = MessageHandler.serialize_response(response) + + self._perf_timer.measure_time("serialize_resp") + + if request.callback: + request.callback.send(serialized_resp) + if reply.has_outputs: + # send tensor data after response + for output in reply.outputs: + request.callback.send(output) + self._perf_timer.measure_time("send") + + self._perf_timer.end_timings() + + if self._perf_timer.max_length == 801: + self._perf_timer.print_timings(True) + + def _can_shutdown(self) -> bool: + """Determine if the service can be shutdown. + + :returns: True when criteria to shutdown the service are met, False otherwise + """ + # todo: determine shutdown criteria + # will we receive a completion message? + # will we let MLI mgr just kill this? + # time_diff = self._last_event - datetime.datetime.now() + # if time_diff.total_seconds() > self._cooldown: + # return True + # return False + return self._worker is None diff --git a/smartsim/_core/mli/infrastructure/environment_loader.py b/smartsim/_core/mli/infrastructure/environment_loader.py new file mode 100644 index 0000000000..5ba0fccc27 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/environment_loader.py @@ -0,0 +1,116 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class EnvironmentConfigLoader: + """ + Facilitates the loading of a FeatureStore and Queue into the WorkerManager. + """ + + REQUEST_QUEUE_ENV_VAR = "_SMARTSIM_REQUEST_QUEUE" + """The environment variable that holds the request queue descriptor""" + BACKBONE_ENV_VAR = "_SMARTSIM_INFRA_BACKBONE" + """The environment variable that holds the backbone descriptor""" + + def __init__( + self, + featurestore_factory: t.Callable[[str], FeatureStore], + callback_factory: t.Callable[[str], CommChannelBase], + queue_factory: t.Callable[[str], CommChannelBase], + ) -> None: + """Initialize the config loader instance with the factories necessary for + creating additional objects. + + :param featurestore_factory: A factory method that produces a feature store + given a descriptor + :param callback_factory: A factory method that produces a callback + channel given a descriptor + :param queue_factory: A factory method that produces a queue + channel given a descriptor + """ + self.queue: t.Optional[CommChannelBase] = None + """The attached incoming event queue channel""" + self.backbone: t.Optional[FeatureStore] = None + """The attached backbone feature store""" + self._featurestore_factory = featurestore_factory + """A factory method to instantiate a FeatureStore""" + self._callback_factory = callback_factory + """A factory method to instantiate a concrete CommChannelBase + for inference callbacks""" + self._queue_factory = queue_factory + """A factory method to instantiate a concrete CommChannelBase + for inference requests""" + + def get_backbone(self) -> t.Optional[FeatureStore]: + """Attach to the backbone feature store using the descriptor found in + the environment variable `_SMARTSIM_INFRA_BACKBONE`. The backbone is + a standalone, system-created feature store used to share internal + information among MLI components. + + :returns: The attached feature store via `_SMARTSIM_INFRA_BACKBONE` + """ + descriptor = os.getenv(self.BACKBONE_ENV_VAR, "") + + if not descriptor: + logger.warning("No backbone descriptor is configured") + return None + + if self._featurestore_factory is None: + logger.warning( + "No feature store factory is configured. Backbone not created." + ) + return None + + self.backbone = self._featurestore_factory(descriptor) + return self.backbone + + def get_queue(self) -> t.Optional[CommChannelBase]: + """Attach to a queue-like communication channel using the descriptor + found in the environment variable `_SMARTSIM_REQUEST_QUEUE`. + + :returns: The attached queue specified via `_SMARTSIM_REQUEST_QUEUE` + """ + descriptor = os.getenv(self.REQUEST_QUEUE_ENV_VAR, "") + + if not descriptor: + logger.warning("No queue descriptor is configured") + return None + + if self._queue_factory is None: + logger.warning("No queue factory is configured") + return None + + self.queue = self._queue_factory(descriptor) + return self.queue diff --git a/smartsim/_core/mli/infrastructure/storage/__init__.py b/smartsim/_core/mli/infrastructure/storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py b/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py new file mode 100644 index 0000000000..b12d7b11b4 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py @@ -0,0 +1,259 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import itertools +import os +import time +import typing as t + +# pylint: disable=import-error +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class BackboneFeatureStore(DragonFeatureStore): + """A DragonFeatureStore wrapper with utility methods for accessing shared + information stored in the MLI backbone feature store.""" + + MLI_NOTIFY_CONSUMERS = "_SMARTSIM_MLI_NOTIFY_CONSUMERS" + """Unique key used in the backbone to locate the consumer list""" + MLI_REGISTRAR_CONSUMER = "_SMARTIM_MLI_REGISTRAR_CONSUMER" + """Unique key used in the backbone to locate the registration consumer""" + MLI_WORKER_QUEUE = "_SMARTSIM_REQUEST_QUEUE" + """Unique key used in the backbone to locate MLI work queue""" + MLI_BACKBONE = "_SMARTSIM_INFRA_BACKBONE" + """Unique key used in the backbone to locate the backbone feature store""" + _CREATED_ON = "creation" + """Unique key used in the backbone to locate the creation date of the + feature store""" + _DEFAULT_WAIT_TIMEOUT = 1.0 + """The default wait time (in seconds) for blocking requests to + the feature store""" + + def __init__( + self, + storage: dragon_ddict.DDict, + allow_reserved_writes: bool = False, + ) -> None: + """Initialize the DragonFeatureStore instance. + + :param storage: A distributed dictionary to be used as the underlying + storage mechanism of the feature store + :param allow_reserved_writes: Whether reserved writes are allowed + """ + super().__init__(storage) + self._enable_reserved_writes = allow_reserved_writes + + self._record_creation_data() + + @property + def wait_timeout(self) -> float: + """Retrieve the wait timeout for this feature store. The wait timeout is + applied to all calls to `wait_for`. + + :returns: The wait timeout (in seconds). + """ + return self._wait_timeout + + @wait_timeout.setter + def wait_timeout(self, value: float) -> None: + """Set the wait timeout (in seconds) for this feature store. The wait + timeout is applied to all calls to `wait_for`. + + :param value: The new value to set + """ + self._wait_timeout = value + + @property + def notification_channels(self) -> t.Sequence[str]: + """Retrieve descriptors for all registered MLI notification channels. + + :returns: The list of channel descriptors + """ + if self.MLI_NOTIFY_CONSUMERS in self: + stored_consumers = self[self.MLI_NOTIFY_CONSUMERS] + return str(stored_consumers).split(",") + return [] + + @notification_channels.setter + def notification_channels(self, values: t.Sequence[str]) -> None: + """Set the notification channels to be sent events. + + :param values: The list of channel descriptors to save + """ + self[self.MLI_NOTIFY_CONSUMERS] = ",".join( + [str(value) for value in values if value] + ) + + @property + def backend_channel(self) -> t.Optional[str]: + """Retrieve the channel descriptor used to register event consumers. + + :returns: The channel descriptor""" + if self.MLI_REGISTRAR_CONSUMER in self: + return str(self[self.MLI_REGISTRAR_CONSUMER]) + return None + + @backend_channel.setter + def backend_channel(self, value: str) -> None: + """Set the channel used to register event consumers. + + :param value: The stringified channel descriptor""" + self[self.MLI_REGISTRAR_CONSUMER] = value + + @property + def worker_queue(self) -> t.Optional[str]: + """Retrieve the channel descriptor used to send work to MLI worker managers. + + :returns: The channel descriptor, if found. Otherwise, `None`""" + if self.MLI_WORKER_QUEUE in self: + return str(self[self.MLI_WORKER_QUEUE]) + return None + + @worker_queue.setter + def worker_queue(self, value: str) -> None: + """Set the channel descriptor used to send work to MLI worker managers. + + :param value: The channel descriptor""" + self[self.MLI_WORKER_QUEUE] = value + + @property + def creation_date(self) -> str: + """Return the creation date for the backbone feature store. + + :returns: The string-formatted date when feature store was created""" + return str(self[self._CREATED_ON]) + + def _record_creation_data(self) -> None: + """Write the creation timestamp to the feature store.""" + if self._CREATED_ON not in self: + if not self._allow_reserved_writes: + logger.warning( + "Recorded creation from a write-protected backbone instance" + ) + self[self._CREATED_ON] = str(time.time()) + + os.environ[self.MLI_BACKBONE] = self.descriptor + + @classmethod + def from_writable_descriptor( + cls, + descriptor: str, + ) -> "BackboneFeatureStore": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached DragonFeatureStore + :raises SmartSimError: if attachment to DragonFeatureStore fails + """ + try: + return BackboneFeatureStore(dragon_ddict.DDict.attach(descriptor), True) + except Exception as ex: + raise SmartSimError( + f"Error creating backbone feature store: {descriptor}" + ) from ex + + def _check_wait_timeout( + self, start_time: float, timeout: float, indicators: t.Dict[str, bool] + ) -> None: + """Perform timeout verification. + + :param start_time: the start time to use for elapsed calculation + :param timeout: the timeout (in seconds) + :param indicators: latest retrieval status for requested keys + :raises SmartSimError: If the timeout elapses before all values are + retrieved + """ + elapsed = time.time() - start_time + if timeout and elapsed > timeout: + raise SmartSimError( + f"Backbone {self.descriptor=} timeout after {elapsed} " + f"seconds retrieving keys: {indicators}" + ) + + def wait_for( + self, keys: t.List[str], timeout: float = _DEFAULT_WAIT_TIMEOUT + ) -> t.Dict[str, t.Union[str, bytes, None]]: + """Perform a blocking wait until all specified keys have been found + in the backbone. + + :param keys: The required collection of keys to retrieve + :param timeout: The maximum wait time in seconds + :returns: Dictionary containing the keys and values requested + :raises SmartSimError: If the timeout elapses without retrieving + all requested keys + """ + if timeout < 0: + timeout = self._DEFAULT_WAIT_TIMEOUT + logger.info(f"Using default wait_for timeout: {timeout}s") + + if not keys: + return {} + + values: t.Dict[str, t.Union[str, bytes, None]] = {k: None for k in set(keys)} + is_found = {k: False for k in values.keys()} + + backoff = (0.1, 0.2, 0.4, 0.8) + backoff_iter = itertools.cycle(backoff) + start_time = time.time() + + while not all(is_found.values()): + delay = next(backoff_iter) + + for key in [k for k, v in is_found.items() if not v]: + try: + values[key] = self[key] + is_found[key] = True + except Exception: + if delay == backoff[-1]: + logger.debug(f"Re-attempting `{key}` retrieval in {delay}s") + + if all(is_found.values()): + logger.debug(f"wait_for({keys}) retrieved all keys") + continue + + self._check_wait_timeout(start_time, timeout, is_found) + time.sleep(delay) + + return values + + def get_env(self) -> t.Dict[str, str]: + """Returns a dictionary populated with environment variables necessary to + connect a process to the existing backbone instance. + + :returns: The dictionary populated with env vars + """ + return {self.MLI_BACKBONE: self.descriptor} diff --git a/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py b/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py new file mode 100644 index 0000000000..24f2221c87 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py @@ -0,0 +1,126 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +# pylint: disable=import-error +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim._core.mli.infrastructure.storage.dragon_util import ( + ddict_to_descriptor, + descriptor_to_ddict, +) +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim.error import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class DragonFeatureStore(FeatureStore): + """A feature store backed by a dragon distributed dictionary.""" + + def __init__(self, storage: "dragon_ddict.DDict") -> None: + """Initialize the DragonFeatureStore instance. + + :param storage: A distributed dictionary to be used as the underlying + storage mechanism of the feature store""" + if storage is None: + raise ValueError( + "Storage is required when instantiating a DragonFeatureStore." + ) + + descriptor = "" + if isinstance(storage, dragon_ddict.DDict): + descriptor = ddict_to_descriptor(storage) + + super().__init__(descriptor) + self._storage: t.Dict[str, t.Union[str, bytes]] = storage + """The underlying storage mechanism of the DragonFeatureStore; a + distributed, in-memory key-value store""" + + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism. + + :param key: The unique key that identifies the resource + :returns: The value identified by the key + :raises KeyError: If the key has not been used to store a value + """ + try: + return self._storage[key] + except dragon_ddict.DDictError as e: + raise KeyError(f"Key not found in FeatureStore: {key}") from e + + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism. + + :param key: The unique key that identifies the resource + :param value: The value to store + :returns: The value identified by the key + """ + self._storage[key] = value + + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key. + + :param key: The unique key that identifies the resource + :returns: True if the key is defined, False otherwise + """ + return key in self._storage + + def pop(self, key: str) -> t.Union[str, bytes, None]: + """Remove the value from the dictionary and return the value. + + :param key: Dictionary key to retrieve + :returns: The value held at the key if it exists, otherwise `None + `""" + try: + return self._storage.pop(key) + except dragon_ddict.DDictError: + return None + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "DragonFeatureStore": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached DragonFeatureStore + :raises SmartSimError: If attachment to DragonFeatureStore fails + """ + try: + logger.debug(f"Attaching to FeatureStore with descriptor: {descriptor}") + storage = descriptor_to_ddict(descriptor) + return cls(storage) + except Exception as ex: + raise SmartSimError( + f"Error creating dragon feature store from descriptor: {descriptor}" + ) from ex diff --git a/smartsim/_core/mli/infrastructure/storage/dragon_util.py b/smartsim/_core/mli/infrastructure/storage/dragon_util.py new file mode 100644 index 0000000000..50d15664c0 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/storage/dragon_util.py @@ -0,0 +1,101 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +def ddict_to_descriptor(ddict: dragon_ddict.DDict) -> str: + """Convert a DDict to a descriptor string. + + :param ddict: The dragon dictionary to convert + :returns: The descriptor string + :raises ValueError: If a ddict is not provided + """ + if ddict is None: + raise ValueError("DDict is not available to create a descriptor") + + # unlike other dragon objects, the dictionary serializes to a string + # instead of bytes + return str(ddict.serialize()) + + +def descriptor_to_ddict(descriptor: str) -> dragon_ddict.DDict: + """Create and attach a new DDict instance given + the string-encoded descriptor. + + :param descriptor: The descriptor of a dictionary to attach to + :returns: The attached dragon dictionary""" + return dragon_ddict.DDict.attach(descriptor) + + +def create_ddict( + num_nodes: int, mgr_per_node: int, mem_per_node: int +) -> dragon_ddict.DDict: + """Create a distributed dragon dictionary. + + :param num_nodes: The number of distributed nodes to distribute the dictionary to. + At least one node is required. + :param mgr_per_node: The number of manager processes per node + :param mem_per_node: The amount of memory (in megabytes) to allocate per node. Total + memory available will be calculated as `num_nodes * node_mem` + + :returns: The instantiated dragon dictionary + :raises ValueError: If invalid num_nodes is supplied + :raises ValueError: If invalid mem_per_node is supplied + :raises ValueError: If invalid mgr_per_node is supplied + """ + if num_nodes < 1: + raise ValueError("A dragon dictionary must have at least 1 node") + + if mgr_per_node < 1: + raise ValueError("A dragon dict requires at least 2 managers per ndode") + + if mem_per_node < dragon_ddict.DDICT_MIN_SIZE: + raise ValueError( + "A dragon dictionary requires at least " + f"{dragon_ddict.DDICT_MIN_SIZE / 1024} MB" + ) + + mem_total = num_nodes * mem_per_node + + logger.debug( + f"Creating dragon dictionary with {num_nodes} nodes, {mem_total} MB memory" + ) + + distributed_dict = dragon_ddict.DDict(num_nodes, mgr_per_node, total_mem=mem_total) + logger.debug( + "Successfully created dragon dictionary with " + f"{num_nodes} nodes, {mem_total} MB total memory" + ) + return distributed_dict diff --git a/smartsim/_core/mli/infrastructure/storage/feature_store.py b/smartsim/_core/mli/infrastructure/storage/feature_store.py new file mode 100644 index 0000000000..ebca07ed4e --- /dev/null +++ b/smartsim/_core/mli/infrastructure/storage/feature_store.py @@ -0,0 +1,224 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import enum +import typing as t +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class ReservedKeys(str, enum.Enum): + """Contains constants used to identify all featurestore keys that + may not be to used by users. Avoids overwriting system data.""" + + MLI_NOTIFY_CONSUMERS = "_SMARTSIM_MLI_NOTIFY_CONSUMERS" + """Storage location for the list of registered consumers that will receive + events from an EventBroadcaster""" + + MLI_REGISTRAR_CONSUMER = "_SMARTIM_MLI_REGISTRAR_CONSUMER" + """Storage location for the channel used to send messages directly to + the MLI backend""" + + MLI_WORKER_QUEUE = "_SMARTSIM_REQUEST_QUEUE" + """Storage location for the channel used to send work requests + to the available worker managers""" + + @classmethod + def contains(cls, value: str) -> bool: + """Convert a string representation into an enumeration member. + + :param value: The string to convert + :returns: The enumeration member if the conversion succeeded, otherwise None + """ + try: + cls(value) + except ValueError: + return False + + return True + + +@dataclass(frozen=True) +class TensorKey: + """A key,descriptor pair enabling retrieval of an item from a feature store.""" + + key: str + """The unique key of an item in a feature store""" + descriptor: str + """The unique identifier of the feature store containing the key""" + + def __post_init__(self) -> None: + """Ensure the key and descriptor have at least one character. + + :raises ValueError: If key or descriptor are empty strings + """ + if len(self.key) < 1: + raise ValueError("Key must have at least one character.") + if len(self.descriptor) < 1: + raise ValueError("Descriptor must have at least one character.") + + +@dataclass(frozen=True) +class ModelKey: + """A key,descriptor pair enabling retrieval of an item from a feature store.""" + + key: str + """The unique key of an item in a feature store""" + descriptor: str + """The unique identifier of the feature store containing the key""" + + def __post_init__(self) -> None: + """Ensure the key and descriptor have at least one character. + + :raises ValueError: If key or descriptor are empty strings + """ + if len(self.key) < 1: + raise ValueError("Key must have at least one character.") + if len(self.descriptor) < 1: + raise ValueError("Descriptor must have at least one character.") + + +class FeatureStore(ABC): + """Abstract base class providing the common interface for retrieving + values from a feature store implementation.""" + + def __init__(self, descriptor: str, allow_reserved_writes: bool = False) -> None: + """Initialize the feature store. + + :param descriptor: The stringified version of a storage descriptor + :param allow_reserved_writes: Override the default behavior of blocking + writes to reserved keys + """ + self._enable_reserved_writes = allow_reserved_writes + """Flag used to ensure that any keys written by the system to a feature store + are not overwritten by user code. Disabled by default. Subclasses must set the + value intentionally.""" + self._descriptor = descriptor + """Stringified version of the unique ID enabling a client to connect + to the feature store""" + + def _check_reserved(self, key: str) -> None: + """A utility method used to verify access to write to a reserved key + in the FeatureStore. Used by subclasses in __setitem___ implementations. + + :param key: A key to compare to the reserved keys + :raises SmartSimError: If the key is reserved + """ + if not self._enable_reserved_writes and ReservedKeys.contains(key): + raise SmartSimError( + "Use of reserved key denied. " + "Unable to overwrite system configuration" + ) + + def __getitem__(self, key: str) -> t.Union[str, bytes]: + """Retrieve an item using key. + + :param key: Unique key of an item to retrieve from the feature store + :returns: An item in the FeatureStore + :raises SmartSimError: If retrieving fails + """ + try: + return self._get(key) + except KeyError: + raise + except Exception as ex: + # note: explicitly avoid round-trip to check for key existence + raise SmartSimError( + f"Could not get value for existing key {key}, error:\n{ex}" + ) from ex + + def __setitem__(self, key: str, value: t.Union[str, bytes]) -> None: + """Assign a value using key. + + :param key: Unique key of an item to set in the feature store + :param value: Value to persist in the feature store + """ + self._check_reserved(key) + self._set(key, value) + + def __contains__(self, key: str) -> bool: + """Membership operator to test for a key existing within the feature store. + + :param key: Unique key of an item to retrieve from the feature store + :returns: `True` if the key is found, `False` otherwise + """ + return self._contains(key) + + @abstractmethod + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism. + + :param key: The unique key that identifies the resource + :returns: The value identified by the key + :raises KeyError: If the key has not been used to store a value + """ + + @abstractmethod + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism. + + :param key: The unique key that identifies the resource + :param value: The value to store + """ + + @abstractmethod + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key. + + :param key: The unique key that identifies the resource + :returns: `True` if the key is defined, `False` otherwise + """ + + @property + def _allow_reserved_writes(self) -> bool: + """Return the boolean flag indicating if writing to reserved keys is + enabled for this feature store. + + :returns: `True` if enabled, `False` otherwise + """ + return self._enable_reserved_writes + + @_allow_reserved_writes.setter + def _allow_reserved_writes(self, value: bool) -> None: + """Modify the boolean flag indicating if writing to reserved keys is + enabled for this feature store. + + :param value: The new value to set for the flag + """ + self._enable_reserved_writes = value + + @property + def descriptor(self) -> str: + """Unique identifier enabling a client to connect to the feature store. + + :returns: A descriptor encoded as a string + """ + return self._descriptor diff --git a/smartsim/_core/mli/infrastructure/worker/__init__.py b/smartsim/_core/mli/infrastructure/worker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py new file mode 100644 index 0000000000..64e94e5eb6 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/worker/torch_worker.py @@ -0,0 +1,276 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io + +import numpy as np +import torch + +# pylint: disable=import-error +from dragon.managed_memory import MemoryAlloc, MemoryPool + +from .....error import SmartSimError +from .....log import get_logger +from ...mli_schemas.tensor import tensor_capnp +from .worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) + +# pylint: enable=import-error + + +torch.set_num_threads(1) +torch.set_num_interop_threads(4) +logger = get_logger(__name__) + + +class TorchWorker(MachineLearningWorkerBase): + """A worker that executes a PyTorch model.""" + + @staticmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + """Given a loaded MachineLearningModel, ensure it is loaded into + device memory. + + :param request: The request that triggered the pipeline + :param device: The device on which the model must be placed + :returns: LoadModelResult wrapping the model loaded for the request + :raises ValueError: If model reference object is not found + :raises RuntimeError: If loading and evaluating the model failed + """ + if fetch_result.model_bytes: + model_bytes = fetch_result.model_bytes + elif batch.raw_model and batch.raw_model.data: + model_bytes = batch.raw_model.data + else: + raise ValueError("Unable to load model without reference object") + + device_to_torch = {"cpu": "cpu", "gpu": "cuda"} + for old, new in device_to_torch.items(): + device = device.replace(old, new) + + buffer = io.BytesIO(initial_bytes=model_bytes) + try: + with torch.no_grad(): + model = torch.jit.load(buffer, map_location=device) # type: ignore + model.eval() + except Exception as e: + raise RuntimeError( + "Failed to load and evaluate the model: " + f"Model key {batch.model_id.key}, Device {device}" + ) from e + result = LoadModelResult(model) + return result + + @staticmethod + def transform_input( + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: MemoryPool, + ) -> TransformInputResult: + """Given a collection of data, perform a transformation on the data and put + the raw tensor data on a MemoryPool allocation. + + :param request: The request that triggered the pipeline + :param fetch_result: Raw outputs from fetching inputs out of a feature store + :param mem_pool: The memory pool used to access batched input tensors + :returns: The transformed inputs wrapped in a TransformInputResult + :raises ValueError: If tensors cannot be reconstructed + :raises IndexError: If index out of range + """ + results: list[torch.Tensor] = [] + total_samples = 0 + slices: list[slice] = [] + + all_dims: list[list[int]] = [] + all_dtypes: list[str] = [] + if fetch_results[0].meta is None: + raise ValueError("Cannot reconstruct tensor without meta information") + # Traverse inputs to get total number of samples and compute slices + # Assumption: first dimension is samples, all tensors in the same input + # have same number of samples + # thus we only look at the first tensor for each input + for res_idx, fetch_result in enumerate(fetch_results): + if fetch_result.meta is None or any( + item_meta is None for item_meta in fetch_result.meta + ): + raise ValueError("Cannot reconstruct tensor without meta information") + first_tensor_desc: tensor_capnp.TensorDescriptor = fetch_result.meta[0] + num_samples = first_tensor_desc.dimensions[0] + slices.append(slice(total_samples, total_samples + num_samples)) + total_samples = total_samples + num_samples + + if res_idx == len(fetch_results) - 1: + # For each tensor in the last input, get remaining dimensions + # Assumptions: all inputs have the same number of tensors and + # last N-1 dimensions match across inputs for corresponding tensors + # thus: resulting array will be of size (num_samples, all_other_dims) + for item_meta in fetch_result.meta: + tensor_desc: tensor_capnp.TensorDescriptor = item_meta + tensor_dims = list(tensor_desc.dimensions) + all_dims.append([total_samples, *tensor_dims[1:]]) + all_dtypes.append(str(tensor_desc.dataType)) + + for result_tensor_idx, (dims, dtype) in enumerate(zip(all_dims, all_dtypes)): + itemsize = np.empty((1), dtype=dtype).itemsize + alloc_size = int(np.prod(dims) * itemsize) + mem_alloc = mem_pool.alloc(alloc_size) + mem_view = mem_alloc.get_memview() + try: + mem_view[:alloc_size] = b"".join( + [ + fetch_result.inputs[result_tensor_idx] + for fetch_result in fetch_results + ] + ) + except IndexError as e: + raise IndexError( + "Error accessing elements in fetch_result.inputs " + f"with index {result_tensor_idx}" + ) from e + + results.append(mem_alloc.serialize()) + + return TransformInputResult(results, slices, all_dims, all_dtypes) + + # pylint: disable-next=unused-argument + @staticmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + """Execute an ML model on inputs transformed for use by the model. + + :param batch: The batch of requests that triggered the pipeline + :param load_result: The result of loading the model onto device memory + :param transform_result: The result of transforming inputs for model consumption + :param device: The device on which the model will be executed + :returns: The result of inference wrapped in an ExecuteResult + :raises SmartSimError: If model is not loaded + :raises IndexError: If memory slicing is out of range + :raises ValueError: If tensor creation fails or is unable to evaluate the model + """ + if not load_result.model: + raise SmartSimError("Model must be loaded to execute") + device_to_torch = {"cpu": "cpu", "gpu": "cuda"} + for old, new in device_to_torch.items(): + device = device.replace(old, new) + + tensors = [] + mem_allocs = [] + for transformed, dims, dtype in zip( + transform_result.transformed, transform_result.dims, transform_result.dtypes + ): + mem_alloc = MemoryAlloc.attach(transformed) + mem_allocs.append(mem_alloc) + itemsize = np.empty((1), dtype=dtype).itemsize + try: + tensors.append( + torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[0 : np.prod(dims) * itemsize], + dtype=dtype, + ).reshape(dims) + ) + ) + except IndexError as e: + raise IndexError("Error during memory slicing") from e + except Exception as e: + raise ValueError("Error during tensor creation") from e + + model: torch.nn.Module = load_result.model + try: + with torch.no_grad(): + model.eval() + results = [ + model( + *[ + tensor.to(device, non_blocking=True).detach() + for tensor in tensors + ] + ) + ] + except Exception as e: + raise ValueError( + f"Error while evaluating the model: Model {batch.model_id.key}" + ) from e + + transform_result.transformed = [] + + execute_result = ExecuteResult(results, transform_result.slices) + for mem_alloc in mem_allocs: + mem_alloc.free() + return execute_result + + @staticmethod + def transform_output( + batch: RequestBatch, + execute_result: ExecuteResult, + ) -> list[TransformOutputResult]: + """Given inference results, perform transformations required to + transmit results to the requestor. + + :param batch: The batch of requests that triggered the pipeline + :param execute_result: The result of inference wrapped in an ExecuteResult + :returns: A list of transformed outputs + :raises IndexError: If indexing is out of range + :raises ValueError: If transforming output fails + """ + transformed_list: list[TransformOutputResult] = [] + cpu_predictions = [ + prediction.cpu() for prediction in execute_result.predictions + ] + for result_slice in execute_result.slices: + transformed = [] + for cpu_item in cpu_predictions: + try: + transformed.append(cpu_item[result_slice].numpy().tobytes()) + + # todo: need the shape from latest schemas added here. + transformed_list.append( + TransformOutputResult(transformed, None, "c", "float32") + ) # fixme + except IndexError as e: + raise IndexError( + f"Error accessing elements: result_slice {result_slice}" + ) from e + except Exception as e: + raise ValueError("Error transforming output") from e + + execute_result.predictions = [] + + return transformed_list diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py new file mode 100644 index 0000000000..9556b8e438 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -0,0 +1,646 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +from dragon.managed_memory import MemoryPool + +# isort: off +# isort: on + +import typing as t +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from .....error import SmartSimError +from .....log import get_logger +from ...comm.channel.channel import CommChannelBase +from ...message_handler import MessageHandler +from ...mli_schemas.model.model_capnp import Model +from ..storage.feature_store import FeatureStore, ModelKey, TensorKey + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + from smartsim._core.mli.mli_schemas.tensor.tensor_capnp import TensorDescriptor + +logger = get_logger(__name__) + +# Placeholder +ModelIdentifier = ModelKey + + +class InferenceRequest: + """Internal representation of an inference request from a client.""" + + def __init__( + self, + model_key: t.Optional[ModelKey] = None, + callback: t.Optional[CommChannelBase] = None, + raw_inputs: t.Optional[t.List[bytes]] = None, + input_keys: t.Optional[t.List[TensorKey]] = None, + input_meta: t.Optional[t.List[t.Any]] = None, + output_keys: t.Optional[t.List[TensorKey]] = None, + raw_model: t.Optional[Model] = None, + batch_size: int = 0, + ): + """Initialize the InferenceRequest. + + :param model_key: A tuple containing a (key, descriptor) pair + :param callback: The channel used for notification of inference completion + :param raw_inputs: Raw bytes of tensor inputs + :param input_keys: A list of tuples containing a (key, descriptor) pair + :param input_meta: Metadata about the input data + :param output_keys: A list of tuples containing a (key, descriptor) pair + :param raw_model: Raw bytes of an ML model + :param batch_size: The batch size to apply when batching + """ + self.model_key = model_key + """A tuple containing a (key, descriptor) pair""" + self.raw_model = raw_model + """Raw bytes of an ML model""" + self.callback = callback + """The channel used for notification of inference completion""" + self.raw_inputs = raw_inputs or [] + """Raw bytes of tensor inputs""" + self.input_keys = input_keys or [] + """A list of tuples containing a (key, descriptor) pair""" + self.input_meta = input_meta or [] + """Metadata about the input data""" + self.output_keys = output_keys or [] + """A list of tuples containing a (key, descriptor) pair""" + self.batch_size = batch_size + """The batch size to apply when batching""" + + @property + def has_raw_model(self) -> bool: + """Check if the InferenceRequest contains a raw_model. + + :returns: True if raw_model is not None, False otherwise + """ + return self.raw_model is not None + + @property + def has_model_key(self) -> bool: + """Check if the InferenceRequest contains a model_key. + + :returns: True if model_key is not None, False otherwise + """ + return self.model_key is not None + + @property + def has_raw_inputs(self) -> bool: + """Check if the InferenceRequest contains raw_inputs. + + :returns: True if raw_outputs is not None and is not an empty list, + False otherwise + """ + return self.raw_inputs is not None and bool(self.raw_inputs) + + @property + def has_input_keys(self) -> bool: + """Check if the InferenceRequest contains input_keys. + + :returns: True if input_keys is not None and is not an empty list, + False otherwise + """ + return self.input_keys is not None and bool(self.input_keys) + + @property + def has_output_keys(self) -> bool: + """Check if the InferenceRequest contains output_keys. + + :returns: True if output_keys is not None and is not an empty list, + False otherwise + """ + return self.output_keys is not None and bool(self.output_keys) + + @property + def has_input_meta(self) -> bool: + """Check if the InferenceRequest contains input_meta. + + :returns: True if input_meta is not None and is not an empty list, + False otherwise + """ + return self.input_meta is not None and bool(self.input_meta) + + +class InferenceReply: + """Internal representation of the reply to a client request for inference.""" + + def __init__( + self, + outputs: t.Optional[t.Collection[t.Any]] = None, + output_keys: t.Optional[t.Collection[TensorKey]] = None, + status_enum: "Status" = "running", + message: str = "In progress", + ) -> None: + """Initialize the InferenceReply. + + :param outputs: List of output data + :param output_keys: List of keys used for output data + :param status_enum: Status of the reply + :param message: Status message that corresponds with the status enum + """ + self.outputs: t.Collection[t.Any] = outputs or [] + """List of output data""" + self.output_keys: t.Collection[t.Optional[TensorKey]] = output_keys or [] + """List of keys used for output data""" + self.status_enum = status_enum + """Status of the reply""" + self.message = message + """Status message that corresponds with the status enum""" + + @property + def has_outputs(self) -> bool: + """Check if the InferenceReply contains outputs. + + :returns: True if outputs is not None and is not an empty list, + False otherwise + """ + return self.outputs is not None and bool(self.outputs) + + @property + def has_output_keys(self) -> bool: + """Check if the InferenceReply contains output_keys. + + :returns: True if output_keys is not None and is not an empty list, + False otherwise + """ + return self.output_keys is not None and bool(self.output_keys) + + +class LoadModelResult: + """A wrapper around a loaded model.""" + + def __init__(self, model: t.Any) -> None: + """Initialize the LoadModelResult. + + :param model: The loaded model + """ + self.model = model + """The loaded model (e.g. a TensorFlow, PyTorch, ONNX, etc. model)""" + + +class TransformInputResult: + """A wrapper around a transformed batch of input tensors""" + + def __init__( + self, + result: t.Any, + slices: list[slice], + dims: list[list[int]], + dtypes: list[str], + ) -> None: + """Initialize the TransformInputResult. + + :param result: List of Dragon MemoryAlloc objects on which + the tensors are stored + :param slices: The slices that represent which portion of the + input tensors belongs to which request + :param dims: Dimension of the transformed tensors + :param dtypes: Data type of transformed tensors + """ + self.transformed = result + """List of Dragon MemoryAlloc objects on which the tensors are stored""" + self.slices = slices + """Each slice represents which portion of the input tensors belongs to + which request""" + self.dims = dims + """Dimension of the transformed tensors""" + self.dtypes = dtypes + """Data type of transformed tensors""" + + +class ExecuteResult: + """A wrapper around inference results.""" + + def __init__(self, result: t.Any, slices: list[slice]) -> None: + """Initialize the ExecuteResult. + + :param result: Result of the execution + :param slices: The slices that represent which portion of the input + tensors belongs to which request + """ + self.predictions = result + """Result of the execution""" + self.slices = slices + """The slices that represent which portion of the input + tensors belongs to which request""" + + +class FetchInputResult: + """A wrapper around fetched inputs.""" + + def __init__(self, result: t.List[bytes], meta: t.Optional[t.List[t.Any]]) -> None: + """Initialize the FetchInputResult. + + :param result: List of input tensor bytes + :param meta: List of metadata that corresponds with the inputs + """ + self.inputs = result + """List of input tensor bytes""" + self.meta = meta + """List of metadata that corresponds with the inputs""" + + +class TransformOutputResult: + """A wrapper around inference results transformed for transmission.""" + + def __init__( + self, result: t.Any, shape: t.Optional[t.List[int]], order: str, dtype: str + ) -> None: + """Initialize the TransformOutputResult. + + :param result: Transformed output results + :param shape: Shape of output results + :param order: Order of output results + :param dtype: Datatype of output results + """ + self.outputs = result + """Transformed output results""" + self.shape = shape + """Shape of output results""" + self.order = order + """Order of output results""" + self.dtype = dtype + """Datatype of output results""" + + +class CreateInputBatchResult: + """A wrapper around inputs batched into a single request.""" + + def __init__(self, result: t.Any) -> None: + """Initialize the CreateInputBatchResult. + + :param result: Inputs batched into a single request + """ + self.batch = result + """Inputs batched into a single request""" + + +class FetchModelResult: + """A wrapper around raw fetched models.""" + + def __init__(self, result: bytes) -> None: + """Initialize the FetchModelResult. + + :param result: The raw fetched model + """ + self.model_bytes: bytes = result + """The raw fetched model""" + + +@dataclass +class RequestBatch: + """A batch of aggregated inference requests.""" + + requests: list[InferenceRequest] + """List of InferenceRequests in the batch""" + inputs: t.Optional[TransformInputResult] + """Transformed batch of input tensors""" + model_id: "ModelIdentifier" + """Model (key, descriptor) tuple""" + + @property + def has_valid_requests(self) -> bool: + """Returns whether the batch contains at least one request. + + :returns: True if at least one request is available + """ + return len(self.requests) > 0 + + @property + def has_raw_model(self) -> bool: + """Returns whether the batch has a raw model. + + :returns: True if the batch has a raw model + """ + return self.raw_model is not None + + @property + def raw_model(self) -> t.Optional[t.Any]: + """Returns the raw model to use to execute for this batch + if it is available. + + :returns: A model if available, otherwise None""" + if self.has_valid_requests: + return self.requests[0].raw_model + return None + + @property + def input_keys(self) -> t.List[TensorKey]: + """All input keys available in this batch's requests. + + :returns: All input keys belonging to requests in this batch""" + keys = [] + for request in self.requests: + keys.extend(request.input_keys) + + return keys + + @property + def output_keys(self) -> t.List[TensorKey]: + """All output keys available in this batch's requests. + + :returns: All output keys belonging to requests in this batch""" + keys = [] + for request in self.requests: + keys.extend(request.output_keys) + + return keys + + +class MachineLearningWorkerCore: + """Basic functionality of ML worker that is shared across all worker types.""" + + @staticmethod + def deserialize_message( + data_blob: bytes, + callback_factory: t.Callable[[str], CommChannelBase], + ) -> InferenceRequest: + """Deserialize a message from a byte stream into an InferenceRequest. + + :param data_blob: The byte stream to deserialize + :param callback_factory: A factory method that can create an instance + of the desired concrete comm channel type + :returns: The raw input message deserialized into an InferenceRequest + """ + request = MessageHandler.deserialize_request(data_blob) + model_key: t.Optional[ModelKey] = None + model_bytes: t.Optional[Model] = None + + if request.model.which() == "key": + model_key = ModelKey( + key=request.model.key.key, + descriptor=request.model.key.descriptor, + ) + elif request.model.which() == "data": + model_bytes = request.model.data + + callback_key = request.replyChannel.descriptor + comm_channel = callback_factory(callback_key) + input_keys: t.Optional[t.List[TensorKey]] = None + input_bytes: t.Optional[t.List[bytes]] = None + output_keys: t.Optional[t.List[TensorKey]] = None + input_meta: t.Optional[t.List[TensorDescriptor]] = None + + if request.input.which() == "keys": + input_keys = [ + TensorKey(key=value.key, descriptor=value.descriptor) + for value in request.input.keys + ] + elif request.input.which() == "descriptors": + input_meta = request.input.descriptors # type: ignore + + if request.output: + output_keys = [ + TensorKey(key=value.key, descriptor=value.descriptor) + for value in request.output + ] + + inference_request = InferenceRequest( + model_key=model_key, + callback=comm_channel, + raw_inputs=input_bytes, + input_meta=input_meta, + input_keys=input_keys, + output_keys=output_keys, + raw_model=model_bytes, + batch_size=0, + ) + return inference_request + + @staticmethod + def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]: + """Assemble the output information based on whether the output + information will be in the form of TensorKeys or TensorDescriptors. + + :param reply: The reply that the output belongs to + :returns: The list of prepared outputs, depending on the output + information needed in the reply + """ + prepared_outputs: t.List[t.Any] = [] + if reply.has_output_keys: + for value in reply.output_keys: + if not value: + continue + msg_key = MessageHandler.build_tensor_key(value.key, value.descriptor) + prepared_outputs.append(msg_key) + elif reply.has_outputs: + for _ in reply.outputs: + msg_tensor_desc = MessageHandler.build_tensor_descriptor( + "c", + "float32", + [1], + ) + prepared_outputs.append(msg_tensor_desc) + return prepared_outputs + + @staticmethod + def fetch_model( + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] + ) -> FetchModelResult: + """Given a resource key, retrieve the raw model from a feature store. + + :param batch: The batch of requests that triggered the pipeline + :param feature_stores: Available feature stores used for persistence + :returns: Raw bytes of the model + :raises SmartSimError: If neither a key or a model are provided or the + model cannot be retrieved from the feature store + :raises ValueError: If a feature store is not available and a raw + model is not provided + """ + # All requests in the same batch share the model + if batch.raw_model: + return FetchModelResult(batch.raw_model.data) + + if not feature_stores: + raise ValueError("Feature store is required for model retrieval") + + if batch.model_id is None: + raise SmartSimError( + "Key must be provided to retrieve model from feature store" + ) + + key, fsd = batch.model_id.key, batch.model_id.descriptor + + try: + feature_store = feature_stores[fsd] + raw_bytes: bytes = t.cast(bytes, feature_store[key]) + return FetchModelResult(raw_bytes) + except (FileNotFoundError, KeyError) as ex: + logger.exception(ex) + raise SmartSimError(f"Model could not be retrieved with key {key}") from ex + + @staticmethod + def fetch_inputs( + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] + ) -> t.List[FetchInputResult]: + """Given a collection of ResourceKeys, identify the physical location + and input metadata. + + :param batch: The batch of requests that triggered the pipeline + :param feature_stores: Available feature stores used for persistence + :returns: The fetched input + :raises ValueError: If neither an input key or an input tensor are provided + :raises SmartSimError: If a tensor for a given key cannot be retrieved + """ + fetch_results = [] + for request in batch.requests: + if request.raw_inputs: + fetch_results.append( + FetchInputResult(request.raw_inputs, request.input_meta) + ) + continue + + if not feature_stores: + raise ValueError("No input and no feature store provided") + + if request.has_input_keys: + data: t.List[bytes] = [] + + for fs_key in request.input_keys: + try: + feature_store = feature_stores[fs_key.descriptor] + tensor_bytes = t.cast(bytes, feature_store[fs_key.key]) + data.append(tensor_bytes) + except KeyError as ex: + logger.exception(ex) + raise SmartSimError( + f"Tensor could not be retrieved with key {fs_key.key}" + ) from ex + fetch_results.append( + FetchInputResult(data, meta=None) + ) # fixme: need to get both tensor and descriptor + continue + + raise ValueError("No input source") + + return fetch_results + + @staticmethod + def place_output( + request: InferenceRequest, + transform_result: TransformOutputResult, + feature_stores: t.Dict[str, FeatureStore], + ) -> t.Collection[t.Optional[TensorKey]]: + """Given a collection of data, make it available as a shared resource in the + feature store. + + :param request: The request that triggered the pipeline + :param transform_result: Transformed version of the inference result + :param feature_stores: Available feature stores used for persistence + :returns: A collection of keys that were placed in the feature store + :raises ValueError: If a feature store is not provided + """ + if not feature_stores: + raise ValueError("Feature store is required for output persistence") + + keys: t.List[t.Optional[TensorKey]] = [] + # need to decide how to get back to original sub-batch inputs so they can be + # accurately placed, datum might need to include this. + + # Consider parallelizing all PUT feature_store operations + for fs_key, v in zip(request.output_keys, transform_result.outputs): + feature_store = feature_stores[fs_key.descriptor] + feature_store[fs_key.key] = v + keys.append(fs_key) + + return keys + + +class MachineLearningWorkerBase(MachineLearningWorkerCore, ABC): + """Abstract base class providing contract for a machine learning + worker implementation.""" + + @staticmethod + @abstractmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + """Given the raw bytes of an ML model that were fetched, ensure + it is loaded into device memory. + + :param request: The request that triggered the pipeline + :param fetch_result: The result of a fetch-model operation; contains + the raw bytes of the ML model. + :param device: The device on which the model must be placed + :returns: LoadModelResult wrapping the model loaded for the request + :raises ValueError: If model reference object is not found + :raises RuntimeError: If loading and evaluating the model failed + """ + + @staticmethod + @abstractmethod + def transform_input( + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: MemoryPool, + ) -> TransformInputResult: + """Given a collection of data, perform a transformation on the data and put + the raw tensor data on a MemoryPool allocation. + + :param batch: The request that triggered the pipeline + :param fetch_result: Raw outputs from fetching inputs out of a feature store + :param mem_pool: The memory pool used to access batched input tensors + :returns: The transformed inputs wrapped in a TransformInputResult + :raises ValueError: If tensors cannot be reconstructed + :raises IndexError: If index out of range + """ + + @staticmethod + @abstractmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + """Execute an ML model on inputs transformed for use by the model. + + :param batch: The batch of requests that triggered the pipeline + :param load_result: The result of loading the model onto device memory + :param transform_result: The result of transforming inputs for model consumption + :param device: The device on which the model will be executed + :returns: The result of inference wrapped in an ExecuteResult + :raises SmartSimError: If model is not loaded + :raises IndexError: If memory slicing is out of range + :raises ValueError: If tensor creation fails or is unable to evaluate the model + """ + + @staticmethod + @abstractmethod + def transform_output( + batch: RequestBatch, execute_result: ExecuteResult + ) -> t.List[TransformOutputResult]: + """Given inference results, perform transformations required to + transmit results to the requestor. + + :param batch: The batch of requests that triggered the pipeline + :param execute_result: The result of inference wrapped in an ExecuteResult + :returns: A list of transformed outputs + :raises IndexError: If indexing is out of range + :raises ValueError: If transforming output fails + """ diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py new file mode 100644 index 0000000000..e3d46a7ab3 --- /dev/null +++ b/smartsim/_core/mli/message_handler.py @@ -0,0 +1,602 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import typing as t + +from .mli_schemas.data import data_references_capnp +from .mli_schemas.model import model_capnp +from .mli_schemas.request import request_capnp +from .mli_schemas.request.request_attributes import request_attributes_capnp +from .mli_schemas.response import response_capnp +from .mli_schemas.response.response_attributes import response_attributes_capnp +from .mli_schemas.tensor import tensor_capnp + + +class MessageHandler: + """Utility methods for transforming capnproto messages to and from + internal representations. + """ + + @staticmethod + def build_tensor_descriptor( + order: "tensor_capnp.Order", + data_type: "tensor_capnp.NumericalType", + dimensions: t.List[int], + ) -> tensor_capnp.TensorDescriptor: + """ + Builds a TensorDescriptor message using the provided + order, data type, and dimensions. + + :param order: Order of the tensor, such as row-major (c) or column-major (f) + :param data_type: Data type of the tensor + :param dimensions: Dimensions of the tensor + :returns: The TensorDescriptor + :raises ValueError: If building fails + """ + try: + description = tensor_capnp.TensorDescriptor.new_message() + description.order = order + description.dataType = data_type + description.dimensions = dimensions + except Exception as e: + raise ValueError("Error building tensor descriptor.") from e + + return description + + @staticmethod + def build_output_tensor_descriptor( + order: "tensor_capnp.Order", + keys: t.List["data_references_capnp.TensorKey"], + data_type: "tensor_capnp.ReturnNumericalType", + dimensions: t.List[int], + ) -> tensor_capnp.OutputDescriptor: + """ + Builds an OutputDescriptor message using the provided + order, data type, and dimensions. + + :param order: Order of the tensor, such as row-major (c) or column-major (f) + :param keys: List of TensorKey to apply transorm descriptor to + :param data_type: Tranform data type of the tensor + :param dimensions: Transform dimensions of the tensor + :returns: The OutputDescriptor + :raises ValueError: If building fails + """ + try: + description = tensor_capnp.OutputDescriptor.new_message() + description.order = order + description.optionalKeys = keys + description.optionalDatatype = data_type + description.optionalDimension = dimensions + + except Exception as e: + raise ValueError("Error building output tensor descriptor.") from e + + return description + + @staticmethod + def build_tensor_key(key: str, descriptor: str) -> data_references_capnp.TensorKey: + """ + Builds a new TensorKey message with the provided key. + + :param key: String to set the TensorKey + :param descriptor: A descriptor identifying the feature store + containing the key + :returns: The TensorKey + :raises ValueError: If building fails + """ + try: + tensor_key = data_references_capnp.TensorKey.new_message() + tensor_key.key = key + tensor_key.descriptor = descriptor + except Exception as e: + raise ValueError("Error building tensor key.") from e + return tensor_key + + @staticmethod + def build_model(data: bytes, name: str, version: str) -> model_capnp.Model: + """ + Builds a new Model message with the provided data, name, and version. + + :param data: Model data + :param name: Model name + :param version: Model version + :returns: The Model + :raises ValueError: If building fails + """ + try: + model = model_capnp.Model.new_message() + model.data = data + model.name = name + model.version = version + except Exception as e: + raise ValueError("Error building model.") from e + return model + + @staticmethod + def build_model_key(key: str, descriptor: str) -> data_references_capnp.ModelKey: + """ + Builds a new ModelKey message with the provided key. + + :param key: String to set the ModelKey + :param descriptor: A descriptor identifying the feature store + containing the key + :returns: The ModelKey + :raises ValueError: If building fails + """ + try: + model_key = data_references_capnp.ModelKey.new_message() + model_key.key = key + model_key.descriptor = descriptor + except Exception as e: + raise ValueError("Error building tensor key.") from e + return model_key + + @staticmethod + def build_torch_request_attributes( + tensor_type: "request_attributes_capnp.TorchTensorType", + ) -> request_attributes_capnp.TorchRequestAttributes: + """ + Builds a new TorchRequestAttributes message with the provided tensor type. + + :param tensor_type: Type of the tensor passed in + :returns: The TorchRequestAttributes + :raises ValueError: If building fails + """ + try: + attributes = request_attributes_capnp.TorchRequestAttributes.new_message() + attributes.tensorType = tensor_type + except Exception as e: + raise ValueError("Error building Torch request attributes.") from e + return attributes + + @staticmethod + def build_tf_request_attributes( + name: str, tensor_type: "request_attributes_capnp.TFTensorType" + ) -> request_attributes_capnp.TensorFlowRequestAttributes: + """ + Builds a new TensorFlowRequestAttributes message with + the provided name and tensor type. + + :param name: Name of the tensor + :param tensor_type: Type of the tensor passed in + :returns: The TensorFlowRequestAttributes + :raises ValueError: If building fails + """ + try: + attributes = ( + request_attributes_capnp.TensorFlowRequestAttributes.new_message() + ) + attributes.name = name + attributes.tensorType = tensor_type + except Exception as e: + raise ValueError("Error building TensorFlow request attributes.") from e + return attributes + + @staticmethod + def build_torch_response_attributes() -> ( + response_attributes_capnp.TorchResponseAttributes + ): + """ + Builds a new TorchResponseAttributes message. + + :returns: The TorchResponseAttributes + """ + return response_attributes_capnp.TorchResponseAttributes.new_message() + + @staticmethod + def build_tf_response_attributes() -> ( + response_attributes_capnp.TensorFlowResponseAttributes + ): + """ + Builds a new TensorFlowResponseAttributes message. + + :returns: The TensorFlowResponseAttributes + """ + return response_attributes_capnp.TensorFlowResponseAttributes.new_message() + + @staticmethod + def _assign_model( + request: request_capnp.Request, + model: t.Union[data_references_capnp.ModelKey, model_capnp.Model], + ) -> None: + """ + Assigns a model to the supplied request. + + :param request: Request being built + :param model: Model to be assigned + :raises ValueError: If building fails + """ + try: + class_name = model.schema.node.displayName.split(":")[-1] # type: ignore + if class_name == "Model": + request.model.data = model # type: ignore + elif class_name == "ModelKey": + request.model.key = model # type: ignore + else: + raise ValueError("""Invalid custom attribute class name. + Expected 'Model' or 'ModelKey'.""") + except Exception as e: + raise ValueError("Error building model portion of request.") from e + + @staticmethod + def _assign_reply_channel( + request: request_capnp.Request, reply_channel: str + ) -> None: + """ + Assigns a reply channel to the supplied request. + + :param request: Request being built + :param reply_channel: Reply channel to be assigned + :raises ValueError: If building fails + """ + try: + request.replyChannel.descriptor = reply_channel + except Exception as e: + raise ValueError("Error building reply channel portion of request.") from e + + @staticmethod + def _assign_inputs( + request: request_capnp.Request, + inputs: t.Union[ + t.List[data_references_capnp.TensorKey], + t.List[tensor_capnp.TensorDescriptor], + ], + ) -> None: + """ + Assigns inputs to the supplied request. + + :param request: Request being built + :param inputs: Inputs to be assigned + :raises ValueError: If building fails + """ + try: + if inputs: + display_name = inputs[0].schema.node.displayName # type: ignore + input_class_name = display_name.split(":")[-1] + if input_class_name == "TensorDescriptor": + request.input.descriptors = inputs # type: ignore + elif input_class_name == "TensorKey": + request.input.keys = inputs # type: ignore + else: + raise ValueError("""Invalid input class name. Expected + 'TensorDescriptor' or 'TensorKey'.""") + except Exception as e: + raise ValueError("Error building inputs portion of request.") from e + + @staticmethod + def _assign_outputs( + request: request_capnp.Request, + outputs: t.List[data_references_capnp.TensorKey], + ) -> None: + """ + Assigns outputs to the supplied request. + + :param request: Request being built + :param outputs: Outputs to be assigned + :raises ValueError: If building fails + """ + try: + request.output = outputs + + except Exception as e: + raise ValueError("Error building outputs portion of request.") from e + + @staticmethod + def _assign_output_descriptors( + request: request_capnp.Request, + output_descriptors: t.List[tensor_capnp.OutputDescriptor], + ) -> None: + """ + Assigns a list of output tensor descriptors to the supplied request. + + :param request: Request being built + :param output_descriptors: Output descriptors to be assigned + :raises ValueError: If building fails + """ + try: + request.outputDescriptors = output_descriptors + except Exception as e: + raise ValueError( + "Error building the output descriptors portion of request." + ) from e + + @staticmethod + def _assign_custom_request_attributes( + request: request_capnp.Request, + custom_attrs: t.Union[ + request_attributes_capnp.TorchRequestAttributes, + request_attributes_capnp.TensorFlowRequestAttributes, + None, + ], + ) -> None: + """ + Assigns request attributes to the supplied request. + + :param request: Request being built + :param custom_attrs: Custom attributes to be assigned + :raises ValueError: If building fails + """ + try: + if custom_attrs is None: + request.customAttributes.none = custom_attrs + else: + custom_attribute_class_name = ( + custom_attrs.schema.node.displayName.split(":")[-1] # type: ignore + ) + if custom_attribute_class_name == "TorchRequestAttributes": + request.customAttributes.torch = custom_attrs # type: ignore + elif custom_attribute_class_name == "TensorFlowRequestAttributes": + request.customAttributes.tf = custom_attrs # type: ignore + else: + raise ValueError("""Invalid custom attribute class name. + Expected 'TensorFlowRequestAttributes' or + 'TorchRequestAttributes'.""") + except Exception as e: + raise ValueError( + "Error building custom attributes portion of request." + ) from e + + @staticmethod + def build_request( + reply_channel: str, + model: t.Union[data_references_capnp.ModelKey, model_capnp.Model], + inputs: t.Union[ + t.List[data_references_capnp.TensorKey], + t.List[tensor_capnp.TensorDescriptor], + ], + outputs: t.List[data_references_capnp.TensorKey], + output_descriptors: t.List[tensor_capnp.OutputDescriptor], + custom_attributes: t.Union[ + request_attributes_capnp.TorchRequestAttributes, + request_attributes_capnp.TensorFlowRequestAttributes, + None, + ], + ) -> request_capnp.RequestBuilder: + """ + Builds the request message. + + :param reply_channel: Reply channel to be assigned to request + :param model: Model to be assigned to request + :param inputs: Inputs to be assigned to request + :param outputs: Outputs to be assigned to request + :param output_descriptors: Output descriptors to be assigned to request + :param custom_attributes: Custom attributes to be assigned to request + :returns: The Request + """ + request = request_capnp.Request.new_message() + MessageHandler._assign_reply_channel(request, reply_channel) + MessageHandler._assign_model(request, model) + MessageHandler._assign_inputs(request, inputs) + MessageHandler._assign_outputs(request, outputs) + MessageHandler._assign_output_descriptors(request, output_descriptors) + MessageHandler._assign_custom_request_attributes(request, custom_attributes) + return request + + @staticmethod + def serialize_request(request: request_capnp.RequestBuilder) -> bytes: + """ + Serializes a built request message. + + :param request: Request to be serialized + :returns: Serialized request bytes + :raises ValueError: If serialization fails + """ + display_name = request.schema.node.displayName # type: ignore + class_name = display_name.split(":")[-1] + if class_name != "Request": + raise ValueError( + "Error serializing the request. Value passed in is not " + f"a request: {class_name}" + ) + try: + return request.to_bytes() + except Exception as e: + raise ValueError("Error serializing the request") from e + + @staticmethod + def deserialize_request(request_bytes: bytes) -> request_capnp.Request: + """ + Deserializes a serialized request message. + + :param request_bytes: Bytes to be deserialized into a request + :returns: Deserialized request + :raises ValueError: If deserialization fails + """ + try: + bytes_message = request_capnp.Request.from_bytes( + request_bytes, traversal_limit_in_words=2**63 + ) + + with bytes_message as message: + return message + except Exception as e: + raise ValueError("Error deserializing the request") from e + + @staticmethod + def _assign_status( + response: response_capnp.Response, status: "response_capnp.Status" + ) -> None: + """ + Assigns a status to the supplied response. + + :param response: Response being built + :param status: Status to be assigned + :raises ValueError: If building fails + """ + try: + response.status = status + except Exception as e: + raise ValueError("Error assigning status to response.") from e + + @staticmethod + def _assign_message(response: response_capnp.Response, message: str) -> None: + """ + Assigns a message to the supplied response. + + :param response: Response being built + :param message: Message to be assigned + :raises ValueError: If building fails + """ + try: + response.message = message + except Exception as e: + raise ValueError("Error assigning message to response.") from e + + @staticmethod + def _assign_result( + response: response_capnp.Response, + result: t.Union[ + t.List[tensor_capnp.TensorDescriptor], + t.List[data_references_capnp.TensorKey], + None, + ], + ) -> None: + """ + Assigns a result to the supplied response. + + :param response: Response being built + :param result: Result to be assigned + :raises ValueError: If building fails + """ + try: + if result: + first_result = result[0] + display_name = first_result.schema.node.displayName # type: ignore + result_class_name = display_name.split(":")[-1] + if result_class_name == "TensorDescriptor": + response.result.descriptors = result # type: ignore + elif result_class_name == "TensorKey": + response.result.keys = result # type: ignore + else: + raise ValueError("""Invalid custom attribute class name. + Expected 'TensorDescriptor' or 'TensorKey'.""") + except Exception as e: + raise ValueError("Error assigning result to response.") from e + + @staticmethod + def _assign_custom_response_attributes( + response: response_capnp.Response, + custom_attrs: t.Union[ + response_attributes_capnp.TorchResponseAttributes, + response_attributes_capnp.TensorFlowResponseAttributes, + None, + ], + ) -> None: + """ + Assigns custom attributes to the supplied response. + + :param response: Response being built + :param custom_attrs: Custom attributes to be assigned + :raises ValueError: If building fails + """ + try: + if custom_attrs is None: + response.customAttributes.none = custom_attrs + else: + custom_attribute_class_name = ( + custom_attrs.schema.node.displayName.split(":")[-1] # type: ignore + ) + if custom_attribute_class_name == "TorchResponseAttributes": + response.customAttributes.torch = custom_attrs # type: ignore + elif custom_attribute_class_name == "TensorFlowResponseAttributes": + response.customAttributes.tf = custom_attrs # type: ignore + else: + raise ValueError("""Invalid custom attribute class name. + Expected 'TensorFlowResponseAttributes' or + 'TorchResponseAttributes'.""") + except Exception as e: + raise ValueError("Error assigning custom attributes to response.") from e + + @staticmethod + def build_response( + status: "response_capnp.Status", + message: str, + result: t.Union[ + t.List[tensor_capnp.TensorDescriptor], + t.List[data_references_capnp.TensorKey], + None, + ], + custom_attributes: t.Union[ + response_attributes_capnp.TorchResponseAttributes, + response_attributes_capnp.TensorFlowResponseAttributes, + None, + ], + ) -> response_capnp.ResponseBuilder: + """ + Builds the response message. + + :param status: Status to be assigned to response + :param message: Message to be assigned to response + :param result: Result to be assigned to response + :param custom_attributes: Custom attributes to be assigned to response + :returns: The Response + """ + response = response_capnp.Response.new_message() + MessageHandler._assign_status(response, status) + MessageHandler._assign_message(response, message) + MessageHandler._assign_result(response, result) + MessageHandler._assign_custom_response_attributes(response, custom_attributes) + return response + + @staticmethod + def serialize_response(response: response_capnp.ResponseBuilder) -> bytes: + """ + Serializes a built response message. + + :param response: Response to be serialized + :returns: Serialized response bytes + :raises ValueError: If serialization fails + """ + display_name = response.schema.node.displayName # type: ignore + class_name = display_name.split(":")[-1] + if class_name != "Response": + raise ValueError( + "Error serializing the response. Value passed in is not " + f"a response: {class_name}" + ) + try: + return response.to_bytes() + except Exception as e: + raise ValueError("Error serializing the response") from e + + @staticmethod + def deserialize_response(response_bytes: bytes) -> response_capnp.Response: + """ + Deserializes a serialized response message. + + :param response_bytes: Bytes to be deserialized into a response + :returns: Deserialized response + :raises ValueError: If deserialization fails + """ + try: + bytes_message = response_capnp.Response.from_bytes( + response_bytes, traversal_limit_in_words=2**63 + ) + + with bytes_message as message: + return message + + except Exception as e: + raise ValueError("Error deserializing the response") from e diff --git a/smartsim/_core/mli/mli_schemas/data/data_references.capnp b/smartsim/_core/mli/mli_schemas/data/data_references.capnp new file mode 100644 index 0000000000..65293be7b2 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/data/data_references.capnp @@ -0,0 +1,37 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0x8ca69fd1aacb6668; + +struct ModelKey { + key @0 :Text; + descriptor @1 :Text; +} + +struct TensorKey { + key @0 :Text; + descriptor @1 :Text; +} diff --git a/smartsim/_core/mli/mli_schemas/data/data_references_capnp.py b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.py new file mode 100644 index 0000000000..099d10c438 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `data_references.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "data_references.capnp")) +ModelKey = capnp.load(module_file).ModelKey +ModelKeyBuilder = ModelKey +ModelKeyReader = ModelKey +TensorKey = capnp.load(module_file).TensorKey +TensorKeyBuilder = TensorKey +TensorKeyReader = TensorKey diff --git a/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi new file mode 100644 index 0000000000..a5e318a556 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi @@ -0,0 +1,107 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `data_references.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator + +class ModelKey: + key: str + descriptor: str + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[ModelKeyReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> ModelKeyReader: ... + @staticmethod + def new_message() -> ModelKeyBuilder: ... + def to_dict(self) -> dict: ... + +class ModelKeyReader(ModelKey): + def as_builder(self) -> ModelKeyBuilder: ... + +class ModelKeyBuilder(ModelKey): + @staticmethod + def from_dict(dictionary: dict) -> ModelKeyBuilder: ... + def copy(self) -> ModelKeyBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> ModelKeyReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class TensorKey: + key: str + descriptor: str + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TensorKeyReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TensorKeyReader: ... + @staticmethod + def new_message() -> TensorKeyBuilder: ... + def to_dict(self) -> dict: ... + +class TensorKeyReader(TensorKey): + def as_builder(self) -> TensorKeyBuilder: ... + +class TensorKeyBuilder(TensorKey): + @staticmethod + def from_dict(dictionary: dict) -> TensorKeyBuilder: ... + def copy(self) -> TensorKeyBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TensorKeyReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/model/__init__.py b/smartsim/_core/mli/mli_schemas/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/mli_schemas/model/model.capnp b/smartsim/_core/mli/mli_schemas/model/model.capnp new file mode 100644 index 0000000000..fc9ed73663 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/model/model.capnp @@ -0,0 +1,33 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xaefb9301e14ba4bd; + +struct Model { + data @0 :Data; + name @1 :Text; + version @2 :Text; +} diff --git a/smartsim/_core/mli/mli_schemas/model/model_capnp.py b/smartsim/_core/mli/mli_schemas/model/model_capnp.py new file mode 100644 index 0000000000..be2c276c23 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/model/model_capnp.py @@ -0,0 +1,38 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `model.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "model.capnp")) +Model = capnp.load(module_file).Model +ModelBuilder = Model +ModelReader = Model diff --git a/smartsim/_core/mli/mli_schemas/model/model_capnp.pyi b/smartsim/_core/mli/mli_schemas/model/model_capnp.pyi new file mode 100644 index 0000000000..6ca53a3579 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/model/model_capnp.pyi @@ -0,0 +1,72 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `model.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator + +class Model: + data: bytes + name: str + version: str + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[ModelReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> ModelReader: ... + @staticmethod + def new_message() -> ModelBuilder: ... + def to_dict(self) -> dict: ... + +class ModelReader(Model): + def as_builder(self) -> ModelBuilder: ... + +class ModelBuilder(Model): + @staticmethod + def from_dict(dictionary: dict) -> ModelBuilder: ... + def copy(self) -> ModelBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> ModelReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/request/request.capnp b/smartsim/_core/mli/mli_schemas/request/request.capnp new file mode 100644 index 0000000000..26d9542d9f --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request.capnp @@ -0,0 +1,55 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xa27f0152c7bb299e; + +using Tensors = import "../tensor/tensor.capnp"; +using RequestAttributes = import "request_attributes/request_attributes.capnp"; +using DataRef = import "../data/data_references.capnp"; +using Models = import "../model/model.capnp"; + +struct ChannelDescriptor { + descriptor @0 :Text; +} + +struct Request { + replyChannel @0 :ChannelDescriptor; + model :union { + key @1 :DataRef.ModelKey; + data @2 :Models.Model; + } + input :union { + keys @3 :List(DataRef.TensorKey); + descriptors @4 :List(Tensors.TensorDescriptor); + } + output @5 :List(DataRef.TensorKey); + outputDescriptors @6 :List(Tensors.OutputDescriptor); + customAttributes :union { + torch @7 :RequestAttributes.TorchRequestAttributes; + tf @8 :RequestAttributes.TensorFlowRequestAttributes; + none @9 :Void; + } +} diff --git a/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp new file mode 100644 index 0000000000..f0a319f0a3 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp @@ -0,0 +1,49 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xdd14d8ba5c06743f; + +enum TorchTensorType { + nested @0; # ragged + sparse @1; + tensor @2; # "normal" tensor +} + +enum TFTensorType { + ragged @0; + sparse @1; + variable @2; + constant @3; +} + +struct TorchRequestAttributes { + tensorType @0 :TorchTensorType; +} + +struct TensorFlowRequestAttributes { + name @0 :Text; + tensorType @1 :TFTensorType; +} diff --git a/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py new file mode 100644 index 0000000000..8969f38457 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `request_attributes.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "request_attributes.capnp")) +TorchRequestAttributes = capnp.load(module_file).TorchRequestAttributes +TorchRequestAttributesBuilder = TorchRequestAttributes +TorchRequestAttributesReader = TorchRequestAttributes +TensorFlowRequestAttributes = capnp.load(module_file).TensorFlowRequestAttributes +TensorFlowRequestAttributesBuilder = TensorFlowRequestAttributes +TensorFlowRequestAttributesReader = TensorFlowRequestAttributes diff --git a/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi new file mode 100644 index 0000000000..c474de4b4f --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi @@ -0,0 +1,109 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `request_attributes.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator, Literal + +TorchTensorType = Literal["nested", "sparse", "tensor"] +TFTensorType = Literal["ragged", "sparse", "variable", "constant"] + +class TorchRequestAttributes: + tensorType: TorchTensorType + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TorchRequestAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TorchRequestAttributesReader: ... + @staticmethod + def new_message() -> TorchRequestAttributesBuilder: ... + def to_dict(self) -> dict: ... + +class TorchRequestAttributesReader(TorchRequestAttributes): + def as_builder(self) -> TorchRequestAttributesBuilder: ... + +class TorchRequestAttributesBuilder(TorchRequestAttributes): + @staticmethod + def from_dict(dictionary: dict) -> TorchRequestAttributesBuilder: ... + def copy(self) -> TorchRequestAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TorchRequestAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class TensorFlowRequestAttributes: + name: str + tensorType: TFTensorType + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TensorFlowRequestAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TensorFlowRequestAttributesReader: ... + @staticmethod + def new_message() -> TensorFlowRequestAttributesBuilder: ... + def to_dict(self) -> dict: ... + +class TensorFlowRequestAttributesReader(TensorFlowRequestAttributes): + def as_builder(self) -> TensorFlowRequestAttributesBuilder: ... + +class TensorFlowRequestAttributesBuilder(TensorFlowRequestAttributes): + @staticmethod + def from_dict(dictionary: dict) -> TensorFlowRequestAttributesBuilder: ... + def copy(self) -> TensorFlowRequestAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TensorFlowRequestAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/request/request_capnp.py b/smartsim/_core/mli/mli_schemas/request/request_capnp.py new file mode 100644 index 0000000000..90b8ce194e --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `request.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "request.capnp")) +ChannelDescriptor = capnp.load(module_file).ChannelDescriptor +ChannelDescriptorBuilder = ChannelDescriptor +ChannelDescriptorReader = ChannelDescriptor +Request = capnp.load(module_file).Request +RequestBuilder = Request +RequestReader = Request diff --git a/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi new file mode 100644 index 0000000000..2aab80b1d0 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi @@ -0,0 +1,319 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `request.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator, Literal, Sequence, overload + +from ..data.data_references_capnp import ( + ModelKey, + ModelKeyBuilder, + ModelKeyReader, + TensorKey, + TensorKeyBuilder, + TensorKeyReader, +) +from ..model.model_capnp import Model, ModelBuilder, ModelReader +from ..tensor.tensor_capnp import ( + OutputDescriptor, + OutputDescriptorBuilder, + OutputDescriptorReader, + TensorDescriptor, + TensorDescriptorBuilder, + TensorDescriptorReader, +) +from .request_attributes.request_attributes_capnp import ( + TensorFlowRequestAttributes, + TensorFlowRequestAttributesBuilder, + TensorFlowRequestAttributesReader, + TorchRequestAttributes, + TorchRequestAttributesBuilder, + TorchRequestAttributesReader, +) + +class ChannelDescriptor: + descriptor: str + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[ChannelDescriptorReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> ChannelDescriptorReader: ... + @staticmethod + def new_message() -> ChannelDescriptorBuilder: ... + def to_dict(self) -> dict: ... + +class ChannelDescriptorReader(ChannelDescriptor): + def as_builder(self) -> ChannelDescriptorBuilder: ... + +class ChannelDescriptorBuilder(ChannelDescriptor): + @staticmethod + def from_dict(dictionary: dict) -> ChannelDescriptorBuilder: ... + def copy(self) -> ChannelDescriptorBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> ChannelDescriptorReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class Request: + class Model: + key: ModelKey | ModelKeyBuilder | ModelKeyReader + data: Model | ModelBuilder | ModelReader + def which(self) -> Literal["key", "data"]: ... + @overload + def init(self, name: Literal["key"]) -> ModelKey: ... + @overload + def init(self, name: Literal["data"]) -> Model: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Request.ModelReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Request.ModelReader: ... + @staticmethod + def new_message() -> Request.ModelBuilder: ... + def to_dict(self) -> dict: ... + + class ModelReader(Request.Model): + key: ModelKeyReader + data: ModelReader + def as_builder(self) -> Request.ModelBuilder: ... + + class ModelBuilder(Request.Model): + key: ModelKey | ModelKeyBuilder | ModelKeyReader + data: Model | ModelBuilder | ModelReader + @staticmethod + def from_dict(dictionary: dict) -> Request.ModelBuilder: ... + def copy(self) -> Request.ModelBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Request.ModelReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + + class Input: + keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + def which(self) -> Literal["keys", "descriptors"]: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Request.InputReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Request.InputReader: ... + @staticmethod + def new_message() -> Request.InputBuilder: ... + def to_dict(self) -> dict: ... + + class InputReader(Request.Input): + keys: Sequence[TensorKeyReader] + descriptors: Sequence[TensorDescriptorReader] + def as_builder(self) -> Request.InputBuilder: ... + + class InputBuilder(Request.Input): + keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + @staticmethod + def from_dict(dictionary: dict) -> Request.InputBuilder: ... + def copy(self) -> Request.InputBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Request.InputReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + + class CustomAttributes: + torch: ( + TorchRequestAttributes + | TorchRequestAttributesBuilder + | TorchRequestAttributesReader + ) + tf: ( + TensorFlowRequestAttributes + | TensorFlowRequestAttributesBuilder + | TensorFlowRequestAttributesReader + ) + none: None + def which(self) -> Literal["torch", "tf", "none"]: ... + @overload + def init(self, name: Literal["torch"]) -> TorchRequestAttributes: ... + @overload + def init(self, name: Literal["tf"]) -> TensorFlowRequestAttributes: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Request.CustomAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Request.CustomAttributesReader: ... + @staticmethod + def new_message() -> Request.CustomAttributesBuilder: ... + def to_dict(self) -> dict: ... + + class CustomAttributesReader(Request.CustomAttributes): + torch: TorchRequestAttributesReader + tf: TensorFlowRequestAttributesReader + def as_builder(self) -> Request.CustomAttributesBuilder: ... + + class CustomAttributesBuilder(Request.CustomAttributes): + torch: ( + TorchRequestAttributes + | TorchRequestAttributesBuilder + | TorchRequestAttributesReader + ) + tf: ( + TensorFlowRequestAttributes + | TensorFlowRequestAttributesBuilder + | TensorFlowRequestAttributesReader + ) + @staticmethod + def from_dict(dictionary: dict) -> Request.CustomAttributesBuilder: ... + def copy(self) -> Request.CustomAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Request.CustomAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + replyChannel: ChannelDescriptor | ChannelDescriptorBuilder | ChannelDescriptorReader + model: Request.Model | Request.ModelBuilder | Request.ModelReader + input: Request.Input | Request.InputBuilder | Request.InputReader + output: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + outputDescriptors: Sequence[ + OutputDescriptor | OutputDescriptorBuilder | OutputDescriptorReader + ] + customAttributes: ( + Request.CustomAttributes + | Request.CustomAttributesBuilder + | Request.CustomAttributesReader + ) + @overload + def init(self, name: Literal["replyChannel"]) -> ChannelDescriptor: ... + @overload + def init(self, name: Literal["model"]) -> Model: ... + @overload + def init(self, name: Literal["input"]) -> Input: ... + @overload + def init(self, name: Literal["customAttributes"]) -> CustomAttributes: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[RequestReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> RequestReader: ... + @staticmethod + def new_message() -> RequestBuilder: ... + def to_dict(self) -> dict: ... + +class RequestReader(Request): + replyChannel: ChannelDescriptorReader + model: Request.ModelReader + input: Request.InputReader + output: Sequence[TensorKeyReader] + outputDescriptors: Sequence[OutputDescriptorReader] + customAttributes: Request.CustomAttributesReader + def as_builder(self) -> RequestBuilder: ... + +class RequestBuilder(Request): + replyChannel: ChannelDescriptor | ChannelDescriptorBuilder | ChannelDescriptorReader + model: Request.Model | Request.ModelBuilder | Request.ModelReader + input: Request.Input | Request.InputBuilder | Request.InputReader + output: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + outputDescriptors: Sequence[ + OutputDescriptor | OutputDescriptorBuilder | OutputDescriptorReader + ] + customAttributes: ( + Request.CustomAttributes + | Request.CustomAttributesBuilder + | Request.CustomAttributesReader + ) + @staticmethod + def from_dict(dictionary: dict) -> RequestBuilder: ... + def copy(self) -> RequestBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> RequestReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/response/response.capnp b/smartsim/_core/mli/mli_schemas/response/response.capnp new file mode 100644 index 0000000000..7194524cd0 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response.capnp @@ -0,0 +1,52 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xa05dcb4444780705; + +using Tensors = import "../tensor/tensor.capnp"; +using ResponseAttributes = import "response_attributes/response_attributes.capnp"; +using DataRef = import "../data/data_references.capnp"; + +enum Status { + complete @0; + fail @1; + timeout @2; + running @3; +} + +struct Response { + status @0 :Status; + message @1 :Text; + result :union { + keys @2 :List(DataRef.TensorKey); + descriptors @3 :List(Tensors.TensorDescriptor); + } + customAttributes :union { + torch @4 :ResponseAttributes.TorchResponseAttributes; + tf @5 :ResponseAttributes.TensorFlowResponseAttributes; + none @6 :Void; + } +} diff --git a/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp new file mode 100644 index 0000000000..b4dcf18e88 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp @@ -0,0 +1,33 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0xee59c60fccbb1bf9; + +struct TorchResponseAttributes { +} + +struct TensorFlowResponseAttributes { +} diff --git a/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py new file mode 100644 index 0000000000..4839334d52 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `response_attributes.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "response_attributes.capnp")) +TorchResponseAttributes = capnp.load(module_file).TorchResponseAttributes +TorchResponseAttributesBuilder = TorchResponseAttributes +TorchResponseAttributesReader = TorchResponseAttributes +TensorFlowResponseAttributes = capnp.load(module_file).TensorFlowResponseAttributes +TensorFlowResponseAttributesBuilder = TensorFlowResponseAttributes +TensorFlowResponseAttributesReader = TensorFlowResponseAttributes diff --git a/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi new file mode 100644 index 0000000000..f40688d74a --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi @@ -0,0 +1,103 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `response_attributes.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator + +class TorchResponseAttributes: + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TorchResponseAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TorchResponseAttributesReader: ... + @staticmethod + def new_message() -> TorchResponseAttributesBuilder: ... + def to_dict(self) -> dict: ... + +class TorchResponseAttributesReader(TorchResponseAttributes): + def as_builder(self) -> TorchResponseAttributesBuilder: ... + +class TorchResponseAttributesBuilder(TorchResponseAttributes): + @staticmethod + def from_dict(dictionary: dict) -> TorchResponseAttributesBuilder: ... + def copy(self) -> TorchResponseAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TorchResponseAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class TensorFlowResponseAttributes: + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TensorFlowResponseAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TensorFlowResponseAttributesReader: ... + @staticmethod + def new_message() -> TensorFlowResponseAttributesBuilder: ... + def to_dict(self) -> dict: ... + +class TensorFlowResponseAttributesReader(TensorFlowResponseAttributes): + def as_builder(self) -> TensorFlowResponseAttributesBuilder: ... + +class TensorFlowResponseAttributesBuilder(TensorFlowResponseAttributes): + @staticmethod + def from_dict(dictionary: dict) -> TensorFlowResponseAttributesBuilder: ... + def copy(self) -> TensorFlowResponseAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TensorFlowResponseAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.py b/smartsim/_core/mli/mli_schemas/response/response_capnp.py new file mode 100644 index 0000000000..eaa3451045 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.py @@ -0,0 +1,38 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `response.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "response.capnp")) +Response = capnp.load(module_file).Response +ResponseBuilder = Response +ResponseReader = Response diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi new file mode 100644 index 0000000000..6b4c50fd05 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi @@ -0,0 +1,212 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `response.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator, Literal, Sequence, overload + +from ..data.data_references_capnp import TensorKey, TensorKeyBuilder, TensorKeyReader +from ..tensor.tensor_capnp import ( + TensorDescriptor, + TensorDescriptorBuilder, + TensorDescriptorReader, +) +from .response_attributes.response_attributes_capnp import ( + TensorFlowResponseAttributes, + TensorFlowResponseAttributesBuilder, + TensorFlowResponseAttributesReader, + TorchResponseAttributes, + TorchResponseAttributesBuilder, + TorchResponseAttributesReader, +) + +Status = Literal["complete", "fail", "timeout", "running"] + +class Response: + class Result: + keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + def which(self) -> Literal["keys", "descriptors"]: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Response.ResultReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Response.ResultReader: ... + @staticmethod + def new_message() -> Response.ResultBuilder: ... + def to_dict(self) -> dict: ... + + class ResultReader(Response.Result): + keys: Sequence[TensorKeyReader] + descriptors: Sequence[TensorDescriptorReader] + def as_builder(self) -> Response.ResultBuilder: ... + + class ResultBuilder(Response.Result): + keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + @staticmethod + def from_dict(dictionary: dict) -> Response.ResultBuilder: ... + def copy(self) -> Response.ResultBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Response.ResultReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + + class CustomAttributes: + torch: ( + TorchResponseAttributes + | TorchResponseAttributesBuilder + | TorchResponseAttributesReader + ) + tf: ( + TensorFlowResponseAttributes + | TensorFlowResponseAttributesBuilder + | TensorFlowResponseAttributesReader + ) + none: None + def which(self) -> Literal["torch", "tf", "none"]: ... + @overload + def init(self, name: Literal["torch"]) -> TorchResponseAttributes: ... + @overload + def init(self, name: Literal["tf"]) -> TensorFlowResponseAttributes: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[Response.CustomAttributesReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Response.CustomAttributesReader: ... + @staticmethod + def new_message() -> Response.CustomAttributesBuilder: ... + def to_dict(self) -> dict: ... + + class CustomAttributesReader(Response.CustomAttributes): + torch: TorchResponseAttributesReader + tf: TensorFlowResponseAttributesReader + def as_builder(self) -> Response.CustomAttributesBuilder: ... + + class CustomAttributesBuilder(Response.CustomAttributes): + torch: ( + TorchResponseAttributes + | TorchResponseAttributesBuilder + | TorchResponseAttributesReader + ) + tf: ( + TensorFlowResponseAttributes + | TensorFlowResponseAttributesBuilder + | TensorFlowResponseAttributesReader + ) + @staticmethod + def from_dict(dictionary: dict) -> Response.CustomAttributesBuilder: ... + def copy(self) -> Response.CustomAttributesBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> Response.CustomAttributesReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + status: Status + message: str + result: Response.Result | Response.ResultBuilder | Response.ResultReader + customAttributes: ( + Response.CustomAttributes + | Response.CustomAttributesBuilder + | Response.CustomAttributesReader + ) + @overload + def init(self, name: Literal["result"]) -> Result: ... + @overload + def init(self, name: Literal["customAttributes"]) -> CustomAttributes: ... + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[ResponseReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> ResponseReader: ... + @staticmethod + def new_message() -> ResponseBuilder: ... + def to_dict(self) -> dict: ... + +class ResponseReader(Response): + result: Response.ResultReader + customAttributes: Response.CustomAttributesReader + def as_builder(self) -> ResponseBuilder: ... + +class ResponseBuilder(Response): + result: Response.Result | Response.ResultBuilder | Response.ResultReader + customAttributes: ( + Response.CustomAttributes + | Response.CustomAttributesBuilder + | Response.CustomAttributesReader + ) + @staticmethod + def from_dict(dictionary: dict) -> ResponseBuilder: ... + def copy(self) -> ResponseBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> ResponseReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp b/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp new file mode 100644 index 0000000000..4b2218b166 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp @@ -0,0 +1,75 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@0x9a0aeb2e04838fb1; + +using DataRef = import "../data/data_references.capnp"; + +enum Order { + c @0; # row major (contiguous layout) + f @1; # column major (fortran contiguous layout) +} + +enum NumericalType { + int8 @0; + int16 @1; + int32 @2; + int64 @3; + uInt8 @4; + uInt16 @5; + uInt32 @6; + uInt64 @7; + float32 @8; + float64 @9; +} + +enum ReturnNumericalType { + int8 @0; + int16 @1; + int32 @2; + int64 @3; + uInt8 @4; + uInt16 @5; + uInt32 @6; + uInt64 @7; + float32 @8; + float64 @9; + none @10; + auto @11; +} + +struct TensorDescriptor { + dimensions @0 :List(Int32); + order @1 :Order; + dataType @2 :NumericalType; +} + +struct OutputDescriptor { + order @0 :Order; + optionalKeys @1 :List(DataRef.TensorKey); + optionalDimension @2 :List(Int32); + optionalDatatype @3 :ReturnNumericalType; +} diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py new file mode 100644 index 0000000000..8c9d6c9029 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `tensor.capnp`.""" + +import os + +import capnp # type: ignore + +capnp.remove_import_hook() +here = os.path.dirname(os.path.abspath(__file__)) +module_file = os.path.abspath(os.path.join(here, "tensor.capnp")) +TensorDescriptor = capnp.load(module_file).TensorDescriptor +TensorDescriptorBuilder = TensorDescriptor +TensorDescriptorReader = TensorDescriptor +OutputDescriptor = capnp.load(module_file).OutputDescriptor +OutputDescriptorBuilder = OutputDescriptor +OutputDescriptorReader = OutputDescriptor diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi new file mode 100644 index 0000000000..b55f26b452 --- /dev/null +++ b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi @@ -0,0 +1,142 @@ +# BSD 2-Clause License + +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""This is an automatically generated stub for `tensor.capnp`.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from contextlib import contextmanager +from io import BufferedWriter +from typing import Iterator, Literal, Sequence + +from ..data.data_references_capnp import TensorKey, TensorKeyBuilder, TensorKeyReader + +Order = Literal["c", "f"] +NumericalType = Literal[ + "int8", + "int16", + "int32", + "int64", + "uInt8", + "uInt16", + "uInt32", + "uInt64", + "float32", + "float64", +] +ReturnNumericalType = Literal[ + "int8", + "int16", + "int32", + "int64", + "uInt8", + "uInt16", + "uInt32", + "uInt64", + "float32", + "float64", + "none", + "auto", +] + +class TensorDescriptor: + dimensions: Sequence[int] + order: Order + dataType: NumericalType + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[TensorDescriptorReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> TensorDescriptorReader: ... + @staticmethod + def new_message() -> TensorDescriptorBuilder: ... + def to_dict(self) -> dict: ... + +class TensorDescriptorReader(TensorDescriptor): + def as_builder(self) -> TensorDescriptorBuilder: ... + +class TensorDescriptorBuilder(TensorDescriptor): + @staticmethod + def from_dict(dictionary: dict) -> TensorDescriptorBuilder: ... + def copy(self) -> TensorDescriptorBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> TensorDescriptorReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... + +class OutputDescriptor: + order: Order + optionalKeys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + optionalDimension: Sequence[int] + optionalDatatype: ReturnNumericalType + @staticmethod + @contextmanager + def from_bytes( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> Iterator[OutputDescriptorReader]: ... + @staticmethod + def from_bytes_packed( + data: bytes, + traversal_limit_in_words: int | None = ..., + nesting_limit: int | None = ..., + ) -> OutputDescriptorReader: ... + @staticmethod + def new_message() -> OutputDescriptorBuilder: ... + def to_dict(self) -> dict: ... + +class OutputDescriptorReader(OutputDescriptor): + optionalKeys: Sequence[TensorKeyReader] + def as_builder(self) -> OutputDescriptorBuilder: ... + +class OutputDescriptorBuilder(OutputDescriptor): + optionalKeys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] + @staticmethod + def from_dict(dictionary: dict) -> OutputDescriptorBuilder: ... + def copy(self) -> OutputDescriptorBuilder: ... + def to_bytes(self) -> bytes: ... + def to_bytes_packed(self) -> bytes: ... + def to_segments(self) -> list[bytes]: ... + def as_reader(self) -> OutputDescriptorReader: ... + @staticmethod + def write(file: BufferedWriter) -> None: ... + @staticmethod + def write_packed(file: BufferedWriter) -> None: ... diff --git a/smartsim/_core/schemas/utils.py b/smartsim/_core/schemas/utils.py index 9cb36bcf57..905fe8955c 100644 --- a/smartsim/_core/schemas/utils.py +++ b/smartsim/_core/schemas/utils.py @@ -48,7 +48,7 @@ class _Message(t.Generic[_SchemaT]): delimiter: str = pydantic.Field(min_length=1, default=_DEFAULT_MSG_DELIM) def __str__(self) -> str: - return self.delimiter.join((self.header, self.payload.json())) + return self.delimiter.join((self.header, self.payload.model_dump_json())) @classmethod def from_str( @@ -58,7 +58,7 @@ def from_str( delimiter: str = _DEFAULT_MSG_DELIM, ) -> "_Message[_SchemaT]": header, payload = str_.split(delimiter, 1) - return cls(payload_type.parse_raw(payload), header, delimiter) + return cls(payload_type.model_validate_json(payload), header, delimiter) class SchemaRegistry(t.Generic[_SchemaT]): diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index b17be763b4..bf5838928e 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -220,7 +220,7 @@ def _installed(base_path: Path, backend: str) -> bool: """ backend_key = f"redisai_{backend}" backend_path = base_path / backend_key / f"{backend_key}.so" - backend_so = Path(os.environ.get("RAI_PATH", backend_path)).resolve() + backend_so = Path(os.environ.get("SMARTSIM_RAI_LIB", backend_path)).resolve() return backend_so.is_file() diff --git a/smartsim/_core/utils/timings.py b/smartsim/_core/utils/timings.py new file mode 100644 index 0000000000..f99950739e --- /dev/null +++ b/smartsim/_core/utils/timings.py @@ -0,0 +1,175 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import time +import typing as t +from collections import OrderedDict + +import numpy as np + +from ...log import get_logger + +logger = get_logger("PerfTimer") + + +class PerfTimer: + def __init__( + self, + filename: str = "timings", + prefix: str = "", + timing_on: bool = True, + debug: bool = False, + ): + self._start: t.Optional[float] = None + self._interm: t.Optional[float] = None + self._timings: OrderedDict[str, list[t.Union[float, int, str]]] = OrderedDict() + self._timing_on = timing_on + self._filename = filename + self._prefix = prefix + self._debug = debug + + def _add_label_to_timings(self, label: str) -> None: + if label not in self._timings: + self._timings[label] = [] + + @staticmethod + def _format_number(number: t.Union[float, int]) -> str: + """Formats the input value with a fixed precision appropriate for logging""" + return f"{number:0.4e}" + + def start_timings( + self, + first_label: t.Optional[str] = None, + first_value: t.Optional[t.Union[float, int]] = None, + ) -> None: + """Start a recording session by recording + + :param first_label: a label for an event that will be manually prepended + to the timing information before starting timers + :param first_label: a value for an event that will be manually prepended + to the timing information before starting timers""" + if self._timing_on: + if first_label is not None and first_value is not None: + mod_label = self._make_label(first_label) + value = self._format_number(first_value) + self._log(f"Started timing: {first_label}: {value}") + self._add_label_to_timings(mod_label) + self._timings[mod_label].append(value) + self._start = time.perf_counter() + self._interm = time.perf_counter() + + def end_timings(self) -> None: + """Record a timing event and clear the last checkpoint""" + if self._timing_on and self._start is not None: + mod_label = self._make_label("total_time") + self._add_label_to_timings(mod_label) + delta = self._format_number(time.perf_counter() - self._start) + self._timings[self._make_label("total_time")].append(delta) + self._log(f"Finished timing: {mod_label}: {delta}") + self._interm = None + + def _make_label(self, label: str) -> str: + """Return a label formatted with the current label prefix + + :param label: the original label + :returns: the adjusted label value""" + return self._prefix + label + + def _get_delta(self) -> float: + """Calculates the offset from the last intermediate checkpoint time + + :returns: the number of seconds elapsed""" + if self._interm is None: + return 0 + return time.perf_counter() - self._interm + + def get_last(self, label: str) -> str: + """Return the last timing value collected for the given label in + the format `{label}: {value}`. If no timing value has been collected + with the label, returns `Not measured yet`""" + mod_label = self._make_label(label) + if mod_label in self._timings: + value = self._timings[mod_label][-1] + if value: + return f"{label}: {value}" + + return "Not measured yet" + + def measure_time(self, label: str) -> None: + """Record a new time event if timing is enabled + + :param label: the label to record a timing event for""" + if self._timing_on and self._interm is not None: + mod_label = self._make_label(label) + self._add_label_to_timings(mod_label) + delta = self._format_number(self._get_delta()) + self._timings[mod_label].append(delta) + self._log(f"{mod_label}: {delta}") + self._interm = time.perf_counter() + + def _log(self, msg: str) -> None: + """Conditionally logs a message when the debug flag is enabled + + :param msg: the message to be logged""" + if self._debug: + logger.info(msg) + + @property + def max_length(self) -> int: + """Returns the number of records contained in the largest timing set""" + if len(self._timings) == 0: + return 0 + return max(len(value) for value in self._timings.values()) + + def print_timings(self, to_file: bool = False) -> None: + """Print timing information to standard output. If `to_file` + is `True`, also write results to a file. + + :param to_file: If `True`, also saves timing information + to the files `timings.npy` and `timings.txt` + """ + print(" ".join(self._timings.keys())) + try: + value_array = np.array(list(self._timings.values()), dtype=float) + except Exception as e: + logger.exception(e) + return + value_array = np.transpose(value_array) + if self._debug: + for i in range(value_array.shape[0]): + print(" ".join(self._format_number(value) for value in value_array[i])) + if to_file: + np.save(self._prefix + self._filename + ".npy", value_array) + + @property + def is_active(self) -> bool: + """Return `True` if timer is recording, `False` otherwise""" + return self._timing_on + + @is_active.setter + def is_active(self, active: bool) -> None: + """Set to `True` to record timing information, `False` otherwise""" + self._timing_on = active diff --git a/smartsim/database/orchestrator.py b/smartsim/database/orchestrator.py index e2549891af..e5e99c8932 100644 --- a/smartsim/database/orchestrator.py +++ b/smartsim/database/orchestrator.py @@ -265,8 +265,8 @@ def __init__( raise SSConfigError( "SmartSim not installed with pre-built extensions (Redis)\n" "Use the `smart` cli tool to install needed extensions\n" - "or set REDIS_PATH and REDIS_CLI_PATH in your environment\n" - "See documentation for more information" + "or set SMARTSIM_REDIS_SERVER_EXE and SMARTSIM_REDIS_CLI_EXE " + "in your environment\nSee documentation for more information" ) from e if self.launcher != "local": diff --git a/smartsim/experiment.py b/smartsim/experiment.py index 607a90ae16..9a14eecdc8 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -178,7 +178,7 @@ def __init__( def _set_dragon_server_path(self) -> None: """Set path for dragon server through environment varialbes""" if not "SMARTSIM_DRAGON_SERVER_PATH" in environ: - environ["SMARTSIM_DRAGON_SERVER_PATH_EXP"] = osp.join( + environ["_SMARTSIM_DRAGON_SERVER_PATH_EXP"] = osp.join( self.exp_path, CONFIG.dragon_default_subdir ) diff --git a/smartsim/log.py b/smartsim/log.py index 3d6c0860ee..c8fed9329f 100644 --- a/smartsim/log.py +++ b/smartsim/log.py @@ -252,16 +252,21 @@ def filter(self, record: logging.LogRecord) -> bool: return record.levelno <= level_no -def log_to_file(filename: str, log_level: str = "debug") -> None: +def log_to_file( + filename: str, log_level: str = "debug", logger: t.Optional[logging.Logger] = None +) -> None: """Installs a second filestream handler to the root logger, allowing subsequent logging calls to be sent to filename. - :param filename: the name of the desired log file. - :param log_level: as defined in get_logger. Can be specified + :param filename: The name of the desired log file. + :param log_level: As defined in get_logger. Can be specified to allow the file to store more or less verbose logging information. + :param logger: If supplied, a logger to add the file stream logging + behavior to. By default, a new logger is instantiated. """ - logger = logging.getLogger("SmartSim") + if logger is None: + logger = logging.getLogger("SmartSim") stream = open( # pylint: disable=consider-using-with filename, "w+", encoding="utf-8" ) diff --git a/smartsim/settings/dragonRunSettings.py b/smartsim/settings/dragonRunSettings.py index 69a91547e7..15e5855448 100644 --- a/smartsim/settings/dragonRunSettings.py +++ b/smartsim/settings/dragonRunSettings.py @@ -95,6 +95,26 @@ def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: self.run_args["node-feature"] = ",".join(feature_list) + @override + def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + """Specify the hostlist for this job + + :param host_list: hosts to launch on + :raises ValueError: if an empty host list is supplied + """ + if not host_list: + raise ValueError("empty hostlist provided") + + if isinstance(host_list, str): + host_list = host_list.replace(" ", "").split(",") + + # strip out all whitespace-only values + cleaned_list = [host.strip() for host in host_list if host and host.strip()] + if not len(cleaned_list) == len(host_list): + raise ValueError(f"invalid names found in hostlist: {host_list}") + + self.run_args["host-list"] = ",".join(cleaned_list) + def set_cpu_affinity(self, devices: t.List[int]) -> None: """Set the CPU affinity for this job diff --git a/tests/dragon/__init__.py b/tests/dragon/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dragon/channel.py b/tests/dragon/channel.py new file mode 100644 index 0000000000..4c46359c2d --- /dev/null +++ b/tests/dragon/channel.py @@ -0,0 +1,125 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import threading +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. + + :param key: a path to the root directory of the feature store + """ + self._lock = threading.RLock() + + super().__init__(key.as_posix()) + self._file_path = key + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes, timeout: float = 0) -> None: + """Send a message throuh the underlying communication channel. + + :param value: The value to send + :param timeout: maximum time to wait (in seconds) for messages to send + """ + with self._lock: + # write as text so we can add newlines as delimiters + with open(self._file_path, "a") as fp: + encoded_value = base64.b64encode(value).decode("utf-8") + fp.write(f"{encoded_value}\n") + logger.debug(f"FileSystemCommChannel {self._file_path} sent message") + + def recv(self, timeout: float = 0) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: maximum time to wait (in seconds) for messages to arrive + :returns: the received message + :raises SmartSimError: if the descriptor points to a missing file + """ + with self._lock: + messages: t.List[bytes] = [] + if not self._file_path.exists(): + raise SmartSimError("Empty channel") + + # read as text so we can split on newlines + with open(self._file_path, "r") as fp: + lines = fp.readlines() + + if lines: + line = lines.pop(0) + event_bytes = base64.b64decode(line.encode("utf-8")) + messages.append(event_bytes) + + self.clear() + + # remove the first message only, write remainder back... + if len(lines) > 0: + with open(self._file_path, "w") as fp: + fp.writelines(lines) + + logger.debug( + f"FileSystemCommChannel {self._file_path} received message" + ) + + return messages + + def clear(self) -> None: + """Create an empty file for events.""" + if self._file_path.exists(): + self._file_path.unlink() + self._file_path.touch() + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemCommChannel + """ + try: + path = pathlib.Path(descriptor) + return FileSystemCommChannel(path) + except: + logger.warning(f"failed to create fs comm channel: {descriptor}") + raise diff --git a/tests/dragon/conftest.py b/tests/dragon/conftest.py new file mode 100644 index 0000000000..d542700175 --- /dev/null +++ b/tests/dragon/conftest.py @@ -0,0 +1,129 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import os +import pathlib +import socket +import subprocess +import sys +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.data.ddict.ddict as dragon_ddict +import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc +import dragon.native.process as dragon_process + +from dragon.fli import FLInterface + +# isort: on + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.storage import dragon_util +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_storage() -> dragon_ddict.DDict: + """Fixture to instantiate a dragon distributed dictionary.""" + return dragon_util.create_ddict(1, 2, 32 * 1024**2) + + +@pytest.fixture(scope="module") +def the_worker_channel() -> DragonFLIChannel: + """Fixture to create a valid descriptor for a worker channel + that can be attached to.""" + channel_ = create_local() + fli_ = FLInterface(main_ch=channel_, manager_ch=None) + comm_channel = DragonFLIChannel(fli_) + return comm_channel + + +@pytest.fixture(scope="module") +def the_backbone( + the_storage: t.Any, the_worker_channel: DragonFLIChannel +) -> BackboneFeatureStore: + """Fixture to create a distributed dragon dictionary and wrap it + in a BackboneFeatureStore. + + :param the_storage: The dragon storage engine to use + :param the_worker_channel: Pre-configured worker channel + """ + + backbone = BackboneFeatureStore(the_storage, allow_reserved_writes=True) + backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_channel.descriptor + + return backbone + + +@pytest.fixture(scope="module") +def backbone_descriptor(the_backbone: BackboneFeatureStore) -> str: + # create a shared backbone featurestore + return the_backbone.descriptor + + +def function_as_dragon_proc( + entrypoint_fn: t.Callable[[t.Any], None], + args: t.List[t.Any], + cpu_affinity: t.List[int], + gpu_affinity: t.List[int], +) -> dragon_process.Process: + """Execute a function as an independent dragon process. + + :param entrypoint_fn: The function to execute + :param args: The arguments for the entrypoint function + :param cpu_affinity: The cpu affinity for the process + :param gpu_affinity: The gpu affinity for the process + :returns: The dragon process handle + """ + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + return dragon_process.Process( + target=entrypoint_fn, + args=args, + cwd=os.getcwd(), + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) diff --git a/tests/dragon/feature_store.py b/tests/dragon/feature_store.py new file mode 100644 index 0000000000..d06b0b334e --- /dev/null +++ b/tests/dragon/feature_store.py @@ -0,0 +1,152 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import typing as t + +import smartsim.error as sse +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class MemoryFeatureStore(FeatureStore): + """A feature store with values persisted only in local memory""" + + def __init__( + self, storage: t.Optional[t.Dict[str, t.Union[str, bytes]]] = None + ) -> None: + """Initialize the MemoryFeatureStore instance""" + super().__init__("in-memory-fs") + if storage is None: + storage = {"_": "abc"} + self._storage = storage + + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism + + :param key: The unique key that identifies the resource + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + return self._storage[key] + + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism + + :param key: The unique key that identifies the resource + :param value: The value to store + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + self._storage[key] = value + + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key + + :param key: The unique key that identifies the resource + :returns: True if the key is defined, False otherwise""" + return key in self._storage + + +class FileSystemFeatureStore(FeatureStore): + """Alternative feature store implementation for testing. Stores all + data on the file system""" + + def __init__(self, storage_dir: t.Union[pathlib.Path, str]) -> None: + """Initialize the FileSystemFeatureStore instance + + :param storage_dir: (optional) root directory to store all data relative to""" + if isinstance(storage_dir, str): + storage_dir = pathlib.Path(storage_dir) + self._storage_dir = storage_dir + super().__init__(storage_dir.as_posix()) + + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism + + :param key: The unique key that identifies the resource + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + path = self._key_path(key) + if not path.exists(): + raise sse.SmartSimError(f"{path} not found in feature store") + return path.read_bytes() + + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism + + :param key: The unique key that identifies the resource + :param value: The value to store + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + path = self._key_path(key, create=True) + if isinstance(value, str): + value = value.encode("utf-8") + path.write_bytes(value) + + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key + + :param key: The unique key that identifies the resource + :returns: True if the key is defined, False otherwise""" + path = self._key_path(key) + return path.exists() + + def _key_path(self, key: str, create: bool = False) -> pathlib.Path: + """Given a key, return a path that is optionally combined with a base + directory used by the FileSystemFeatureStore. + + :param key: Unique key of an item to retrieve from the feature store""" + value = pathlib.Path(key) + + if self._storage_dir is not None: + value = self._storage_dir / key + + if create: + value.parent.mkdir(parents=True, exist_ok=True) + + return value + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemFeatureStore": + """A factory method that creates an instance from a descriptor string + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemFeatureStore""" + try: + path = pathlib.Path(descriptor) + path.mkdir(parents=True, exist_ok=True) + if not path.is_dir(): + raise ValueError("FileSystemFeatureStore requires a directory path") + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + return FileSystemFeatureStore(path) + except: + logger.error(f"Error while creating FileSystemFeatureStore: {descriptor}") + raise diff --git a/tests/dragon/test_core_machine_learning_worker.py b/tests/dragon/test_core_machine_learning_worker.py new file mode 100644 index 0000000000..e9c356b4e0 --- /dev/null +++ b/tests/dragon/test_core_machine_learning_worker.py @@ -0,0 +1,377 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import time + +import pytest + +dragon = pytest.importorskip("dragon") + +import torch + +import smartsim.error as sse +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey, TensorKey +from smartsim._core.mli.infrastructure.worker.worker import ( + InferenceRequest, + MachineLearningWorkerCore, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) +from smartsim._core.utils import installed_redisai_backends + +from .feature_store import FileSystemFeatureStore, MemoryFeatureStore + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +# retrieved from pytest fixtures +is_dragon = ( + pytest.test_launcher == "dragon" if hasattr(pytest, "test_launcher") else False +) +torch_available = "torch" in installed_redisai_backends() + + +@pytest.fixture +def persist_torch_model(test_dir: str) -> pathlib.Path: + ts_start = time.time_ns() + print("Starting model file creation...") + test_path = pathlib.Path(test_dir) + model_path = test_path / "basic.pt" + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + ts_end = time.time_ns() + + ts_elapsed = (ts_end - ts_start) / 1000000000 + print(f"Model file creation took {ts_elapsed} seconds") + return model_path + + +@pytest.fixture +def persist_torch_tensor(test_dir: str) -> pathlib.Path: + ts_start = time.time_ns() + print("Starting model file creation...") + test_path = pathlib.Path(test_dir) + file_path = test_path / "tensor.pt" + + tensor = torch.randn((100, 100, 2)) + torch.save(tensor, file_path) + ts_end = time.time_ns() + + ts_elapsed = (ts_end - ts_start) / 1000000000 + print(f"Tensor file creation took {ts_elapsed} seconds") + return file_path + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_model_disk(persist_torch_model: pathlib.Path, test_dir: str) -> None: + """Verify that the ML worker successfully retrieves a model + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + key = str(persist_torch_model) + feature_store = FileSystemFeatureStore(test_dir) + fsd = feature_store.descriptor + feature_store[str(persist_torch_model)] = persist_torch_model.read_bytes() + + model_key = ModelKey(key=key, descriptor=fsd) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) + assert fetch_result.model_bytes + assert fetch_result.model_bytes == persist_torch_model.read_bytes() + + +def test_fetch_model_disk_missing() -> None: + """Verify that the ML worker fails to retrieves a model + when given an invalid (file system) key""" + worker = MachineLearningWorkerCore + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + key = "/path/that/doesnt/exist" + + model_key = ModelKey(key=key, descriptor=fsd) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_model(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_model_feature_store(persist_torch_model: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a model + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + + # create a key to retrieve from the feature store + key = "test-model" + + # put model bytes into the feature store + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + feature_store[key] = persist_torch_model.read_bytes() + + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) + assert fetch_result.model_bytes + assert fetch_result.model_bytes == persist_torch_model.read_bytes() + + +def test_fetch_model_feature_store_missing() -> None: + """Verify that the ML worker fails to retrieves a model + when given an invalid (feature store) key""" + worker = MachineLearningWorkerCore + + key = "some-key" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + # todo: consider that raising this exception shows impl. replace... + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_model(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_model_memory(persist_torch_model: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a model + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + + key = "test-model" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + feature_store[key] = persist_torch_model.read_bytes() + + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) + assert fetch_result.model_bytes + assert fetch_result.model_bytes == persist_torch_model.read_bytes() + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_input_disk(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a tensor/input + when given a valid (file system) key""" + tensor_name = str(persist_torch_tensor) + + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + worker = MachineLearningWorkerCore + + feature_store[tensor_name] = persist_torch_tensor.read_bytes() + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs is not None + + +def test_fetch_input_disk_missing() -> None: + """Verify that the ML worker fails to retrieves a tensor/input + when given an invalid (file system) key""" + worker = MachineLearningWorkerCore + + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + key = "/path/that/doesnt/exist" + + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_inputs(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key[0] in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a tensor/input + when given a valid (feature store) key""" + worker = MachineLearningWorkerCore + + tensor_name = "test-tensor" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) + + # put model bytes into the feature store + feature_store[tensor_name] = persist_torch_tensor.read_bytes() + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs + assert ( + list(fetch_result[0].inputs)[0][:10] == persist_torch_tensor.read_bytes()[:10] + ) + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_multi_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves multiple tensor/input + when given a valid collection of (feature store) keys""" + worker = MachineLearningWorkerCore + + tensor_name = "test-tensor" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + # put model bytes into the feature store + body1 = persist_torch_tensor.read_bytes() + feature_store[tensor_name + "1"] = body1 + + body2 = b"abcdefghijklmnopqrstuvwxyz" + feature_store[tensor_name + "2"] = body2 + + body3 = b"mnopqrstuvwxyzabcdefghijkl" + feature_store[tensor_name + "3"] = body3 + + request = InferenceRequest( + input_keys=[ + TensorKey(key=tensor_name + "1", descriptor=fsd), + TensorKey(key=tensor_name + "2", descriptor=fsd), + TensorKey(key=tensor_name + "3", descriptor=fsd), + ] + ) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + + raw_bytes = list(fetch_result[0].inputs) + assert raw_bytes + assert raw_bytes[0][:10] == persist_torch_tensor.read_bytes()[:10] + assert raw_bytes[1][:10] == body2[:10] + assert raw_bytes[2][:10] == body3[:10] + + +def test_fetch_input_feature_store_missing() -> None: + """Verify that the ML worker fails to retrieves a tensor/input + when given an invalid (feature store) key""" + worker = MachineLearningWorkerCore + + key = "bad-key" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_inputs(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_input_memory(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a tensor/input + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + key = "test-model" + feature_store[key] = persist_torch_tensor.read_bytes() + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs is not None + + +def test_place_outputs() -> None: + """Verify outputs are shared using the feature store""" + worker = MachineLearningWorkerCore + + key_name = "test-model" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + # create a key to retrieve from the feature store + keys = [ + TensorKey(key=key_name + "1", descriptor=fsd), + TensorKey(key=key_name + "2", descriptor=fsd), + TensorKey(key=key_name + "3", descriptor=fsd), + ] + data = [b"abcdef", b"ghijkl", b"mnopqr"] + + for fsk, v in zip(keys, data): + feature_store[fsk.key] = v + + request = InferenceRequest(output_keys=keys) + transform_result = TransformOutputResult(data, [1], "c", "float32") + + worker.place_output(request, transform_result, {fsd: feature_store}) + + for i in range(3): + assert feature_store[keys[i].key] == data[i] + + +@pytest.mark.parametrize( + "key, descriptor", + [ + pytest.param("", "desc", id="invalid key"), + pytest.param("key", "", id="invalid descriptor"), + ], +) +def test_invalid_tensorkey(key, descriptor) -> None: + with pytest.raises(ValueError): + fsk = TensorKey(key, descriptor) diff --git a/tests/dragon/test_device_manager.py b/tests/dragon/test_device_manager.py new file mode 100644 index 0000000000..d270e921cb --- /dev/null +++ b/tests/dragon/test_device_manager.py @@ -0,0 +1,186 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.control.device_manager import ( + DeviceManager, + WorkerDevice, +) +from smartsim._core.mli.infrastructure.storage.feature_store import ( + FeatureStore, + ModelKey, + TensorKey, +) +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +class MockWorker(MachineLearningWorkerBase): + @staticmethod + def fetch_model( + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] + ) -> FetchModelResult: + if batch.has_raw_model: + return FetchModelResult(batch.raw_model) + return FetchModelResult(b"fetched_model") + + @staticmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + return LoadModelResult(fetch_result.model_bytes) + + @staticmethod + def transform_input( + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: "MemoryPool", + ) -> TransformInputResult: + return TransformInputResult(b"result", [slice(0, 1)], [[1, 2]], ["float32"]) + + @staticmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + return ExecuteResult(b"result", [slice(0, 1)]) + + @staticmethod + def transform_output( + batch: RequestBatch, execute_result: ExecuteResult + ) -> t.List[TransformOutputResult]: + return [TransformOutputResult(b"result", None, "c", "float32")] + + +def test_worker_device(): + worker_device = WorkerDevice("gpu:0") + assert worker_device.name == "gpu:0" + + model_key = "my_model_key" + model = b"the model" + + worker_device.add_model(model_key, model) + + assert model_key in worker_device + assert worker_device.get_model(model_key) == model + worker_device.remove_model(model_key) + + assert model_key not in worker_device + + +def test_device_manager_model_in_request(): + + worker_device = WorkerDevice("gpu:0") + device_manager = DeviceManager(worker_device) + + worker = MockWorker() + + tensor_key = TensorKey(key="key", descriptor="desc") + output_key = TensorKey(key="key", descriptor="desc") + model_key = ModelKey(key="model key", descriptor="desc") + + request = InferenceRequest( + model_key=model_key, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"raw model", + batch_size=0, + ) + + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_key, + ) + + with device_manager.get_device( + worker=worker, batch=request_batch, feature_stores={} + ) as returned_device: + + assert returned_device == worker_device + assert worker_device.get_model(model_key.key) == b"raw model" + + assert model_key.key not in worker_device + + +def test_device_manager_model_key(): + + worker_device = WorkerDevice("gpu:0") + device_manager = DeviceManager(worker_device) + + worker = MockWorker() + + tensor_key = TensorKey(key="key", descriptor="desc") + output_key = TensorKey(key="key", descriptor="desc") + model_key = ModelKey(key="model key", descriptor="desc") + + request = InferenceRequest( + model_key=model_key, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=None, + batch_size=0, + ) + + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_key, + ) + + with device_manager.get_device( + worker=worker, batch=request_batch, feature_stores={} + ) as returned_device: + + assert returned_device == worker_device + assert worker_device.get_model(model_key.key) == b"fetched_model" + + assert model_key.key in worker_device diff --git a/tests/dragon/test_dragon_backend.py b/tests/dragon/test_dragon_backend.py new file mode 100644 index 0000000000..0e64c358df --- /dev/null +++ b/tests/dragon/test_dragon_backend.py @@ -0,0 +1,308 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import time +import uuid + +import pytest + +dragon = pytest.importorskip("dragon") + + +from smartsim._core.launcher.dragon.dragonBackend import DragonBackend +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_backend() -> DragonBackend: + return DragonBackend(pid=9999) + + +@pytest.mark.skip("Test is unreliable on build agent and may hang. TODO: Fix") +def test_dragonbackend_start_listener(the_backend: DragonBackend): + """Verify the background process listening to consumer registration events + is up and processing messages as expected.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor + + with pytest.raises(KeyError) as ex: + # we expect the value of the consumer to be empty until + # the listener start-up completes. + backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + + assert "not found" in ex.value.args[0] + + drg_process = the_backend.start_event_listener(cpu_affinity=[], gpu_affinity=[]) + + # # confirm there is a process still running + logger.info(f"Dragon process started: {drg_process}") + assert drg_process is not None, "Backend was unable to start event listener" + assert drg_process.puid != 0, "Process unique ID is empty" + assert drg_process.returncode is None, "Listener terminated early" + + # wait for the event listener to come up + try: + config = backbone.wait_for( + [BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], timeout=30 + ) + # verify result was in the returned configuration map + assert config[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + except Exception: + raise KeyError( + f"Unable to locate {BackboneFeatureStore.MLI_REGISTRAR_CONSUMER}" + "in the backbone" + ) + + # wait_for ensures the normal retrieval will now work, error-free + descriptor = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + assert descriptor is not None + + # register a new listener channel + comm_channel = DragonCommChannel.from_descriptor(descriptor) + mock_descriptor = str(uuid.uuid4()) + event = OnCreateConsumer("test_dragonbackend_start_listener", mock_descriptor, []) + + event_bytes = bytes(event) + comm_channel.send(event_bytes) + + subscriber_list = [] + + # Give the channel time to write the message and the listener time to handle it + for i in range(20): + time.sleep(1) + # Retrieve the subscriber list from the backbone and verify it is updated + if subscriber_list := backbone.notification_channels: + logger.debug(f"The subscriber list was populated after {i} iterations") + break + + assert mock_descriptor in subscriber_list + + # now send a shutdown message to terminate the listener + return_code = drg_process.returncode + + # clean up if the OnShutdownRequested wasn't properly handled + if return_code is None and drg_process.is_alive: + drg_process.kill() + drg_process.join() + + +def test_dragonbackend_backend_consumer(the_backend: DragonBackend): + """Verify the listener background process updates the appropriate + value in the backbone.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + assert backbone._allow_reserved_writes + + # create listener with `as_service=False` to perform a single loop iteration + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False) + + logger.debug(f"backbone loaded? {listener._backbone}") + logger.debug(f"listener created? {listener}") + + try: + # call the service execute method directly to trigger + # the entire service lifecycle + listener.execute() + + consumer_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + logger.debug(f"MLI_REGISTRAR_CONSUMER: {consumer_desc}") + + assert consumer_desc + except Exception as ex: + logger.info("") + finally: + listener._on_shutdown() + + +def test_dragonbackend_event_handled(the_backend: DragonBackend): + """Verify the event listener process updates the appropriate + value in the backbone when an event is received and again on shutdown. + """ + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + # create the listener to be tested + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False) + + assert listener._backbone, "The listener is not attached to a backbone" + + try: + # set up the listener but don't let the service event loop start + listener._create_eventing() # listener.execute() + + # grab the channel descriptor so we can simulate registrations + channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + comm_channel = DragonCommChannel.from_descriptor(channel_desc) + + num_events = 5 + events = [] + for i in range(num_events): + # register some mock consumers using the backend channel + event = OnCreateConsumer( + "test_dragonbackend_event_handled", + f"mock-consumer-descriptor-{uuid.uuid4()}", + [], + ) + event_bytes = bytes(event) + comm_channel.send(event_bytes) + events.append(event) + + # run few iterations of the event loop in case it takes a few cycles to write + for _ in range(20): + listener._on_iteration() + # Grab the value that should be getting updated + notify_consumers = set(backbone.notification_channels) + if len(notify_consumers) == len(events): + logger.info(f"Retrieved all consumers after {i} listen cycles") + break + + # ... and confirm that all the mock consumer descriptors are registered + assert set([e.descriptor for e in events]) == set(notify_consumers) + logger.info(f"Number of registered consumers: {len(notify_consumers)}") + + except Exception as ex: + logger.exception(f"test_dragonbackend_event_handled - exception occurred: {ex}") + assert False + finally: + # shutdown should unregister a registration listener + listener._on_shutdown() + + for i in range(10): + if BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in backbone: + logger.debug(f"The listener was removed after {i} iterations") + channel_desc = None + break + + # we should see that there is no listener registered + assert not channel_desc, "Listener shutdown failed to clean up the backbone" + + +def test_dragonbackend_shutdown_event(the_backend: DragonBackend): + """Verify the background process shuts down when it receives a + shutdown request.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=True) + + # set up the listener but don't let the listener loop start + listener._create_eventing() # listener.execute() + + # grab the channel descriptor so we can publish to it + channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + comm_channel = DragonCommChannel.from_descriptor(channel_desc) + + assert listener._consumer.listening, "Listener isn't ready to listen" + + # send a shutdown request... + event = OnShutdownRequested("test_dragonbackend_shutdown_event") + event_bytes = bytes(event) + comm_channel.send(event_bytes, 0.1) + + # execute should encounter the shutdown and exit + listener.execute() + + # ...and confirm the listener is now cancelled + assert not listener._consumer.listening + + +@pytest.mark.parametrize("health_check_frequency", [10, 20]) +def test_dragonbackend_shutdown_on_health_check( + the_backend: DragonBackend, + health_check_frequency: float, +): + """Verify that the event listener automatically shuts down when + a new listener is registered in its place. + + :param health_check_frequency: The expected frequency of service health check + invocations""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + listener = ConsumerRegistrationListener( + backbone, + 1.0, + 1.0, + as_service=True, # allow service to run long enough to health check + health_check_frequency=health_check_frequency, + ) + + # set up the listener but don't let the listener loop start + listener._create_eventing() # listener.execute() + assert listener._consumer.listening, "Listener wasn't ready to listen" + + # Replace the consumer descriptor in the backbone to trigger + # an automatic shutdown + backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = str(uuid.uuid4()) + + # set the last health check manually to verify the duration + start_at = time.time() + listener._last_health_check = time.time() + + # run execute to let the service trigger health checks + listener.execute() + elapsed = time.time() - start_at + + # confirm the frequency of the health check was honored + assert elapsed >= health_check_frequency + + # ...and confirm the listener is now cancelled + assert ( + not listener._consumer.listening + ), "Listener was not automatically shutdown by the health check" diff --git a/tests/dragon/test_dragon_ddict_utils.py b/tests/dragon/test_dragon_ddict_utils.py new file mode 100644 index 0000000000..c8bf687ef1 --- /dev/null +++ b/tests/dragon/test_dragon_ddict_utils.py @@ -0,0 +1,117 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim._core.mli.infrastructure.storage import dragon_util +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.mark.parametrize( + "num_nodes, num_managers, mem_per_node", + [ + pytest.param(1, 1, 3 * 1024**2, id="3MB, Bare minimum allocation"), + pytest.param(2, 2, 128 * 1024**2, id="128 MB allocation, 2 nodes, 2 mgr"), + pytest.param(2, 1, 512 * 1024**2, id="512 MB allocation, 2 nodes, 1 mgr"), + ], +) +def test_dragon_storage_util_create_ddict( + num_nodes: int, + num_managers: int, + mem_per_node: int, +): + """Verify that a dragon dictionary is successfully created. + + :param num_nodes: Number of ddict nodes to attempt to create + :param num_managers: Number of managers per node to request + :param num_managers: Memory to allocate per node + """ + ddict = dragon_util.create_ddict(num_nodes, num_managers, mem_per_node) + + assert ddict is not None + + +@pytest.mark.parametrize( + "num_nodes, num_managers, mem_per_node", + [ + pytest.param(-1, 1, 3 * 1024**2, id="Negative Node Count"), + pytest.param(0, 1, 3 * 1024**2, id="Invalid Node Count"), + pytest.param(1, -1, 3 * 1024**2, id="Negative Mgr Count"), + pytest.param(1, 0, 3 * 1024**2, id="Invalid Mgr Count"), + pytest.param(1, 1, -3 * 1024**2, id="Negative Mem Per Node"), + pytest.param(1, 1, (3 * 1024**2) - 1, id="Invalid Mem Per Node"), + pytest.param(1, 1, 0 * 1024**2, id="No Mem Per Node"), + ], +) +def test_dragon_storage_util_create_ddict_validators( + num_nodes: int, + num_managers: int, + mem_per_node: int, +): + """Verify that a dragon dictionary is successfully created. + + :param num_nodes: Number of ddict nodes to attempt to create + :param num_managers: Number of managers per node to request + :param num_managers: Memory to allocate per node + """ + with pytest.raises(ValueError): + dragon_util.create_ddict(num_nodes, num_managers, mem_per_node) + + +def test_dragon_storage_util_get_ddict_descriptor(the_storage: dragon_ddict.DDict): + """Verify that a descriptor is created. + + :param the_storage: A pre-allocated ddict + """ + value = dragon_util.ddict_to_descriptor(the_storage) + + assert isinstance(value, str) + assert len(value) > 0 + + +def test_dragon_storage_util_get_ddict_from_descriptor(the_storage: dragon_ddict.DDict): + """Verify that a ddict is created from a descriptor. + + :param the_storage: A pre-allocated ddict + """ + descriptor = dragon_util.ddict_to_descriptor(the_storage) + + value = dragon_util.descriptor_to_ddict(descriptor) + + assert value is not None + assert isinstance(value, dragon_ddict.DDict) + assert dragon_util.ddict_to_descriptor(value) == descriptor diff --git a/tests/dragon/test_environment_loader.py b/tests/dragon/test_environment_loader.py new file mode 100644 index 0000000000..07b2a45c1c --- /dev/null +++ b/tests/dragon/test_environment_loader.py @@ -0,0 +1,147 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +import dragon.data.ddict.ddict as dragon_ddict +import dragon.utils as du +from dragon.fli import FLInterface + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + DragonFeatureStore, +) +from smartsim.error.errors import SmartSimError + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.mark.parametrize( + "content", + [ + pytest.param(b"a"), + pytest.param(b"new byte string"), + ], +) +def test_environment_loader_attach_fli(content: bytes, monkeypatch: pytest.MonkeyPatch): + """A descriptor can be stored, loaded, and reattached.""" + chan = create_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv( + EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, + du.B64.bytes_to_str(queue.serialize()), + ) + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + config_queue = config.get_queue() + + _ = config_queue.send(content) + + old_recv = queue.recvh() + result, _ = old_recv.recv_bytes() + assert result == content + + +def test_environment_loader_serialize_fli(monkeypatch: pytest.MonkeyPatch): + """The serialized descriptors of a loaded and unloaded + queue are the same.""" + chan = create_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv( + EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, + du.B64.bytes_to_str(queue.serialize()), + ) + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + config_queue = config.get_queue() + assert config_queue._fli.serialize() == queue.serialize() + + +def test_environment_loader_flifails(monkeypatch: pytest.MonkeyPatch): + """An incorrect serialized descriptor will fails to attach.""" + + monkeypatch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "randomstring") + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=None, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + with pytest.raises(SmartSimError): + config.get_queue() + + +def test_environment_loader_backbone_load_dfs( + monkeypatch: pytest.MonkeyPatch, the_storage: dragon_ddict.DDict +): + """Verify the dragon feature store is loaded correctly by the + EnvironmentConfigLoader to demonstrate featurestore_factory correctness.""" + feature_store = DragonFeatureStore(the_storage) + monkeypatch.setenv( + EnvironmentConfigLoader.BACKBONE_ENV_VAR, feature_store.descriptor + ) + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=None, + queue_factory=None, + ) + + print(f"calling config.get_backbone: `{feature_store.descriptor}`") + + backbone = config.get_backbone() + assert backbone is not None + + +def test_environment_variables_not_set(monkeypatch: pytest.MonkeyPatch): + """EnvironmentConfigLoader getters return None when environment + variables are not set.""" + with monkeypatch.context() as patch: + patch.setenv(EnvironmentConfigLoader.BACKBONE_ENV_VAR, "") + patch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "") + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonCommChannel.from_descriptor, + ) + assert config.get_backbone() is None + assert config.get_queue() is None diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py new file mode 100644 index 0000000000..aacd47b556 --- /dev/null +++ b/tests/dragon/test_error_handling.py @@ -0,0 +1,511 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +from unittest.mock import MagicMock + +import pytest + +dragon = pytest.importorskip("dragon") + +import multiprocessing as mp + +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.fli import FLInterface +from dragon.mpbridge.queues import DragonQueue + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.infrastructure.control.request_dispatcher import ( + RequestDispatcher, +) +from smartsim._core.mli.infrastructure.control.worker_manager import ( + WorkerManager, + exception_handler, +) +from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.feature_store import ( + FeatureStore, + ModelKey, + TensorKey, +) +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.mli.mli_schemas.response.response_capnp import ResponseBuilder + +from .utils.channel import FileSystemCommChannel +from .utils.worker import IntegratedTorchWorker + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.fixture(scope="module") +def app_feature_store(the_storage) -> FeatureStore: + # create a standalone feature store to mimic a user application putting + # data into an application-owned resource (app should not access backbone) + app_fs = DragonFeatureStore(the_storage) + return app_fs + + +@pytest.fixture +def setup_worker_manager_model_bytes( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, + as_service=False, + cooldown=3, + ) + + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + + inf_request = InferenceRequest( + model_key=None, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + + model_id = ModelKey(key="key", descriptor=app_feature_store.descriptor) + + request_batch = RequestBatch( + [inf_request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_id, + ) + + dispatcher_task_queue.put(request_batch) + return worker_manager, integrated_worker_type + + +@pytest.fixture +def setup_worker_manager_model_key( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, + as_service=False, + cooldown=3, + ) + + tensor_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) + output_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) + model_id = ModelKey(key="model key", descriptor=app_feature_store.descriptor) + + request = InferenceRequest( + model_key=model_id, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_id, + ) + + dispatcher_task_queue.put(request_batch) + return worker_manager, integrated_worker_type + + +@pytest.fixture +def setup_request_dispatcher_model_bytes( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=0, + batch_size=0, + config_loader=config_loader, + worker_type=integrated_worker_type, + ) + request_dispatcher._on_start() + + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") + request = MessageHandler.build_request( + test_dir, model, [tensor_key], [output_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + + request_dispatcher._incoming_channel.send(ser_request) + + return request_dispatcher, integrated_worker_type + + +@pytest.fixture +def setup_request_dispatcher_model_key( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=0, + batch_size=0, + config_loader=config_loader, + worker_type=integrated_worker_type, + ) + request_dispatcher._on_start() + + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + model_key = MessageHandler.build_model_key( + key="model key", descriptor=app_feature_store.descriptor + ) + request = MessageHandler.build_request( + test_dir, model_key, [tensor_key], [output_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + + request_dispatcher._incoming_channel.send(ser_request) + + return request_dispatcher, integrated_worker_type + + +def mock_pipeline_stage( + monkeypatch: pytest.MonkeyPatch, + integrated_worker: MachineLearningWorkerBase, + stage: str, +) -> t.Callable[[t.Any], ResponseBuilder]: + def mock_stage(*args: t.Any, **kwargs: t.Any) -> None: + raise ValueError(f"Simulated error in {stage}") + + monkeypatch.setattr(integrated_worker, stage, mock_stage) + mock_reply_fn = MagicMock() + mock_response = MagicMock() + mock_response.schema.node.displayName = "Response" + mock_reply_fn.return_value = mock_response + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply", + mock_reply_fn, + ) + + mock_reply_channel = MagicMock() + mock_reply_channel.send = MagicMock() + + def mock_exception_handler( + exc: Exception, reply_channel: CommChannelBase, failure_message: str + ) -> None: + exception_handler(exc, mock_reply_channel, failure_message) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.worker_manager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.request_dispatcher.exception_handler", + mock_exception_handler, + ) + + return mock_reply_fn + + +@pytest.mark.parametrize( + "setup_worker_manager", + [ + pytest.param("setup_worker_manager_model_bytes"), + pytest.param("setup_worker_manager_model_key"), + ], +) +@pytest.mark.parametrize( + "stage, error_message", + [ + pytest.param( + "fetch_model", + "Error loading model on device or getting device.", + id="fetch model", + ), + pytest.param( + "load_model", + "Error loading model on device or getting device.", + id="load model", + ), + pytest.param("execute", "Error while executing.", id="execute"), + pytest.param( + "transform_output", + "Error while transforming the output.", + id="transform output", + ), + pytest.param( + "place_output", "Error while placing the output.", id="place output" + ), + ], +) +def test_wm_pipeline_stage_errors_handled( + request: pytest.FixtureRequest, + setup_worker_manager: str, + monkeypatch: pytest.MonkeyPatch, + stage: str, + error_message: str, +) -> None: + """Ensures that the worker manager does not crash after a failure in various pipeline stages""" + worker_manager, integrated_worker_type = request.getfixturevalue( + setup_worker_manager + ) + integrated_worker = worker_manager._worker + + worker_manager._on_start() + device = worker_manager._device_manager._device + mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + + if stage not in ["fetch_model"]: + monkeypatch.setattr( + integrated_worker, + "fetch_model", + MagicMock(return_value=FetchModelResult(b"result_bytes")), + ) + if stage not in ["fetch_model", "load_model"]: + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + device, + "get_model", + MagicMock(return_value=b"result_bytes"), + ) + if stage not in [ + "fetch_model", + "execute", + ]: + monkeypatch.setattr( + integrated_worker, + "execute", + MagicMock(return_value=ExecuteResult(b"result_bytes", [slice(0, 1)])), + ) + if stage not in [ + "fetch_model", + "execute", + "transform_output", + ]: + monkeypatch.setattr( + integrated_worker, + "transform_output", + MagicMock( + return_value=[TransformOutputResult(b"result", [], "c", "float32")] + ), + ) + + worker_manager._on_iteration() + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", error_message) + + +@pytest.mark.parametrize( + "setup_request_dispatcher", + [ + pytest.param("setup_request_dispatcher_model_bytes"), + pytest.param("setup_request_dispatcher_model_key"), + ], +) +@pytest.mark.parametrize( + "stage, error_message", + [ + pytest.param( + "fetch_inputs", + "Error fetching input.", + id="fetch input", + ), + pytest.param( + "transform_input", + "Error transforming input.", + id="transform input", + ), + ], +) +def test_dispatcher_pipeline_stage_errors_handled( + request: pytest.FixtureRequest, + setup_request_dispatcher: str, + monkeypatch: pytest.MonkeyPatch, + stage: str, + error_message: str, +) -> None: + """Ensures that the request dispatcher does not crash after a failure in various pipeline stages""" + request_dispatcher, integrated_worker_type = request.getfixturevalue( + setup_request_dispatcher + ) + integrated_worker = request_dispatcher._worker + + mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + + if stage not in ["fetch_inputs"]: + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=[FetchInputResult(result=[b"result"], meta=None)]), + ) + + request_dispatcher._on_iteration() + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", error_message) + + +def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensures that the worker manager does not crash after a failure in the + execute pipeline stage""" + + mock_reply_channel = MagicMock() + mock_reply_channel.send = MagicMock() + + mock_reply_fn = MagicMock() + + mock_response = MagicMock() + mock_response.schema.node.displayName = "Response" + mock_reply_fn.return_value = mock_response + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply", + mock_reply_fn, + ) + + test_exception = ValueError("Test ValueError") + exception_handler( + test_exception, mock_reply_channel, "Failure while fetching the model." + ) + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", "Failure while fetching the model.") + + +def test_dragon_feature_store_invalid_storage(): + """Verify that attempting to create a DragonFeatureStore without storage fails.""" + storage = None + + with pytest.raises(ValueError) as ex: + DragonFeatureStore(storage) + + assert "storage" in ex.value.args[0].lower() + assert "required" in ex.value.args[0].lower() diff --git a/tests/dragon/test_event_consumer.py b/tests/dragon/test_event_consumer.py new file mode 100644 index 0000000000..8a241bab19 --- /dev/null +++ b/tests/dragon/test_event_consumer.py @@ -0,0 +1,386 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import time +import typing as t +from unittest import mock + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnShutdownRequested, + OnWriteFeatureStore, +) +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +def test_eventconsumer_eventpublisher_integration( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that the publisher and consumer integrate as expected when + multiple publishers and consumers are sending simultaneously. This + test closely tracks the test in tests/test_featurestore_base.py also named + test_eventconsumer_eventpublisher_integration but requires dragon entities. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + capp_channel = DragonCommChannel(create_local()) + back_channel = DragonCommChannel(create_local()) + + wmgr_consumer_descriptor = wmgr_channel.descriptor + capp_consumer_descriptor = capp_channel.descriptor + back_consumer_descriptor = back_channel.descriptor + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + capp_consumer = EventConsumer( + capp_channel, + the_backbone, + ) + back_consumer = EventConsumer( + back_channel, + the_backbone, + filters=[OnCreateConsumer.CONSUMER_CREATED], + ) + + # create some broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + mock_client_app = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + the_backbone.notification_channels = [ + wmgr_consumer_descriptor, + capp_consumer_descriptor, + back_consumer_descriptor, + ] + + # simulate worker manager sending a notification to backend that it's alive + event_1 = OnCreateConsumer( + "test_eventconsumer_eventpublisher_integration", + wmgr_consumer_descriptor, + filters=[], + ) + mock_worker_mgr.send(event_1) + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", + the_backbone.descriptor, + key, + ) + mock_client_app.send(event, timeout=0.1) + + # worker manager should only get updates about feature update + wmgr_messages = wmgr_consumer.recv() + assert len(wmgr_messages) == 3 + + # the backend should only receive messages about consumer creation + back_messages = back_consumer.recv() + assert len(back_messages) == 1 + + # hypothetical app has no filters and will get all events + app_messages = capp_consumer.recv() + assert len(app_messages) == 4 + + +@pytest.mark.parametrize( + " timeout, batch_timeout, exp_err_msg", + [(-1, 1, " timeout"), (1, -1, "batch_timeout")], +) +def test_eventconsumer_invalid_timeout( + timeout: float, + batch_timeout: float, + exp_err_msg: str, + test_dir: str, + the_backbone: BackboneFeatureStore, +) -> None: + """Verify that the event consumer raises an exception + when provided an invalid request timeout. + + :param timeout: The request timeout for the event consumer recv call + :param batch_timeout: The batch timeout for the event consumer recv call + :param exp_err_msg: A unique value from the error message that should be raised + :param the_storage: The dragon storage engine to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + + # the consumer should report an error for the invalid timeout value + with pytest.raises(ValueError) as ex: + wmgr_consumer.recv(timeout=timeout, batch_timeout=batch_timeout) + + assert exp_err_msg in ex.value.args[0] + + +def test_eventconsumer_no_event_handler_registered( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that a consumer discards messages when + on a channel if no handler is registered. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + + # create a consumer to receive messages + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone, event_handler=None) + + # create a broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # manually register the consumers since we don't have a backend running + the_backbone.notification_channels = [wmgr_channel.descriptor] + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_no_event_handler_registered", + the_backbone.descriptor, + key, + ) + mock_worker_mgr.send(event, timeout=0.1) + + # run the handler and let it discard messages + for _ in range(15): + wmgr_consumer.listen_once(0.2, 2.0) + + assert wmgr_consumer.listening + + +def test_eventconsumer_no_event_handler_registered_shutdown( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that a consumer without an event handler + registered still honors shutdown requests. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + capp_channel = DragonCommChannel(create_local()) + + # create a consumers to receive messages + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone) + + # create a broadcaster to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # manually register the consumers since we don't have a backend running + the_backbone.notification_channels = [ + wmgr_channel.descriptor, + capp_channel.descriptor, + ] + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_no_event_handler_registered_shutdown", + the_backbone.descriptor, + key, + ) + mock_worker_mgr.send(event, timeout=0.1) + + event = OnShutdownRequested( + "test_eventconsumer_no_event_handler_registered_shutdown" + ) + mock_worker_mgr.send(event, timeout=0.1) + + # wmgr will stop listening to messages when it is told to stop listening + wmgr_consumer.listen(timeout=0.1, batch_timeout=2.0) + + for _ in range(15): + wmgr_consumer.listen_once(timeout=0.1, batch_timeout=2.0) + + # confirm the messages were processed, discarded, and the shutdown was received + assert wmgr_consumer.listening == False + + +def test_eventconsumer_registration( + the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that a consumer is correctly registered in + the backbone after sending a registration request. Then, + Confirm the consumer is unregistered after sending the + un-register request. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + with monkeypatch.context() as patch: + registrar = ConsumerRegistrationListener( + the_backbone, 1.0, 2.0, as_service=False + ) + + # NOTE: service.execute(as_service=False) will complete the service life- + # cycle and remove the registrar from the backbone, so mock _on_shutdown + disabled_shutdown = mock.MagicMock() + patch.setattr(registrar, "_on_shutdown", disabled_shutdown) + + # initialze registrar resources + registrar.execute() + + # create a consumer that will be registered + wmgr_channel = DragonCommChannel(create_local()) + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone) + + registered_channels = the_backbone.notification_channels + + # trigger the consumer-to-registrar handshake + wmgr_consumer.register() + + current_registrations: t.List[str] = [] + + # have the registrar run a few times to pick up the msg + for i in range(15): + registrar.execute() + current_registrations = the_backbone.notification_channels + if len(current_registrations) != len(registered_channels): + logger.debug(f"The event was processed on iteration {i}") + break + + # confirm the consumer is registered + assert wmgr_channel.descriptor in current_registrations + + # copy old list so we can compare against it. + registered_channels = list(current_registrations) + + # trigger the consumer removal + wmgr_consumer.unregister() + + # have the registrar run a few times to pick up the msg + for i in range(15): + registrar.execute() + current_registrations = the_backbone.notification_channels + if len(current_registrations) != len(registered_channels): + logger.debug(f"The event was processed on iteration {i}") + break + + # confirm the consumer is no longer registered + assert wmgr_channel.descriptor not in current_registrations + + +def test_registrar_teardown( + the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that the consumer registrar removes itself from + the backbone when it shuts down. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + with monkeypatch.context() as patch: + registrar = ConsumerRegistrationListener( + the_backbone, 1.0, 2.0, as_service=False + ) + + # directly initialze registrar resources to avoid service life-cycle + registrar._create_eventing() + + # confirm the registrar is published to the backbone + cfg = the_backbone.wait_for([BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], 10) + assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in cfg + + # execute the entire service lifecycle 1x + registrar.execute() + + consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone + + for i in range(15): + time.sleep(0.1) + consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone + if not consumer_found: + logger.debug(f"Registrar removed from the backbone on iteration {i}") + break + + assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in the_backbone diff --git a/tests/dragon/test_featurestore.py b/tests/dragon/test_featurestore.py new file mode 100644 index 0000000000..019dcde7a0 --- /dev/null +++ b/tests/dragon/test_featurestore.py @@ -0,0 +1,327 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import multiprocessing as mp +import random +import time +import typing as t +import unittest.mock as mock +import uuid + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + time as bbtime, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +def test_backbone_wait_for_no_keys( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for a value succeeds + immediately and does not cause a wait to occur if the supplied key + list is empty. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([]) + assert len(values) == 0 + + # confirm that no wait occurred + bbtime.sleep.assert_not_called() + + +def test_backbone_wait_for_prepopulated( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for a value succeed + immediately and do not cause a wait to occur if the data exists. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([BackboneFeatureStore.MLI_WORKER_QUEUE], 0.1) + + # confirm that wait_for with one key returns one value + assert len(values) == 1 + + # confirm that the descriptor is non-null w/some non-trivial value + assert len(values[BackboneFeatureStore.MLI_WORKER_QUEUE]) > 5 + + # confirm that no wait occurred + bbtime.sleep.assert_not_called() + + +def test_backbone_wait_for_prepopulated_dupe( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for keys that are duplicated + results in a single value being returned for each key. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + key1, key2 = "key-1", "key-2" + value1, value2 = "i-am-value-1", "i-am-value-2" + the_backbone[key1] = value1 + the_backbone[key2] = value2 + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([key1, key2, key1]) # key1 is duplicated + + # confirm that wait_for with one key returns one value + assert len(values) == 2 + assert key1 in values + assert key2 in values + + assert values[key1] == value1 + assert values[key2] == value2 + + +def set_value_after_delay( + descriptor: str, key: str, value: str, delay: float = 5 +) -> None: + """Helper method to persist a random value into the backbone + + :param descriptor: the backbone feature store descriptor to attach to + :param key: the key to write to + :param value: a value to write to the key + :param delay: amount of delay to apply before writing the key + """ + time.sleep(delay) + + backbone = BackboneFeatureStore.from_descriptor(descriptor) + backbone[key] = value + logger.debug(f"set_value_after_delay wrote `{value} to backbone[`{key}`]") + + +@pytest.mark.parametrize( + "delay", + [ + pytest.param( + 0, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 1, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 2, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 4, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 8, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + ], +) +def test_backbone_wait_for_partial_prepopulated( + the_backbone: BackboneFeatureStore, delay: float +) -> None: + """Verify that when data is not all in the backbone, the `wait_for` operation + continues to poll until it finds everything it needs. + + :param the_backbone: the storage engine to use, prepopulated with + :param delay: the number of seconds the second process will wait before + setting the target value in the backbone featurestore + """ + # set a very low timeout to confirm that it does not wait + wait_timeout = 10 + + key, value = str(uuid.uuid4()), str(random.random() * 10) + + logger.debug(f"Starting process to write {key} after {delay}s") + p = mp.Process( + target=set_value_after_delay, args=(the_backbone.descriptor, key, value, delay) + ) + p.start() + + p2 = mp.Process( + target=the_backbone.wait_for, + args=([BackboneFeatureStore.MLI_WORKER_QUEUE, key],), + kwargs={"timeout": wait_timeout}, + ) + p2.start() + + p.join() + p2.join() + + # both values should be written at this time + ret_vals = the_backbone.wait_for( + [key, BackboneFeatureStore.MLI_WORKER_QUEUE, key], 0.1 + ) + # confirm that wait_for with two keys returns two values + assert len(ret_vals) == 2, "values should contain values for both awaited keys" + + # confirm the pre-populated value has the correct output + assert ( + ret_vals[BackboneFeatureStore.MLI_WORKER_QUEUE] == "12345" + ) # mock descriptor value from fixture + + # confirm the population process completed and the awaited value is correct + assert ret_vals[key] == value, "verify order of values " + + +@pytest.mark.parametrize( + "num_keys", + [ + pytest.param( + 0, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 1, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 3, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 7, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 11, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + ], +) +def test_backbone_wait_for_multikey( + the_backbone: BackboneFeatureStore, + num_keys: int, + test_dir: str, +) -> None: + """Verify that asking the backbone to wait for multiple keys results + in that number of values being returned. + + :param the_backbone: the storage engine to use, prepopulated with + :param num_keys: the number of extra keys to set & request in the backbone + """ + # maximum delay allowed for setter processes + max_delay = 5 + + extra_keys = [str(uuid.uuid4()) for _ in range(num_keys)] + extra_values = [str(uuid.uuid4()) for _ in range(num_keys)] + extras = dict(zip(extra_keys, extra_values)) + delays = [random.random() * max_delay for _ in range(num_keys)] + processes = [] + + for key, value, delay in zip(extra_keys, extra_values, delays): + assert delay < max_delay, "write delay exceeds test timeout" + logger.debug(f"Delaying {key} write by {delay} seconds") + p = mp.Process( + target=set_value_after_delay, + args=(the_backbone.descriptor, key, value, delay), + ) + p.start() + processes.append(p) + + p2 = mp.Process( + target=the_backbone.wait_for, + args=(extra_keys,), + kwargs={"timeout": max_delay * 2}, + ) + p2.start() + for p in processes: + p.join(timeout=max_delay * 2) + p2.join( + timeout=max_delay * 2 + ) # give it 10 seconds longer than p2 timeout for backoff + + # use without a wait to verify all values are written + num_keys = len(extra_keys) + actual_values = the_backbone.wait_for(extra_keys, timeout=0.01) + assert len(extra_keys) == num_keys + + # confirm that wait_for returns all the expected values + assert len(actual_values) == num_keys + + # confirm that the returned values match (e.g. are returned in the right order) + for k in extras: + assert extras[k] == actual_values[k] diff --git a/tests/dragon/test_featurestore_base.py b/tests/dragon/test_featurestore_base.py new file mode 100644 index 0000000000..6daceb9061 --- /dev/null +++ b/tests/dragon/test_featurestore_base.py @@ -0,0 +1,844 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pathlib +import time +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnWriteFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.feature_store import ReservedKeys +from smartsim.error import SmartSimError + +from .channel import FileSystemCommChannel +from .feature_store import MemoryFeatureStore + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +def boom(*args, **kwargs) -> None: + """Helper function that blows up when used to mock up + some other function.""" + raise Exception(f"you shall not pass! {args}, {kwargs}") + + +def test_event_uid() -> None: + """Verify that all events include a unique identifier.""" + uids: t.Set[str] = set() + num_iters = 1000 + + # generate a bunch of events and keep track all the IDs + for i in range(num_iters): + event_a = OnCreateConsumer("test_event_uid", str(i), filters=[]) + event_b = OnWriteFeatureStore("test_event_uid", "test_event_uid", str(i)) + + uids.add(event_a.uid) + uids.add(event_b.uid) + + # verify each event created a unique ID + assert len(uids) == 2 * num_iters + + +def test_mli_reserved_keys_conversion() -> None: + """Verify that conversion from a string to an enum member + works as expected.""" + + for reserved_key in ReservedKeys: + # iterate through all keys and verify `from_string` works + assert ReservedKeys.contains(reserved_key.value) + + # show that the value (actual key) not the enum member name + # will not be incorrectly identified as reserved + assert not ReservedKeys.contains(str(reserved_key).split(".")[1]) + + +def test_mli_reserved_keys_writes() -> None: + """Verify that attempts to write to reserved keys are blocked from a + standard DragonFeatureStore but enabled with the BackboneFeatureStore.""" + + mock_storage = {} + dfs = DragonFeatureStore(mock_storage) + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + other = MemoryFeatureStore(mock_storage) + + expected_value = "value" + + for reserved_key in ReservedKeys: + # we expect every reserved key to fail using DragonFeatureStore... + with pytest.raises(SmartSimError) as ex: + dfs[reserved_key] = expected_value + + assert "reserved key" in ex.value.args[0] + + # ... and expect other feature stores to respect reserved keys + with pytest.raises(SmartSimError) as ex: + other[reserved_key] = expected_value + + assert "reserved key" in ex.value.args[0] + + # ...and those same keys to succeed on the backbone + backbone[reserved_key] = expected_value + actual_value = backbone[reserved_key] + assert actual_value == expected_value + + +def test_mli_consumers_read_by_key() -> None: + """Verify that the value returned from the mli consumers method is written + to the correct key and reads are allowed via standard dragon feature store.""" + + mock_storage = {} + dfs = DragonFeatureStore(mock_storage) + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + other = MemoryFeatureStore(mock_storage) + + expected_value = "value" + + # write using backbone that has permission to write reserved keys + backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] = expected_value + + # confirm read-only access to reserved keys from any FeatureStore + for fs in [dfs, backbone, other]: + assert fs[ReservedKeys.MLI_NOTIFY_CONSUMERS] == expected_value + + +def test_mli_consumers_read_by_backbone() -> None: + """Verify that the backbone reads the correct location + when using the backbone feature store API instead of mapping API.""" + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + expected_value = "value" + + backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] = expected_value + + # confirm reading via convenience method returns expected value + assert backbone.notification_channels[0] == expected_value + + +def test_mli_consumers_write_by_backbone() -> None: + """Verify that the backbone writes the correct location + when using the backbone feature store API instead of mapping API.""" + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + expected_value = ["value"] + + backbone.notification_channels = expected_value + + # confirm write using convenience method targets expected key + assert backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] == ",".join(expected_value) + + +def test_eventpublisher_broadcast_no_factory(test_dir: str) -> None: + """Verify that a broadcast operation without any registered subscribers + succeeds without raising Exceptions. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + consumer_descriptor = storage_path / "test-consumer" + + # NOTE: we're not putting any consumers into the backbone here! + backbone = BackboneFeatureStore(mock_storage) + + event = OnCreateConsumer( + "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[] + ) + + publisher = EventBroadcaster(backbone) + num_receivers = 0 + + # publishing this event without any known consumers registered should succeed + # but report that it didn't have anybody to send the event to + consumer_descriptor = storage_path / f"test-consumer" + event = OnCreateConsumer( + "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[] + ) + + num_receivers += publisher.send(event) + + # confirm no changes to the backbone occur when fetching the empty consumer key + key_in_features_store = ReservedKeys.MLI_NOTIFY_CONSUMERS in backbone + assert not key_in_features_store + + # confirm that the broadcast reports no events published + assert num_receivers == 0 + # confirm that the broadcast buffered the event for a later send + assert publisher.num_buffered == 1 + + +def test_eventpublisher_broadcast_to_empty_consumer_list(test_dir: str) -> None: + """Verify that a broadcast operation without any registered subscribers + succeeds without raising Exceptions. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + # prep our backbone with a consumer list + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = [] + + event = OnCreateConsumer( + "test_eventpublisher_broadcast_to_empty_consumer_list", + consumer_descriptor, + filters=[], + ) + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + num_receivers = publisher.send(event) + + registered_consumers = backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] + + # confirm that no consumers exist in backbone to send to + assert not registered_consumers + # confirm that the broadcast reports no events published + assert num_receivers == 0 + # confirm that the broadcast buffered the event for a later send + assert publisher.num_buffered == 1 + + +def test_eventpublisher_broadcast_without_channel_factory(test_dir: str) -> None: + """Verify that a broadcast operation reports an error if no channel + factory was supplied for constructing the consumer channels. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + # prep our backbone with a consumer list + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = [consumer_descriptor] + + event = OnCreateConsumer( + "test_eventpublisher_broadcast_without_channel_factory", + consumer_descriptor, + filters=[], + ) + publisher = EventBroadcaster( + backbone, + # channel_factory=FileSystemCommChannel.from_descriptor # <--- not supplied + ) + + with pytest.raises(SmartSimError) as ex: + publisher.send(event) + + assert "factory" in ex.value.args[0] + + +def test_eventpublisher_broadcast_empties_buffer(test_dir: str) -> None: + """Verify that a successful broadcast clears messages from the event + buffer when a new message is sent and consumers are registered. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = (consumer_descriptor,) + + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + # mock building up some buffered events + num_buffered_events = 14 + for i in range(num_buffered_events): + event = OnCreateConsumer( + "test_eventpublisher_broadcast_empties_buffer", + storage_path / f"test-consumer-{str(i)}", + [], + ) + publisher._event_buffer.append(bytes(event)) + + event0 = OnCreateConsumer( + "test_eventpublisher_broadcast_empties_buffer", + storage_path / f"test-consumer-{str(num_buffered_events + 1)}", + [], + ) + + num_receivers = publisher.send(event0) + # 1 receiver x 15 total events == 15 events + assert num_receivers == num_buffered_events + 1 + + +@pytest.mark.parametrize( + "num_consumers, num_buffered, expected_num_sent", + [ + pytest.param(0, 7, 0, id="0 x (7+1) - no consumers, multi-buffer"), + pytest.param(1, 7, 8, id="1 x (7+1) - single consumer, multi-buffer"), + pytest.param(2, 7, 16, id="2 x (7+1) - multi-consumer, multi-buffer"), + pytest.param(4, 4, 20, id="4 x (4+1) - multi-consumer, multi-buffer (odd #)"), + pytest.param(9, 0, 9, id="13 x (0+1) - multi-consumer, empty buffer"), + ], +) +def test_eventpublisher_broadcast_returns_total_sent( + test_dir: str, num_consumers: int, num_buffered: int, expected_num_sent: int +) -> None: + """Verify that a successful broadcast returns the total number of events + sent, including buffered messages. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param num_consumers: the number of consumers to mock setting up prior to send + :param num_buffered: the number of pre-buffered events to mock up + :param expected_num_sent: the expected result from calling send + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumers = [] + for i in range(num_consumers): + consumers.append(storage_path / f"test-consumer-{i}") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = consumers + + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + # mock building up some buffered events + for i in range(num_buffered): + event = OnCreateConsumer( + "test_eventpublisher_broadcast_returns_total_sent", + storage_path / f"test-consumer-{str(i)}", + [], + ) + publisher._event_buffer.append(bytes(event)) + + assert publisher.num_buffered == num_buffered + + # this event will trigger clearing anything already in buffer + event0 = OnCreateConsumer( + "test_eventpublisher_broadcast_returns_total_sent", + storage_path / f"test-consumer-{num_buffered}", + [], + ) + + # num_receivers should contain a number that computes w/all consumers and all events + num_receivers = publisher.send(event0) + + assert num_receivers == expected_num_sent + + +def test_eventpublisher_prune_unused_consumer(test_dir: str) -> None: + """Verify that any unused consumers are pruned each time a new event is sent. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + event = OnCreateConsumer( + "test_eventpublisher_prune_unused_consumer", + consumer_descriptor, + filters=[], + ) + + # the only registered cnosumer is in the event, expect no pruning + backbone.notification_channels = (consumer_descriptor,) + + publisher.send(event) + assert str(consumer_descriptor) in publisher._channel_cache + assert len(publisher._channel_cache) == 1 + + # add a new descriptor for another event... + consumer_descriptor2 = storage_path / "test-consumer-2" + # ... and remove the old descriptor from the backbone when it's looked up + backbone.notification_channels = (consumer_descriptor2,) + + event = OnCreateConsumer( + "test_eventpublisher_prune_unused_consumer", consumer_descriptor2, filters=[] + ) + + publisher.send(event) + + assert str(consumer_descriptor2) in publisher._channel_cache + assert str(consumer_descriptor) not in publisher._channel_cache + assert len(publisher._channel_cache) == 1 + + # test multi-consumer pruning by caching some extra channels + prune0, prune1, prune2 = "abc", "def", "ghi" + publisher._channel_cache[prune0] = "doesnt-matter-if-it-is-pruned" + publisher._channel_cache[prune1] = "doesnt-matter-if-it-is-pruned" + publisher._channel_cache[prune2] = "doesnt-matter-if-it-is-pruned" + + # add in one of our old channels so we prune the above items, send to these + backbone.notification_channels = (consumer_descriptor, consumer_descriptor2) + + publisher.send(event) + + assert str(consumer_descriptor2) in publisher._channel_cache + + # NOTE: we should NOT prune something that isn't used by this message but + # does appear in `backbone.notification_channels` + assert str(consumer_descriptor) in publisher._channel_cache + + # confirm all of our items that were not in the notification channels are gone + for pruned in [prune0, prune1, prune2]: + assert pruned not in publisher._channel_cache + + # confirm we have only the two expected items in the channel cache + assert len(publisher._channel_cache) == 2 + + +def test_eventpublisher_serialize_failure( + test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that errors during message serialization are raised to the caller. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param monkeypatch: pytest fixture for modifying behavior of existing code + with mock implementations + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + with monkeypatch.context() as patch: + event = OnCreateConsumer( + "test_eventpublisher_serialize_failure", target_descriptor, filters=[] + ) + + # patch the __bytes__ implementation to cause pickling to fail during send + def bad_bytes(self) -> bytes: + return b"abc" + + # this patch causes an attribute error when event pickling is attempted + patch.setattr(event, "__bytes__", bad_bytes) + + backbone.notification_channels = (target_descriptor,) + + # send a message into the channel + with pytest.raises(AttributeError) as ex: + publisher.send(event) + + assert "serialize" in ex.value.args[0] + + +def test_eventpublisher_factory_failure( + test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that errors during channel construction are raised to the caller. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param monkeypatch: pytest fixture for modifying behavior of existing code + with mock implementations + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + def boom(descriptor: str) -> None: + raise Exception(f"you shall not pass! {descriptor}") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + publisher = EventBroadcaster(backbone, channel_factory=boom) + + with monkeypatch.context() as patch: + event = OnCreateConsumer( + "test_eventpublisher_factory_failure", target_descriptor, filters=[] + ) + + backbone.notification_channels = (target_descriptor,) + + # send a message into the channel + with pytest.raises(SmartSimError) as ex: + publisher.send(event) + + assert "construct" in ex.value.args[0] + + +def test_eventpublisher_failure(test_dir: str, monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that unexpected errors during message send are caught and wrapped in a + SmartSimError so they are not propagated directly to the caller. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param monkeypatch: pytest fixture for modifying behavior of existing code + with mock implementations + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + def boom(self) -> None: + raise Exception("That was unexpected...") + + with monkeypatch.context() as patch: + event = OnCreateConsumer( + "test_eventpublisher_failure", target_descriptor, filters=[] + ) + + # patch the _broadcast implementation to cause send to fail after + # after the event has been pickled + patch.setattr(publisher, "_broadcast", boom) + + backbone.notification_channels = (target_descriptor,) + + # Here, we see the exception raised by broadcast that isn't expected + # is not allowed directly out, and instead is wrapped in SmartSimError + with pytest.raises(SmartSimError) as ex: + publisher.send(event) + + assert "unexpected" in ex.value.args[0] + + +def test_eventconsumer_receive(test_dir: str) -> None: + """Verify that a consumer retrieves a message from the given channel. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage) + comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) + event = OnCreateConsumer( + "test_eventconsumer_receive", target_descriptor, filters=[] + ) + + # simulate a sent event by writing directly to the input comm channel + comm_channel.send(bytes(event)) + + consumer = EventConsumer(comm_channel, backbone) + + all_received: t.List[OnCreateConsumer] = consumer.recv() + assert len(all_received) == 1 + + # verify we received the same event that was raised + assert all_received[0].category == event.category + assert all_received[0].descriptor == event.descriptor + + +@pytest.mark.parametrize("num_sent", [0, 1, 2, 4, 8, 16]) +def test_eventconsumer_receive_multi(test_dir: str, num_sent: int) -> None: + """Verify that a consumer retrieves multiple message from the given channel. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param num_sent: parameterized value used to vary the number of events + that are enqueued and validations are checked at multiple queue sizes + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage) + comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) + + # simulate multiple sent events by writing directly to the input comm channel + for _ in range(num_sent): + event = OnCreateConsumer( + "test_eventconsumer_receive_multi", target_descriptor, filters=[] + ) + comm_channel.send(bytes(event)) + + consumer = EventConsumer(comm_channel, backbone) + + all_received: t.List[OnCreateConsumer] = consumer.recv() + assert len(all_received) == num_sent + + +def test_eventconsumer_receive_empty(test_dir: str) -> None: + """Verify that a consumer receiving an empty message ignores the + message and continues processing. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage) + comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) + + # simulate a sent event by writing directly to the input comm channel + comm_channel.send(bytes(b"")) + + consumer = EventConsumer(comm_channel, backbone) + + messages = consumer.recv() + + # the messages array should be empty + assert not messages + + +def test_eventconsumer_eventpublisher_integration(test_dir: str) -> None: + """Verify that the publisher and consumer integrate as expected when + multiple publishers and consumers are sending simultaneously. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + mock_fs_descriptor = str(storage_path / f"mock-feature-store") + + wmgr_channel = FileSystemCommChannel(storage_path / "test-wmgr") + capp_channel = FileSystemCommChannel(storage_path / "test-capp") + back_channel = FileSystemCommChannel(storage_path / "test-backend") + + wmgr_consumer_descriptor = wmgr_channel.descriptor + capp_consumer_descriptor = capp_channel.descriptor + back_consumer_descriptor = back_channel.descriptor + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + capp_consumer = EventConsumer( + capp_channel, + backbone, + ) + back_consumer = EventConsumer( + back_channel, + backbone, + filters=[OnCreateConsumer.CONSUMER_CREATED], + ) + + # create some broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + backbone, + channel_factory=FileSystemCommChannel.from_descriptor, + ) + mock_client_app = EventBroadcaster( + backbone, + channel_factory=FileSystemCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + backbone.notification_channels = [ + wmgr_consumer_descriptor, + capp_consumer_descriptor, + back_consumer_descriptor, + ] + + # simulate worker manager sending a notification to backend that it's alive + event_1 = OnCreateConsumer( + "test_eventconsumer_eventpublisher_integration", + wmgr_consumer_descriptor, + filters=[], + ) + mock_worker_mgr.send(event_1) + + # simulate the app updating a model a few times + event_2 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1" + ) + event_3 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-2" + ) + event_4 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1" + ) + + mock_client_app.send(event_2) + mock_client_app.send(event_3) + mock_client_app.send(event_4) + + # worker manager should only get updates about feature update + wmgr_messages = wmgr_consumer.recv() + assert len(wmgr_messages) == 3 + + # the backend should only receive messages about consumer creation + back_messages = back_consumer.recv() + assert len(back_messages) == 1 + + # hypothetical app has no filters and will get all events + app_messages = capp_consumer.recv() + assert len(app_messages) == 4 + + +@pytest.mark.parametrize("invalid_timeout", [-100.0, -1.0, 0.0]) +def test_eventconsumer_batch_timeout( + invalid_timeout: float, + test_dir: str, +) -> None: + """Verify that a consumer allows only positive, non-zero values for timeout + if it is supplied. + + :param invalid_timeout: any invalid timeout that should fail validation + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage) + + channel = FileSystemCommChannel(storage_path / "test-wmgr") + + with pytest.raises(ValueError) as ex: + # try to create a consumer w/a max recv size of 0 + consumer = EventConsumer( + channel, + backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + consumer.recv(batch_timeout=invalid_timeout) + + assert "positive" in ex.value.args[0] + + +@pytest.mark.parametrize( + "wait_timeout, exp_wait_max", + [ + # aggregate the 1+1+1 into 3 on remaining parameters + pytest.param(1, 1 + 1 + 1, id="1s wait, 3 cycle steps"), + pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"), + pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"), + pytest.param(9, 3 + 2 + 4 + 8, id="9s wait, 6 cycle steps"), + # aggregate an entire cycle into 16 + pytest.param(19.5, 16 + 3 + 2 + 4, id="20s wait, repeat cycle"), + ], +) +def test_backbone_wait_timeout(wait_timeout: float, exp_wait_max: float) -> None: + """Verify that attempts to attach to the worker queue from the protoclient + timeout in an appropriate amount of time. Note: due to the backoff, we verify + the elapsed time is less than the 15s of a cycle of waits. + + :param wait_timeout: Maximum amount of time (in seconds) to allow the backbone + to wait for the requested value to exist + :param exp_wait_max: Maximum amount of time (in seconds) to set as the upper + bound to allow the delays with backoff to occur + :param storage_for_dragon_fs: the dragon storage engine to use + """ + + # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8] + # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps) + start_time = time.time() + + storage = {} + backbone = BackboneFeatureStore(storage) + + with pytest.raises(SmartSimError) as ex: + backbone.wait_for(["does-not-exist"], wait_timeout) + + assert "timeout" in str(ex.value.args[0]).lower() + + end_time = time.time() + elapsed = end_time - start_time + + # confirm that we met our timeout + assert elapsed > wait_timeout, f"below configured timeout {wait_timeout}" + + # confirm that the total wait time is aligned with the sleep cycle + assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}" diff --git a/tests/dragon/test_featurestore_integration.py b/tests/dragon/test_featurestore_integration.py new file mode 100644 index 0000000000..23fdc55ab6 --- /dev/null +++ b/tests/dragon/test_featurestore_integration.py @@ -0,0 +1,213 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import ( + DEFAULT_CHANNEL_BUFFER_SIZE, + create_local, +) +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) + +# isort: off +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +@pytest.fixture(scope="module") +def the_worker_channel() -> DragonCommChannel: + """Fixture to create a valid descriptor for a worker channel + that can be attached to.""" + wmgr_channel_ = create_local() + wmgr_channel = DragonCommChannel(wmgr_channel_) + return wmgr_channel + + +@pytest.mark.parametrize( + "num_events, batch_timeout, max_batches_expected", + [ + pytest.param(1, 1.0, 2, id="under 1s timeout"), + pytest.param(20, 1.0, 3, id="test 1s timeout 20x"), + pytest.param(30, 0.2, 5, id="test 0.2s timeout 30x"), + pytest.param(60, 0.4, 4, id="small batches"), + pytest.param(100, 0.1, 10, id="many small batches"), + ], +) +def test_eventconsumer_max_dequeue( + num_events: int, + batch_timeout: float, + max_batches_expected: int, + the_worker_channel: DragonCommChannel, + the_backbone: BackboneFeatureStore, +) -> None: + """Verify that a consumer does not sit and collect messages indefinitely + by checking that a consumer returns after a maximum timeout is exceeded. + + :param num_events: Total number of events to raise in the test + :param batch_timeout: Maximum wait time (in seconds) for a message to be sent + :param max_batches_expected: Maximum number of receives that should occur + :param the_storage: Dragon storage engine to use + """ + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + the_worker_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + + # create a broadcaster to publish messages + mock_client_app = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + the_backbone.notification_channels = [the_worker_channel.descriptor] + + # simulate the app updating a model a lot of times + for key in (f"key-{i}" for i in range(num_events)): + event = OnWriteFeatureStore( + "test_eventconsumer_max_dequeue", the_backbone.descriptor, key + ) + mock_client_app.send(event, timeout=0.01) + + num_dequeued = 0 + num_batches = 0 + + while wmgr_messages := wmgr_consumer.recv( + timeout=0.1, + batch_timeout=batch_timeout, + ): + # worker manager should not get more than `max_num_msgs` events + num_dequeued += len(wmgr_messages) + num_batches += 1 + + # make sure we made all the expected dequeue calls and got everything + assert num_dequeued == num_events + assert num_batches > 0 + assert num_batches < max_batches_expected, "too many recv calls were made" + + +@pytest.mark.parametrize( + "buffer_size", + [ + pytest.param( + -1, + id="replace negative, default to 500", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 0, + id="replace zero, default to 500", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 1, + id="non-zero buffer size: 1", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + # pytest.param(500, id="maximum size edge case: 500"), + pytest.param( + 550, + id="larger than default: 550", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 800, + id="much larger then default: 800", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 1000, + id="very large buffer: 1000, unreliable in dragon-v0.10", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + ], +) +def test_channel_buffer_size( + buffer_size: int, + the_storage: t.Any, +) -> None: + """Verify that a channel used by an EventBroadcaster can buffer messages + until a configured maximum value is exceeded. + + :param buffer_size: Maximum number of messages allowed in a channel buffer + :param the_storage: The dragon storage engine to use + """ + + mock_storage = the_storage + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + + wmgr_channel_ = create_local(buffer_size) # <--- vary buffer size + wmgr_channel = DragonCommChannel(wmgr_channel_) + wmgr_consumer_descriptor = wmgr_channel.descriptor + + # create a broadcaster to publish messages. create no consumers to + # push the number of sent messages past the allotted buffer size + mock_client_app = EventBroadcaster( + backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + backbone.notification_channels = [wmgr_consumer_descriptor] + + if buffer_size < 1: + # NOTE: we set this after creating the channel above to ensure + # the default parameter value was used during instantiation + buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE + + # simulate the app updating a model a lot of times + for key in (f"key-{i}" for i in range(buffer_size)): + event = OnWriteFeatureStore( + "test_channel_buffer_size", backbone.descriptor, key + ) + mock_client_app.send(event, timeout=0.01) + + # adding 1 more over the configured buffer size should report the error + with pytest.raises(Exception) as ex: + mock_client_app.send(event, timeout=0.01) diff --git a/tests/dragon/test_inference_reply.py b/tests/dragon/test_inference_reply.py new file mode 100644 index 0000000000..bdc7be14bc --- /dev/null +++ b/tests/dragon/test_inference_reply.py @@ -0,0 +1,76 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey +from smartsim._core.mli.infrastructure.worker.worker import InferenceReply +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +handler = MessageHandler() + + +@pytest.fixture +def inference_reply() -> InferenceReply: + return InferenceReply() + + +@pytest.fixture +def fs_key() -> TensorKey: + return TensorKey("key", "descriptor") + + +@pytest.mark.parametrize( + "outputs, expected", + [ + ([b"output bytes"], True), + (None, False), + ([], False), + ], +) +def test_has_outputs(monkeypatch, inference_reply, outputs, expected): + """Test the has_outputs property with different values for outputs.""" + monkeypatch.setattr(inference_reply, "outputs", outputs) + assert inference_reply.has_outputs == expected + + +@pytest.mark.parametrize( + "output_keys, expected", + [ + ([fs_key], True), + (None, False), + ([], False), + ], +) +def test_has_output_keys(monkeypatch, inference_reply, output_keys, expected): + """Test the has_output_keys property with different values for output_keys.""" + monkeypatch.setattr(inference_reply, "output_keys", output_keys) + assert inference_reply.has_output_keys == expected diff --git a/tests/dragon/test_inference_request.py b/tests/dragon/test_inference_request.py new file mode 100644 index 0000000000..f5c8b9bdc7 --- /dev/null +++ b/tests/dragon/test_inference_request.py @@ -0,0 +1,118 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey +from smartsim._core.mli.infrastructure.worker.worker import InferenceRequest +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +handler = MessageHandler() + + +@pytest.fixture +def inference_request() -> InferenceRequest: + return InferenceRequest() + + +@pytest.fixture +def fs_key() -> TensorKey: + return TensorKey("key", "descriptor") + + +@pytest.mark.parametrize( + "raw_model, expected", + [ + (handler.build_model(b"bytes", "Model Name", "V1"), True), + (None, False), + ], +) +def test_has_raw_model(monkeypatch, inference_request, raw_model, expected): + """Test the has_raw_model property with different values for raw_model.""" + monkeypatch.setattr(inference_request, "raw_model", raw_model) + assert inference_request.has_raw_model == expected + + +@pytest.mark.parametrize( + "model_key, expected", + [ + (fs_key, True), + (None, False), + ], +) +def test_has_model_key(monkeypatch, inference_request, model_key, expected): + """Test the has_model_key property with different values for model_key.""" + monkeypatch.setattr(inference_request, "model_key", model_key) + assert inference_request.has_model_key == expected + + +@pytest.mark.parametrize( + "raw_inputs, expected", + [([b"raw input bytes"], True), (None, False), ([], False)], +) +def test_has_raw_inputs(monkeypatch, inference_request, raw_inputs, expected): + """Test the has_raw_inputs property with different values for raw_inputs.""" + monkeypatch.setattr(inference_request, "raw_inputs", raw_inputs) + assert inference_request.has_raw_inputs == expected + + +@pytest.mark.parametrize( + "input_keys, expected", + [([fs_key], True), (None, False), ([], False)], +) +def test_has_input_keys(monkeypatch, inference_request, input_keys, expected): + """Test the has_input_keys property with different values for input_keys.""" + monkeypatch.setattr(inference_request, "input_keys", input_keys) + assert inference_request.has_input_keys == expected + + +@pytest.mark.parametrize( + "output_keys, expected", + [([fs_key], True), (None, False), ([], False)], +) +def test_has_output_keys(monkeypatch, inference_request, output_keys, expected): + """Test the has_output_keys property with different values for output_keys.""" + monkeypatch.setattr(inference_request, "output_keys", output_keys) + assert inference_request.has_output_keys == expected + + +@pytest.mark.parametrize( + "input_meta, expected", + [ + ([handler.build_tensor_descriptor("c", "float32", [1, 2, 3])], True), + (None, False), + ([], False), + ], +) +def test_has_input_meta(monkeypatch, inference_request, input_meta, expected): + """Test the has_input_meta property with different values for input_meta.""" + monkeypatch.setattr(inference_request, "input_meta", input_meta) + assert inference_request.has_input_meta == expected diff --git a/tests/dragon/test_protoclient.py b/tests/dragon/test_protoclient.py new file mode 100644 index 0000000000..f84417107d --- /dev/null +++ b/tests/dragon/test_protoclient.py @@ -0,0 +1,313 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import pickle +import time +import typing as t +from unittest.mock import MagicMock + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +# isort: off +from dragon import fli +from dragon.data.ddict.ddict import DDict + +# from ..ex..high_throughput_inference.mock_app import ProtoClient +from smartsim._core.mli.client.protoclient import ProtoClient + + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +WORK_QUEUE_KEY = BackboneFeatureStore.MLI_WORKER_QUEUE +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_worker_queue(the_backbone: BackboneFeatureStore) -> DragonFLIChannel: + """Fixture that creates a dragon FLI channel as a stand-in for the + worker queue created by the worker. + + :param the_backbone: The backbone feature store to update + with the worker queue descriptor. + :returns: The attached `DragonFLIChannel` + """ + + # create the FLI + to_worker_channel = create_local() + fli_ = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + comm_channel = DragonFLIChannel(fli_) + + # store the descriptor in the backbone + the_backbone.worker_queue = comm_channel.descriptor + + try: + comm_channel.send(b"foo") + except Exception as ex: + logger.exception(f"Test send from worker channel failed", exc_info=True) + + return comm_channel + + +@pytest.mark.parametrize( + "backbone_timeout, exp_wait_max", + [ + # aggregate the 1+1+1 into 3 on remaining parameters + pytest.param(0.5, 1 + 1 + 1, id="0.5s wait, 3 cycle steps"), + pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"), + pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"), + ], +) +def test_protoclient_timeout( + backbone_timeout: float, + exp_wait_max: float, + the_backbone: BackboneFeatureStore, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that attempts to attach to the worker queue from the protoclient + timeout in an appropriate amount of time. Note: due to the backoff, we verify + the elapsed time is less than the 15s of a cycle of waits. + + :param backbone_timeout: a timeout for use when configuring a proto client + :param exp_wait_max: a ceiling for the expected time spent waiting for + the timeout + :param the_backbone: a pre-initialized backbone featurestore for setting up + the environment variable required by the client + """ + + # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8] + # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps) + + with monkeypatch.context() as ctx, pytest.raises(SmartSimError) as ex: + start_time = time.time() + # remove the worker queue value from the backbone if it exists + # to ensure the timeout occurs + the_backbone.pop(BackboneFeatureStore.MLI_WORKER_QUEUE) + + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + + ProtoClient(timing_on=False, backbone_timeout=backbone_timeout) + elapsed = time.time() - start_time + logger.info(f"ProtoClient timeout occurred in {elapsed} seconds") + + # confirm that we met our timeout + assert ( + elapsed >= backbone_timeout + ), f"below configured timeout {backbone_timeout}" + + # confirm that the total wait time is aligned with the sleep cycle + assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}" + + +def test_protoclient_initialization_no_backbone( + monkeypatch: pytest.MonkeyPatch, the_worker_queue: DragonFLIChannel +): + """Verify that attempting to start the client without required environment variables + results in an exception. + + :param the_worker_queue: Passing the worker queue fixture to ensure + the worker queue environment is correctly configured. + + NOTE: os.environ[BackboneFeatureStore.MLI_BACKBONE] is not set""" + + with monkeypatch.context() as patch, pytest.raises(SmartSimError) as ex: + patch.setenv(BackboneFeatureStore.MLI_BACKBONE, "") + + ProtoClient(timing_on=False) + + # confirm the missing value error has been raised + assert {"backbone", "configuration"}.issubset(set(ex.value.args[0].split(" "))) + + +def test_protoclient_initialization( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that attempting to start the client with required env vars results + in a fully initialized client. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: an FLI channel the client will retrieve + from the backbone""" + + with monkeypatch.context() as ctx: + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + fs_descriptor = the_backbone.descriptor + wq_descriptor = the_worker_queue.descriptor + + # confirm the backbone was attached correctly + assert client._backbone is not None + assert client._backbone.descriptor == fs_descriptor + + # we expect the backbone to add its descriptor to the local env + assert os.environ[BackboneFeatureStore.MLI_BACKBONE] == fs_descriptor + + # confirm the worker queue is created and attached correctly + assert client._to_worker_fli is not None + assert client._to_worker_fli.descriptor == wq_descriptor + + # we expect the worker queue descriptor to be placed into the backbone + # we do NOT expect _from_worker_ch to be placed anywhere. it's a specific callback + assert the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] == wq_descriptor + + # confirm the worker channels are created + assert client._from_worker_ch is not None + assert client._to_worker_ch is not None + + # wrap the channels just to easily verify they produces a descriptor + assert DragonCommChannel(client._from_worker_ch).descriptor + assert DragonCommChannel(client._to_worker_ch).descriptor + + # confirm a publisher is created + assert client._publisher is not None + + +def test_protoclient_write_model( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that writing a model using the client causes the model data to be + written to a feature store. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: Passing the worker queue fixture to ensure + the worker queue environment is correctly configured. + from the backbone + """ + + with monkeypatch.context() as ctx: + # we won't actually send here + client = ProtoClient(timing_on=False) + + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + model_key = "my-model" + model_bytes = b"12345" + + client.set_model(model_key, model_bytes) + + # confirm the client modified the underlying feature store + assert client._backbone[model_key] == model_bytes + + +@pytest.mark.parametrize( + "num_listeners, num_model_updates", + [(1, 1), (1, 4), (2, 4), (16, 4), (64, 8)], +) +def test_protoclient_write_model_notification_sent( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, + num_listeners: int, + num_model_updates: int, +): + """Verify that writing a model sends a key-written event. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: an FLI channel the client will retrieve + from the backbone + :param num_listeners: vary the number of registered listeners + to verify that the event is broadcast to everyone + :param num_listeners: vary the number of listeners to register + to verify the broadcast counts messages sent correctly + """ + + # we won't actually send here, but it won't try without registered listeners + listeners = [f"mock-ch-desc-{i}" for i in range(num_listeners)] + + the_backbone[BackboneFeatureStore.MLI_BACKBONE] = the_backbone.descriptor + the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_queue.descriptor + the_backbone[BackboneFeatureStore.MLI_NOTIFY_CONSUMERS] = ",".join(listeners) + the_backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = None + + with monkeypatch.context() as ctx: + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + publisher = t.cast(EventBroadcaster, client._publisher) + + # mock attaching to a channel given the mock-ch-desc in backbone + mock_send = MagicMock(return_value=None) + mock_comm_channel = MagicMock(**{"send": mock_send}, spec=DragonCommChannel) + mock_get_comm_channel = MagicMock(return_value=mock_comm_channel) + ctx.setattr(publisher, "_get_comm_channel", mock_get_comm_channel) + + model_key = "my-model" + model_bytes = b"12345" + + for i in range(num_model_updates): + client.set_model(model_key, model_bytes) + + # confirm that a listener channel was attached + # once for each registered listener in backbone + assert mock_get_comm_channel.call_count == num_listeners * num_model_updates + + # confirm the client raised the key-written event + assert ( + mock_send.call_count == num_listeners * num_model_updates + ), f"Expected {num_listeners} sends with {num_listeners} registrations" + + # with at least 1 consumer registered, we can verify the message is sent + for call_args in mock_send.call_args_list: + send_args = call_args.args + event_bytes, timeout = send_args[0], send_args[1] + + assert event_bytes, "Expected event bytes to be supplied to send" + assert ( + timeout == 0.001 + ), "Expected default timeout on call to `publisher.send`, " + + # confirm the correct event was raised + event = t.cast( + OnWriteFeatureStore, + pickle.loads(event_bytes), + ) + assert event.descriptor == the_backbone.descriptor + assert event.key == model_key diff --git a/tests/dragon/test_reply_building.py b/tests/dragon/test_reply_building.py new file mode 100644 index 0000000000..48493b3c4d --- /dev/null +++ b/tests/dragon/test_reply_building.py @@ -0,0 +1,64 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.control.worker_manager import build_failure_reply + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.mark.parametrize( + "status, message", + [ + pytest.param("timeout", "Worker timed out", id="timeout"), + pytest.param("fail", "Failed while executing", id="fail"), + ], +) +def test_build_failure_reply(status: "Status", message: str): + "Ensures failure replies can be built successfully" + response = build_failure_reply(status, message) + display_name = response.schema.node.displayName # type: ignore + class_name = display_name.split(":")[-1] + assert class_name == "Response" + assert response.status == status + assert response.message == message + + +def test_build_failure_reply_fails(): + "Ensures ValueError is raised if a Status Enum is not used" + with pytest.raises(ValueError) as ex: + response = build_failure_reply("not a status enum", "message") + + assert "Error assigning status to response" in ex.value.args[0] diff --git a/tests/dragon/test_request_dispatcher.py b/tests/dragon/test_request_dispatcher.py new file mode 100644 index 0000000000..70d73e243f --- /dev/null +++ b/tests/dragon/test_request_dispatcher.py @@ -0,0 +1,233 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import gc +import os +import subprocess as sp +import time +import typing as t +from queue import Empty + +import numpy as np +import pytest + +from . import conftest +from .utils import msg_pump + +pytest.importorskip("dragon") + + +# isort: off +import dragon +import multiprocessing as mp + +import torch + +# isort: on + +from dragon import fli +from dragon.data.ddict.ddict import DDict +from dragon.managed_memory import MemoryAlloc + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.control.request_dispatcher import ( + RequestBatch, + RequestDispatcher, +) +from smartsim._core.mli.infrastructure.control.worker_manager import ( + EnvironmentConfigLoader, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +try: + mp.set_start_method("dragon") +except Exception: + pass + + +@pytest.mark.parametrize("num_iterations", [4]) +def test_request_dispatcher( + num_iterations: int, + the_storage: DDict, + test_dir: str, +) -> None: + """Test the request dispatcher batching and queueing system + + This also includes setting a queue to disposable, checking that it is no + longer referenced by the dispatcher. + """ + + to_worker_channel = create_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli) + + backbone_fs = BackboneFeatureStore(the_storage, allow_reserved_writes=True) + + # NOTE: env vars should be set prior to instantiating EnvironmentConfigLoader + # or test environment may be unable to send messages w/queue + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone_fs.descriptor + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=1000, + batch_size=2, + config_loader=config_loader, + worker_type=TorchWorker, + mem_pool_size=2 * 1024**2, + ) + + worker_queue = config_loader.get_queue() + if worker_queue is None: + logger.warning( + "FLI input queue not loaded correctly from config_loader: " + f"{config_loader._queue_descriptor}" + ) + + request_dispatcher._on_start() + + # put some messages into the work queue for the dispatcher to pickup + channels = [] + processes = [] + for i in range(num_iterations): + batch: t.Optional[RequestBatch] = None + mem_allocs = [] + tensors = [] + + # NOTE: creating callbacks in test to avoid a local channel being torn + # down when mock_messages terms but before the final response message is sent + + callback_channel = DragonCommChannel.from_local() + channels.append(callback_channel) + + process = conftest.function_as_dragon_proc( + msg_pump.mock_messages, + [ + worker_queue.descriptor, + backbone_fs.descriptor, + i, + callback_channel.descriptor, + ], + [], + [], + ) + processes.append(process) + process.start() + assert process.returncode is None, "The message pump failed to start" + + # give dragon some time to populate the message queues + for i in range(15): + try: + request_dispatcher._on_iteration() + batch = request_dispatcher.task_queue.get(timeout=1.0) + break + except Empty: + time.sleep(2) + logger.warning(f"Task queue is empty on iteration {i}") + continue + except Exception as exc: + logger.error(f"Task queue exception on iteration {i}") + raise exc + + assert batch is not None + assert batch.has_valid_requests + + model_key = batch.model_id.key + + try: + transform_result = batch.inputs + for transformed, dims, dtype in zip( + transform_result.transformed, + transform_result.dims, + transform_result.dtypes, + ): + mem_alloc = MemoryAlloc.attach(transformed) + mem_allocs.append(mem_alloc) + itemsize = np.empty((1), dtype=dtype).itemsize + tensors.append( + torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[0 : np.prod(dims) * itemsize], + dtype=dtype, + ).reshape(dims) + ) + ) + + assert len(batch.requests) == 2 + assert batch.model_id.key == model_key + assert model_key in request_dispatcher._queues + assert model_key in request_dispatcher._active_queues + assert len(request_dispatcher._queues[model_key]) == 1 + assert request_dispatcher._queues[model_key][0].empty() + assert request_dispatcher._queues[model_key][0].model_id.key == model_key + assert len(tensors) == 1 + assert tensors[0].shape == torch.Size([2, 2]) + + for tensor in tensors: + for sample_idx in range(tensor.shape[0]): + tensor_in = tensor[sample_idx] + tensor_out = (sample_idx + 1) * torch.ones( + (2,), dtype=torch.float32 + ) + assert torch.equal(tensor_in, tensor_out) + + except Exception as exc: + raise exc + finally: + for mem_alloc in mem_allocs: + mem_alloc.free() + + request_dispatcher._active_queues[model_key].make_disposable() + assert request_dispatcher._active_queues[model_key].can_be_removed + + request_dispatcher._on_iteration() + + assert model_key not in request_dispatcher._active_queues + assert model_key not in request_dispatcher._queues + + # Try to remove the dispatcher and free the memory + del request_dispatcher + gc.collect() diff --git a/tests/dragon/test_torch_worker.py b/tests/dragon/test_torch_worker.py new file mode 100644 index 0000000000..2a9e7d01bd --- /dev/null +++ b/tests/dragon/test_torch_worker.py @@ -0,0 +1,221 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import numpy as np +import pytest +import torch + +dragon = pytest.importorskip("dragon") +import dragon.globalservices.pool as dragon_gs_pool +from dragon.managed_memory import MemoryAlloc, MemoryPool +from torch import nn +from torch.nn import functional as F + +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + RequestBatch, + TransformInputResult, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__) +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +# simple MNIST in PyTorch +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x, y): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +torch_device = {"cpu": "cpu", "gpu": "cuda"} + + +def get_batch() -> torch.Tensor: + return torch.rand(20, 1, 28, 28) + + +def create_torch_model(): + n = Net() + example_forward_input = get_batch() + module = torch.jit.trace(n, [example_forward_input, example_forward_input]) + model_buffer = io.BytesIO() + torch.jit.save(module, model_buffer) + return model_buffer.getvalue() + + +def get_request() -> InferenceRequest: + + tensors = [get_batch() for _ in range(2)] + tensor_numpy = [tensor.numpy() for tensor in tensors] + serialized_tensors_descriptors = [ + MessageHandler.build_tensor_descriptor("c", "float32", list(tensor.shape)) + for tensor in tensors + ] + + return InferenceRequest( + model_key=ModelKey(key="model", descriptor="xyz"), + callback=None, + raw_inputs=tensor_numpy, + input_keys=None, + input_meta=serialized_tensors_descriptors, + output_keys=None, + raw_model=create_torch_model(), + batch_size=0, + ) + + +def get_request_batch_from_request( + request: InferenceRequest, inputs: t.Optional[TransformInputResult] = None +) -> RequestBatch: + + return RequestBatch([request], inputs, request.model_key) + + +sample_request: InferenceRequest = get_request() +sample_request_batch: RequestBatch = get_request_batch_from_request(sample_request) +worker = TorchWorker() + + +def test_load_model(mlutils) -> None: + fetch_model_result = FetchModelResult(sample_request.raw_model) + load_model_result = worker.load_model( + sample_request_batch, fetch_model_result, mlutils.get_test_device().lower() + ) + + assert load_model_result.model( + get_batch().to(torch_device[mlutils.get_test_device().lower()]), + get_batch().to(torch_device[mlutils.get_test_device().lower()]), + ).shape == torch.Size((20, 10)) + + +def test_transform_input(mlutils) -> None: + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_input_result = worker.transform_input( + sample_request_batch, [fetch_input_result], mem_pool + ) + + batch = get_batch().numpy() + assert transform_input_result.slices[0] == slice(0, batch.shape[0]) + + for tensor_index in range(2): + assert torch.Size(transform_input_result.dims[tensor_index]) == batch.shape + assert transform_input_result.dtypes[tensor_index] == str(batch.dtype) + mem_alloc = MemoryAlloc.attach(transform_input_result.transformed[tensor_index]) + itemsize = batch.itemsize + tensor = torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[ + 0 : np.prod(transform_input_result.dims[tensor_index]) * itemsize + ], + dtype=transform_input_result.dtypes[tensor_index], + ).reshape(transform_input_result.dims[tensor_index]) + ) + + assert torch.equal( + tensor, torch.from_numpy(sample_request.raw_inputs[tensor_index]) + ) + + mem_pool.destroy() + + +def test_execute(mlutils) -> None: + load_model_result = LoadModelResult( + Net().to(torch_device[mlutils.get_test_device().lower()]) + ) + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + request_batch = get_request_batch_from_request(sample_request, fetch_input_result) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_result = worker.transform_input( + request_batch, [fetch_input_result], mem_pool + ) + + execute_result = worker.execute( + request_batch, + load_model_result, + transform_result, + mlutils.get_test_device().lower(), + ) + + assert all( + result.shape == torch.Size((20, 10)) for result in execute_result.predictions + ) + + mem_pool.destroy() + + +def test_transform_output(mlutils): + tensors = [torch.rand((20, 10)) for _ in range(2)] + execute_result = ExecuteResult(tensors, [slice(0, 20)]) + + transformed_output = worker.transform_output(sample_request_batch, execute_result) + + assert transformed_output[0].outputs == [item.numpy().tobytes() for item in tensors] + assert transformed_output[0].shape == None + assert transformed_output[0].order == "c" + assert transformed_output[0].dtype == "float32" diff --git a/tests/dragon/test_worker_manager.py b/tests/dragon/test_worker_manager.py new file mode 100644 index 0000000000..4047a731fc --- /dev/null +++ b/tests/dragon/test_worker_manager.py @@ -0,0 +1,314 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import logging +import pathlib +import time + +import pytest + +torch = pytest.importorskip("torch") +dragon = pytest.importorskip("dragon") + +import multiprocessing as mp + +try: + mp.set_start_method("dragon") +except Exception: + pass + +import os + +import torch.nn as nn +from dragon import fli + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.control.worker_manager import ( + EnvironmentConfigLoader, + WorkerManager, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_util import create_ddict +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +from .utils.channel import FileSystemCommChannel + +logger = get_logger(__name__) +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +class MiniModel(nn.Module): + """A torch model that can be executed by the default torch worker""" + + def __init__(self): + """Initialize the model.""" + super().__init__() + + self._name = "mini-model" + self._net = torch.nn.Linear(2, 1) + + def forward(self, input): + """Execute a forward pass.""" + return self._net(input) + + @property + def bytes(self) -> bytes: + """Retrieve the serialized model + + :returns: The byte stream of the model file + """ + buffer = io.BytesIO() + scripted = torch.jit.trace(self._net, self.get_batch()) + torch.jit.save(scripted, buffer) + return buffer.getvalue() + + @classmethod + def get_batch(cls) -> "torch.Tensor": + """Generate a single batch of data with the correct + shape for inference. + + :returns: The batch as a torch tensor + """ + return torch.randn((100, 2), dtype=torch.float32) + + +def create_model(model_path: pathlib.Path) -> pathlib.Path: + """Create a simple torch model and persist to disk for + testing purposes. + + :param model_path: The path to the torch model file + """ + if not model_path.parent.exists(): + model_path.parent.mkdir(parents=True, exist_ok=True) + + model_path.unlink(missing_ok=True) + + mini_model = MiniModel() + torch.save(mini_model, model_path) + + return model_path + + +def load_model() -> bytes: + """Create a simple torch model in memory for testing.""" + mini_model = MiniModel() + return mini_model.bytes + + +def mock_messages( + feature_store_root_dir: pathlib.Path, + comm_channel_root_dir: pathlib.Path, + kill_queue: mp.Queue, +) -> None: + """Mock event producer for triggering the inference pipeline. + + :param feature_store_root_dir: Path to a directory where a + FileSystemFeatureStore can read & write results + :param comm_channel_root_dir: Path to a directory where a + FileSystemCommChannel can read & write messages + :param kill_queue: Queue used by unit test to stop mock_message process + """ + feature_store_root_dir.mkdir(parents=True, exist_ok=True) + comm_channel_root_dir.mkdir(parents=True, exist_ok=True) + + iteration_number = 0 + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + backbone = config_loader.get_backbone() + + worker_queue = config_loader.get_queue() + if worker_queue is None: + queue_desc = config_loader._queue_descriptor + logger.warn( + f"FLI input queue not loaded correctly from config_loader: {queue_desc}" + ) + + model_key = "mini-model" + model_bytes = load_model() + backbone[model_key] = model_bytes + + while True: + if not kill_queue.empty(): + return + iteration_number += 1 + time.sleep(1) + + channel_key = comm_channel_root_dir / f"{iteration_number}/channel.txt" + callback_channel = FileSystemCommChannel(pathlib.Path(channel_key)) + + batch = MiniModel.get_batch() + shape = batch.shape + batch_bytes = batch.numpy().tobytes() + + logger.debug(f"Model content: {backbone[model_key][:20]}") + + input_descriptor = MessageHandler.build_tensor_descriptor( + "f", "float32", list(shape) + ) + + # The first request is always the metadata... + request = MessageHandler.build_request( + reply_channel=callback_channel.descriptor, + model=MessageHandler.build_model(model_bytes, "mini-model", "1.0"), + inputs=[input_descriptor], + outputs=[], + output_descriptors=[], + custom_attributes=None, + ) + request_bytes = MessageHandler.serialize_request(request) + fli: DragonFLIChannel = worker_queue + + with fli._fli.sendh(timeout=None, stream_channel=fli._channel) as sendh: + sendh.send_bytes(request_bytes) + sendh.send_bytes(batch_bytes) + + logger.info("published message") + + if iteration_number > 5: + return + + +def mock_mli_infrastructure_mgr() -> None: + """Create resources normally instanatiated by the infrastructure + management portion of the DragonBackend. + """ + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + integrated_worker = TorchWorker + + worker_manager = WorkerManager( + config_loader, + integrated_worker, + as_service=True, + cooldown=10, + device="cpu", + dispatcher_queue=mp.Queue(maxsize=0), + ) + worker_manager.execute() + + +@pytest.fixture +def prepare_environment(test_dir: str) -> pathlib.Path: + """Cleanup prior outputs to run demo repeatedly. + + :param test_dir: the directory to prepare + :returns: The path to the log file + """ + path = pathlib.Path(f"{test_dir}/workermanager.log") + logging.basicConfig(filename=path.absolute(), level=logging.DEBUG) + return path + + +def test_worker_manager(prepare_environment: pathlib.Path) -> None: + """Test the worker manager. + + :param prepare_environment: Pass this fixture to configure + global resources before the worker manager executes + """ + + test_path = prepare_environment + fs_path = test_path / "feature_store" + comm_path = test_path / "comm_store" + + mgr_per_node = 1 + num_nodes = 2 + mem_per_node = 128 * 1024**2 + + storage = create_ddict(num_nodes, mgr_per_node, mem_per_node) + backbone = BackboneFeatureStore(storage, allow_reserved_writes=True) + + to_worker_channel = create_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + + to_worker_fli_comm_channel = DragonFLIChannel(to_worker_fli) + + # NOTE: env vars must be set prior to instantiating EnvironmentConfigLoader + # or test environment may be unable to send messages w/queue + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = ( + to_worker_fli_comm_channel.descriptor + ) + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + integrated_worker_type = TorchWorker + + worker_manager = WorkerManager( + config_loader, + integrated_worker_type, + as_service=True, + cooldown=5, + device="cpu", + dispatcher_queue=mp.Queue(maxsize=0), + ) + + worker_queue = config_loader.get_queue() + if worker_queue is None: + logger.warn( + f"FLI input queue not loaded correctly from config_loader: {config_loader._queue_descriptor}" + ) + backbone.worker_queue = to_worker_fli_comm_channel.descriptor + + # create a mock client application to populate the request queue + kill_queue = mp.Queue() + msg_pump = mp.Process( + target=mock_messages, + args=(fs_path, comm_path, kill_queue), + ) + msg_pump.start() + + # create a process to execute commands + process = mp.Process(target=mock_mli_infrastructure_mgr) + + # let it send some messages before starting the worker manager + msg_pump.join(timeout=5) + process.start() + msg_pump.join(timeout=5) + kill_queue.put_nowait("kill!") + process.join(timeout=5) + msg_pump.kill() + process.kill() diff --git a/tests/dragon/utils/__init__.py b/tests/dragon/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py new file mode 100644 index 0000000000..4c46359c2d --- /dev/null +++ b/tests/dragon/utils/channel.py @@ -0,0 +1,125 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import threading +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. + + :param key: a path to the root directory of the feature store + """ + self._lock = threading.RLock() + + super().__init__(key.as_posix()) + self._file_path = key + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes, timeout: float = 0) -> None: + """Send a message throuh the underlying communication channel. + + :param value: The value to send + :param timeout: maximum time to wait (in seconds) for messages to send + """ + with self._lock: + # write as text so we can add newlines as delimiters + with open(self._file_path, "a") as fp: + encoded_value = base64.b64encode(value).decode("utf-8") + fp.write(f"{encoded_value}\n") + logger.debug(f"FileSystemCommChannel {self._file_path} sent message") + + def recv(self, timeout: float = 0) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: maximum time to wait (in seconds) for messages to arrive + :returns: the received message + :raises SmartSimError: if the descriptor points to a missing file + """ + with self._lock: + messages: t.List[bytes] = [] + if not self._file_path.exists(): + raise SmartSimError("Empty channel") + + # read as text so we can split on newlines + with open(self._file_path, "r") as fp: + lines = fp.readlines() + + if lines: + line = lines.pop(0) + event_bytes = base64.b64decode(line.encode("utf-8")) + messages.append(event_bytes) + + self.clear() + + # remove the first message only, write remainder back... + if len(lines) > 0: + with open(self._file_path, "w") as fp: + fp.writelines(lines) + + logger.debug( + f"FileSystemCommChannel {self._file_path} received message" + ) + + return messages + + def clear(self) -> None: + """Create an empty file for events.""" + if self._file_path.exists(): + self._file_path.unlink() + self._file_path.touch() + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemCommChannel + """ + try: + path = pathlib.Path(descriptor) + return FileSystemCommChannel(path) + except: + logger.warning(f"failed to create fs comm channel: {descriptor}") + raise diff --git a/tests/dragon/utils/msg_pump.py b/tests/dragon/utils/msg_pump.py new file mode 100644 index 0000000000..8d69e57c63 --- /dev/null +++ b/tests/dragon/utils/msg_pump.py @@ -0,0 +1,225 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import logging +import pathlib +import sys +import time +import typing as t + +import pytest + +pytest.importorskip("torch") +pytest.importorskip("dragon") + + +# isort: off +import dragon +import multiprocessing as mp +import torch +import torch.nn as nn + +# isort: on + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__, log_level=logging.DEBUG) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +try: + mp.set_start_method("dragon") +except Exception: + pass + + +class MiniModel(nn.Module): + def __init__(self): + super().__init__() + + self._name = "mini-model" + self._net = torch.nn.Linear(2, 1) + + def forward(self, input): + return self._net(input) + + @property + def bytes(self) -> bytes: + """Returns the model serialized to a byte stream""" + buffer = io.BytesIO() + scripted = torch.jit.trace(self._net, self.get_batch()) + torch.jit.save(scripted, buffer) + return buffer.getvalue() + + @classmethod + def get_batch(cls) -> "torch.Tensor": + return torch.randn((100, 2), dtype=torch.float32) + + +def load_model() -> bytes: + """Create a simple torch model in memory for testing""" + mini_model = MiniModel() + return mini_model.bytes + + +def persist_model_file(model_path: pathlib.Path) -> pathlib.Path: + """Create a simple torch model and persist to disk for + testing purposes. + + :returns: Path to the model file + """ + # test_path = pathlib.Path(work_dir) + if not model_path.parent.exists(): + model_path.parent.mkdir(parents=True, exist_ok=True) + + model_path.unlink(missing_ok=True) + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + + return model_path + + +def _mock_messages( + dispatch_fli_descriptor: str, + fs_descriptor: str, + parent_iteration: int, + callback_descriptor: str, +) -> None: + """Mock event producer for triggering the inference pipeline.""" + model_key = "mini-model" + # mock_message sends 2 messages, so we offset by 2 * (# of iterations in caller) + offset = 2 * parent_iteration + + feature_store = BackboneFeatureStore.from_descriptor(fs_descriptor) + request_dispatcher_queue = DragonFLIChannel.from_descriptor(dispatch_fli_descriptor) + + feature_store[model_key] = load_model() + + for iteration_number in range(2): + logged_iteration = offset + iteration_number + logger.debug(f"Sending mock message {logged_iteration}") + + output_key = f"output-{iteration_number}" + + tensor = ( + (iteration_number + 1) * torch.ones((1, 2), dtype=torch.float32) + ).numpy() + fsd = feature_store.descriptor + + tensor_desc = MessageHandler.build_tensor_descriptor( + "c", "float32", list(tensor.shape) + ) + + message_tensor_output_key = MessageHandler.build_tensor_key(output_key, fsd) + message_model_key = MessageHandler.build_model_key(model_key, fsd) + + request = MessageHandler.build_request( + reply_channel=callback_descriptor, + model=message_model_key, + inputs=[tensor_desc], + outputs=[message_tensor_output_key], + output_descriptors=[], + custom_attributes=None, + ) + + logger.info(f"Sending request {iteration_number} to request_dispatcher_queue") + request_bytes = MessageHandler.serialize_request(request) + + logger.info("Sending msg_envelope") + + # cuid = request_dispatcher_queue._channel.cuid + # logger.info(f"\tInternal cuid: {cuid}") + + # send the header & body together so they arrive together + try: + request_dispatcher_queue.send_multiple([request_bytes, tensor.tobytes()]) + logger.info(f"\tenvelope 0: {request_bytes[:5]}...") + logger.info(f"\tenvelope 1: {tensor.tobytes()[:5]}...") + except Exception as ex: + logger.exception("Unable to send request envelope") + + logger.info("All messages sent") + + # keep the process alive for an extra 15 seconds to let the processor + # have access to the channels before they're destroyed + for _ in range(15): + time.sleep(1) + + +def mock_messages( + dispatch_fli_descriptor: str, + fs_descriptor: str, + parent_iteration: int, + callback_descriptor: str, +) -> int: + """Mock event producer for triggering the inference pipeline. Used + when starting using multiprocessing.""" + logger.info(f"{dispatch_fli_descriptor=}") + logger.info(f"{fs_descriptor=}") + logger.info(f"{parent_iteration=}") + logger.info(f"{callback_descriptor=}") + + try: + return _mock_messages( + dispatch_fli_descriptor, + fs_descriptor, + parent_iteration, + callback_descriptor, + ) + except Exception as ex: + logger.exception() + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + args = argparse.ArgumentParser() + + args.add_argument("--dispatch-fli-descriptor", type=str) + args.add_argument("--fs-descriptor", type=str) + args.add_argument("--parent-iteration", type=int) + args.add_argument("--callback-descriptor", type=str) + + args = args.parse_args() + + return_code = mock_messages( + args.dispatch_fli_descriptor, + args.fs_descriptor, + args.parent_iteration, + args.callback_descriptor, + ) + sys.exit(return_code) diff --git a/tests/dragon/utils/worker.py b/tests/dragon/utils/worker.py new file mode 100644 index 0000000000..0582cae566 --- /dev/null +++ b/tests/dragon/utils/worker.py @@ -0,0 +1,104 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import torch + +import smartsim._core.mli.infrastructure.worker.worker as mliw +import smartsim.error as sse +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class IntegratedTorchWorker(mliw.MachineLearningWorkerBase): + """A minimum implementation of a worker that executes a PyTorch model""" + + # @staticmethod + # def deserialize(request: InferenceRequest) -> t.List[t.Any]: + # # request.input_meta + # # request.raw_inputs + # return request + + @staticmethod + def load_model( + request: mliw.InferenceRequest, fetch_result: mliw.FetchModelResult, device: str + ) -> mliw.LoadModelResult: + model_bytes = fetch_result.model_bytes or request.raw_model + if not model_bytes: + raise ValueError("Unable to load model without reference object") + + model: torch.nn.Module = torch.load(io.BytesIO(model_bytes)) + result = mliw.LoadModelResult(model) + return result + + @staticmethod + def transform_input( + request: mliw.InferenceRequest, + fetch_result: mliw.FetchInputResult, + device: str, + ) -> mliw.TransformInputResult: + # extra metadata for assembly can be found in request.input_meta + raw_inputs = request.raw_inputs or fetch_result.inputs + + result: t.List[torch.Tensor] = [] + # should this happen here? + # consider - fortran to c data layout + # is there an intermediate representation before really doing torch.load? + if raw_inputs: + result = [torch.load(io.BytesIO(item)) for item in raw_inputs] + + return mliw.TransformInputResult(result) + + @staticmethod + def execute( + request: mliw.InferenceRequest, + load_result: mliw.LoadModelResult, + transform_result: mliw.TransformInputResult, + ) -> mliw.ExecuteResult: + if not load_result.model: + raise sse.SmartSimError("Model must be loaded to execute") + + model = load_result.model + results = [model(tensor) for tensor in transform_result.transformed] + + execute_result = mliw.ExecuteResult(results) + return execute_result + + @staticmethod + def transform_output( + request: mliw.InferenceRequest, + execute_result: mliw.ExecuteResult, + result_device: str, + ) -> mliw.TransformOutputResult: + # send the original tensors... + execute_result.predictions = [t.detach() for t in execute_result.predictions] + # todo: solve sending all tensor metadata that coincisdes with each prediction + return mliw.TransformOutputResult( + execute_result.predictions, [1], "c", "float32" + ) diff --git a/tests/mli/__init__.py b/tests/mli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/mli/channel.py b/tests/mli/channel.py new file mode 100644 index 0000000000..4c46359c2d --- /dev/null +++ b/tests/mli/channel.py @@ -0,0 +1,125 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import threading +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. + + :param key: a path to the root directory of the feature store + """ + self._lock = threading.RLock() + + super().__init__(key.as_posix()) + self._file_path = key + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes, timeout: float = 0) -> None: + """Send a message throuh the underlying communication channel. + + :param value: The value to send + :param timeout: maximum time to wait (in seconds) for messages to send + """ + with self._lock: + # write as text so we can add newlines as delimiters + with open(self._file_path, "a") as fp: + encoded_value = base64.b64encode(value).decode("utf-8") + fp.write(f"{encoded_value}\n") + logger.debug(f"FileSystemCommChannel {self._file_path} sent message") + + def recv(self, timeout: float = 0) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: maximum time to wait (in seconds) for messages to arrive + :returns: the received message + :raises SmartSimError: if the descriptor points to a missing file + """ + with self._lock: + messages: t.List[bytes] = [] + if not self._file_path.exists(): + raise SmartSimError("Empty channel") + + # read as text so we can split on newlines + with open(self._file_path, "r") as fp: + lines = fp.readlines() + + if lines: + line = lines.pop(0) + event_bytes = base64.b64decode(line.encode("utf-8")) + messages.append(event_bytes) + + self.clear() + + # remove the first message only, write remainder back... + if len(lines) > 0: + with open(self._file_path, "w") as fp: + fp.writelines(lines) + + logger.debug( + f"FileSystemCommChannel {self._file_path} received message" + ) + + return messages + + def clear(self) -> None: + """Create an empty file for events.""" + if self._file_path.exists(): + self._file_path.unlink() + self._file_path.touch() + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemCommChannel + """ + try: + path = pathlib.Path(descriptor) + return FileSystemCommChannel(path) + except: + logger.warning(f"failed to create fs comm channel: {descriptor}") + raise diff --git a/tests/mli/feature_store.py b/tests/mli/feature_store.py new file mode 100644 index 0000000000..7bc18253c8 --- /dev/null +++ b/tests/mli/feature_store.py @@ -0,0 +1,144 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import typing as t + +import smartsim.error as sse +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class MemoryFeatureStore(FeatureStore): + """A feature store with values persisted only in local memory""" + + def __init__(self, storage: t.Optional[t.Dict[str, bytes]] = None) -> None: + """Initialize the MemoryFeatureStore instance""" + super().__init__("in-memory-fs") + if storage is None: + storage = {"_": "abc"} + self._storage: t.Dict[str, bytes] = storage + + def _get(self, key: str) -> bytes: + """Retrieve an item using key + + :param key: Unique key of an item to retrieve from the feature store""" + if key not in self._storage: + raise sse.SmartSimError(f"{key} not found in feature store") + return self._storage[key] + + def _set(self, key: str, value: bytes) -> None: + """Membership operator to test for a key existing within the feature store. + + :param key: Unique key of an item to retrieve from the feature store + :returns: `True` if the key is found, `False` otherwise""" + self._check_reserved(key) + self._storage[key] = value + + def _contains(self, key: str) -> bool: + """Membership operator to test for a key existing within the feature store. + Return `True` if the key is found, `False` otherwise + :param key: Unique key of an item to retrieve from the feature store""" + return key in self._storage + + +class FileSystemFeatureStore(FeatureStore): + """Alternative feature store implementation for testing. Stores all + data on the file system""" + + def __init__(self, storage_dir: t.Union[pathlib.Path, str] = None) -> None: + """Initialize the FileSystemFeatureStore instance + + :param storage_dir: (optional) root directory to store all data relative to""" + if isinstance(storage_dir, str): + storage_dir = pathlib.Path(storage_dir) + self._storage_dir = storage_dir + super().__init__(storage_dir.as_posix()) + + def _get(self, key: str) -> bytes: + """Retrieve an item using key + + :param key: Unique key of an item to retrieve from the feature store""" + path = self._key_path(key) + if not path.exists(): + raise sse.SmartSimError(f"{path} not found in feature store") + return path.read_bytes() + + def _set(self, key: str, value: bytes) -> None: + """Assign a value using key + + :param key: Unique key of an item to set in the feature store + :param value: Value to persist in the feature store""" + path = self._key_path(key, create=True) + if isinstance(value, str): + value = value.encode("utf-8") + path.write_bytes(value) + + def _contains(self, key: str) -> bool: + """Membership operator to test for a key existing within the feature store. + + :param key: Unique key of an item to retrieve from the feature store + :returns: `True` if the key is found, `False` otherwise""" + path = self._key_path(key) + return path.exists() + + def _key_path(self, key: str, create: bool = False) -> pathlib.Path: + """Given a key, return a path that is optionally combined with a base + directory used by the FileSystemFeatureStore. + + :param key: Unique key of an item to retrieve from the feature store""" + value = pathlib.Path(key) + + if self._storage_dir: + value = self._storage_dir / key + + if create: + value.parent.mkdir(parents=True, exist_ok=True) + + return value + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemFeatureStore": + """A factory method that creates an instance from a descriptor string + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemFeatureStore""" + try: + path = pathlib.Path(descriptor) + path.mkdir(parents=True, exist_ok=True) + if not path.is_dir(): + raise ValueError("FileSystemFeatureStore requires a directory path") + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + return FileSystemFeatureStore(path) + except: + logger.error(f"Error while creating FileSystemFeatureStore: {descriptor}") + raise diff --git a/tests/mli/test_integrated_torch_worker.py b/tests/mli/test_integrated_torch_worker.py new file mode 100644 index 0000000000..60f1f0c6b9 --- /dev/null +++ b/tests/mli/test_integrated_torch_worker.py @@ -0,0 +1,275 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import typing as t + +import pytest +import torch + +# import smartsim.error as sse +# from smartsim._core.mli.infrastructure.control import workermanager as mli +# from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.utils import installed_redisai_backends + +# The tests in this file belong to the group_b group +pytestmark = pytest.mark.group_b + +# retrieved from pytest fixtures +is_dragon = pytest.test_launcher == "dragon" +torch_available = "torch" in installed_redisai_backends() + + +@pytest.fixture +def persist_torch_model(test_dir: str) -> pathlib.Path: + test_path = pathlib.Path(test_dir) + model_path = test_path / "basic.pt" + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + + return model_path + + +# todo: move deserialization tests into suite for worker manager where serialization occurs + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_direct_request(persist_torch_model: pathlib.Path) -> None: +# """Verify that a direct requestis deserialized properly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# model_bytes = persist_torch_model.read_bytes() +# input_tensor = torch.randn(2) + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# message_tensor_input = MessageHandler.build_tensor( +# input_tensor, "c", "float32", [2] +# ) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=model_bytes, +# inputs=[message_tensor_input], +# outputs=[], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_indirect_request(persist_torch_model: pathlib.Path) -> None: +# """Verify that an indirect request is deserialized correctly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# model_key = "persisted-model" +# # model_bytes = persist_torch_model.read_bytes() +# # feature_store[model_key] = model_bytes + +# input_key = f"demo-input" +# # input_tensor = torch.randn(2) +# # feature_store[input_key] = input_tensor + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# output_key = f"demo-output" + +# message_tensor_output_key = MessageHandler.build_tensor_key(output_key) +# message_tensor_input_key = MessageHandler.build_tensor_key(input_key) +# message_model_key = MessageHandler.build_model_key(model_key) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=message_model_key, +# inputs=[message_tensor_input_key], +# outputs=[message_tensor_output_key], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_mixed_mode_indirect_inputs( +# persist_torch_model: pathlib.Path, +# ) -> None: +# """Verify that a mixed mode (combining direct and indirect inputs, models, outputs) +# with indirect inputs is deserialized correctly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# # model_key = "persisted-model" +# model_bytes = persist_torch_model.read_bytes() +# # feature_store[model_key] = model_bytes + +# input_key = f"demo-input" +# # input_tensor = torch.randn(2) +# # feature_store[input_key] = input_tensor + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# output_key = f"demo-output" + +# message_tensor_output_key = MessageHandler.build_tensor_key(output_key) +# message_tensor_input_key = MessageHandler.build_tensor_key(input_key) +# # message_model_key = MessageHandler.build_model_key(model_key) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=model_bytes, +# inputs=[message_tensor_input_key], +# # outputs=[message_tensor_output_key], +# outputs=[], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_mixed_mode_indirect_outputs( +# persist_torch_model: pathlib.Path, +# ) -> None: +# """Verify that a mixed mode (combining direct and indirect inputs, models, outputs) +# with indirect outputs is deserialized correctly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# # model_key = "persisted-model" +# model_bytes = persist_torch_model.read_bytes() +# # feature_store[model_key] = model_bytes + +# input_key = f"demo-input" +# input_tensor = torch.randn(2) +# # feature_store[input_key] = input_tensor + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# output_key = f"demo-output" + +# message_tensor_output_key = MessageHandler.build_tensor_key(output_key) +# # message_tensor_input_key = MessageHandler.build_tensor_key(input_key) +# # message_model_key = MessageHandler.build_model_key(model_key) +# message_tensor_input = MessageHandler.build_tensor( +# input_tensor, "c", "float32", [2] +# ) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=model_bytes, +# inputs=[message_tensor_input], +# # outputs=[message_tensor_output_key], +# outputs=[message_tensor_output_key], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_deserialize_mixed_mode_indirect_model( +# persist_torch_model: pathlib.Path, +# ) -> None: +# """Verify that a mixed mode (combining direct and indirect inputs, models, outputs) +# with indirect outputs is deserialized correctly""" +# worker = mli.IntegratedTorchWorker +# # feature_store = mli.MemoryFeatureStore() + +# model_key = "persisted-model" +# # model_bytes = persist_torch_model.read_bytes() +# # feature_store[model_key] = model_bytes + +# # input_key = f"demo-input" +# input_tensor = torch.randn(2) +# # feature_store[input_key] = input_tensor + +# expected_callback_channel = b"faux_channel_descriptor_bytes" +# callback_channel = mli.DragonCommChannel.find(expected_callback_channel) + +# output_key = f"demo-output" + +# # message_tensor_output_key = MessageHandler.build_tensor_key(output_key) +# # message_tensor_input_key = MessageHandler.build_tensor_key(input_key) +# message_model_key = MessageHandler.build_model_key(model_key) +# message_tensor_input = MessageHandler.build_tensor( +# input_tensor, "c", "float32", [2] +# ) + +# request = MessageHandler.build_request( +# reply_channel=callback_channel.descriptor, +# model=message_model_key, +# inputs=[message_tensor_input], +# # outputs=[message_tensor_output_key], +# outputs=[], +# custom_attributes=None, +# ) + +# msg_bytes = MessageHandler.serialize_request(request) + +# inference_request = worker.deserialize(msg_bytes) +# assert inference_request.callback._descriptor == expected_callback_channel + + +# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +# def test_serialize(test_dir: str, persist_torch_model: pathlib.Path) -> None: +# """Verify that the worker correctly executes reply serialization""" +# worker = mli.IntegratedTorchWorker + +# reply = mli.InferenceReply() +# reply.output_keys = ["foo", "bar"] + +# # use the worker implementation of reply serialization to get bytes for +# # use on the callback channel +# reply_bytes = worker.serialize_reply(reply) +# assert reply_bytes is not None + +# # deserialize to verity the mapping in the worker.serialize_reply was correct +# actual_reply = MessageHandler.deserialize_response(reply_bytes) + +# actual_tensor_keys = [tk.key for tk in actual_reply.result.keys] +# assert set(actual_tensor_keys) == set(reply.output_keys) +# assert actual_reply.status == 200 +# assert actual_reply.statusMessage == "success" diff --git a/tests/mli/test_service.py b/tests/mli/test_service.py new file mode 100644 index 0000000000..3635f6ff78 --- /dev/null +++ b/tests/mli/test_service.py @@ -0,0 +1,290 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import datetime +import multiprocessing as mp +import pathlib +import time +import typing as t +from asyncore import loop + +import pytest +import torch + +import smartsim.error as sse +from smartsim._core.entrypoints.service import Service + +# The tests in this file belong to the group_b group +pytestmark = pytest.mark.group_a + + +class SimpleService(Service): + """Mock implementation of a service that counts method invocations + using the base class event hooks.""" + + def __init__( + self, + log: t.List[str], + quit_after: int = -1, + as_service: bool = False, + cooldown: float = 0, + loop_delay: float = 0, + hc_freq: float = -1, + run_for: float = 0, + ) -> None: + super().__init__(as_service, cooldown, loop_delay, hc_freq) + self._log = log + self._quit_after = quit_after + self.num_starts = 0 + self.num_shutdowns = 0 + self.num_health_checks = 0 + self.num_cooldowns = 0 + self.num_delays = 0 + self.num_iterations = 0 + self.num_can_shutdown = 0 + self.run_for = run_for + self.start_time = time.time() + + @property + def runtime(self) -> float: + return time.time() - self.start_time + + def _can_shutdown(self) -> bool: + self.num_can_shutdown += 1 + + if self._quit_after > -1 and self.num_iterations >= self._quit_after: + return True + if self.run_for > 0: + return self.runtime >= self.run_for + + def _on_start(self) -> None: + self.num_starts += 1 + + def _on_shutdown(self) -> None: + self.num_shutdowns += 1 + + def _on_health_check(self) -> None: + self.num_health_checks += 1 + + def _on_cooldown_elapsed(self) -> None: + self.num_cooldowns += 1 + + def _on_delay(self) -> None: + self.num_delays += 1 + + def _on_iteration(self) -> None: + self.num_iterations += 1 + + return self.num_iterations >= self._quit_after + + +def test_service_init() -> None: + """Verify expected default values after Service initialization""" + activity_log: t.List[str] = [] + service = SimpleService(activity_log) + + assert service._as_service is False + assert service._cooldown == 0 + assert service._loop_delay == 0 + + +def test_service_run_once() -> None: + """Verify the service completes after a single call to _on_iteration""" + activity_log: t.List[str] = [] + service = SimpleService(activity_log) + + service.execute() + + assert service.num_iterations == 1 + assert service.num_starts == 1 + assert service.num_cooldowns == 0 # it never exceeds a cooldown period + assert service.num_can_shutdown == 0 # it automatically exits in run once + assert service.num_shutdowns == 1 + + +@pytest.mark.parametrize( + "num_iterations", + [ + pytest.param(0, id="Immediate Shutdown"), + pytest.param(1, id="1x"), + pytest.param(2, id="2x"), + pytest.param(4, id="4x"), + pytest.param(8, id="8x"), + pytest.param(16, id="16x"), + pytest.param(32, id="32x"), + ], +) +def test_service_run_until_can_shutdown(num_iterations: int) -> None: + """Verify the service completes after a dynamic number of iterations + based on the return value of `_can_shutdown`""" + activity_log: t.List[str] = [] + + service = SimpleService(activity_log, quit_after=num_iterations, as_service=True) + + service.execute() + + if num_iterations == 0: + # no matter what, it should always execute the _on_iteration method + assert service.num_iterations == 1 + else: + # the shutdown check follows on_iteration. there will be one last call + assert service.num_iterations == num_iterations + + assert service.num_starts == 1 + assert service.num_shutdowns == 1 + + +@pytest.mark.parametrize( + "cooldown", + [ + pytest.param(1, id="1s"), + pytest.param(3, id="3s"), + pytest.param(5, id="5s"), + ], +) +def test_service_cooldown(cooldown: int) -> None: + """Verify that the cooldown period is respected""" + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=1, + as_service=True, + cooldown=cooldown, + loop_delay=0, + ) + + ts0 = datetime.datetime.now() + service.execute() + ts1 = datetime.datetime.now() + + fudge_factor = 1.1 # allow a little bit of wiggle room for the loop + duration_in_seconds = (ts1 - ts0).total_seconds() + + assert duration_in_seconds <= cooldown * fudge_factor + assert service.num_cooldowns == 1 + assert service.num_shutdowns == 1 + + +@pytest.mark.parametrize( + "delay, num_iterations", + [ + pytest.param(1, 3, id="1s delay, 3x"), + pytest.param(3, 2, id="2s delay, 2x"), + pytest.param(5, 1, id="5s delay, 1x"), + ], +) +def test_service_delay(delay: int, num_iterations: int) -> None: + """Verify that a delay is correctly added between iterations""" + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=num_iterations, + as_service=True, + cooldown=0, + loop_delay=delay, + ) + + ts0 = datetime.datetime.now() + service.execute() + ts1 = datetime.datetime.now() + + # the expected duration is the sum of the delay between each iteration + expected_duration = (num_iterations + 1) * delay + duration_in_seconds = (ts1 - ts0).total_seconds() + + assert duration_in_seconds <= expected_duration + assert service.num_cooldowns == 0 + assert service.num_shutdowns == 1 + + +@pytest.mark.parametrize( + "health_check_freq, run_for", + [ + pytest.param(1, 5.5, id="1s freq, 10x"), + pytest.param(5, 10.5, id="5s freq, 2x"), + pytest.param(0.1, 5.1, id="0.1s freq, 50x"), + ], +) +def test_service_health_check_freq(health_check_freq: float, run_for: float) -> None: + """Verify that a the health check frequency is honored + + :param health_check_freq: The desired frequency of the health check + :pram run_for: A fixed duration to allow the service to run + """ + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=-1, + as_service=True, + cooldown=0, + hc_freq=health_check_freq, + run_for=run_for, + ) + + ts0 = datetime.datetime.now() + service.execute() + ts1 = datetime.datetime.now() + + # the expected duration is the sum of the delay between each iteration + expected_hc_count = run_for // health_check_freq + + # allow some wiggle room for frequency comparison + assert expected_hc_count - 1 <= service.num_health_checks <= expected_hc_count + 1 + + assert service.num_cooldowns == 0 + assert service.num_shutdowns == 1 + + +def test_service_health_check_freq_unbound() -> None: + """Verify that a health check frequency of zero is treated as + "always on" and is called each loop iteration + + :param health_check_freq: The desired frequency of the health check + :pram run_for: A fixed duration to allow the service to run + """ + health_check_freq: float = 0.0 + run_for: float = 5 + + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=-1, + as_service=True, + cooldown=0, + hc_freq=health_check_freq, + run_for=run_for, + ) + + service.execute() + + # allow some wiggle room for frequency comparison + assert service.num_health_checks == service.num_iterations + assert service.num_cooldowns == 0 + assert service.num_shutdowns == 1 diff --git a/tests/mli/worker.py b/tests/mli/worker.py new file mode 100644 index 0000000000..0582cae566 --- /dev/null +++ b/tests/mli/worker.py @@ -0,0 +1,104 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import torch + +import smartsim._core.mli.infrastructure.worker.worker as mliw +import smartsim.error as sse +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class IntegratedTorchWorker(mliw.MachineLearningWorkerBase): + """A minimum implementation of a worker that executes a PyTorch model""" + + # @staticmethod + # def deserialize(request: InferenceRequest) -> t.List[t.Any]: + # # request.input_meta + # # request.raw_inputs + # return request + + @staticmethod + def load_model( + request: mliw.InferenceRequest, fetch_result: mliw.FetchModelResult, device: str + ) -> mliw.LoadModelResult: + model_bytes = fetch_result.model_bytes or request.raw_model + if not model_bytes: + raise ValueError("Unable to load model without reference object") + + model: torch.nn.Module = torch.load(io.BytesIO(model_bytes)) + result = mliw.LoadModelResult(model) + return result + + @staticmethod + def transform_input( + request: mliw.InferenceRequest, + fetch_result: mliw.FetchInputResult, + device: str, + ) -> mliw.TransformInputResult: + # extra metadata for assembly can be found in request.input_meta + raw_inputs = request.raw_inputs or fetch_result.inputs + + result: t.List[torch.Tensor] = [] + # should this happen here? + # consider - fortran to c data layout + # is there an intermediate representation before really doing torch.load? + if raw_inputs: + result = [torch.load(io.BytesIO(item)) for item in raw_inputs] + + return mliw.TransformInputResult(result) + + @staticmethod + def execute( + request: mliw.InferenceRequest, + load_result: mliw.LoadModelResult, + transform_result: mliw.TransformInputResult, + ) -> mliw.ExecuteResult: + if not load_result.model: + raise sse.SmartSimError("Model must be loaded to execute") + + model = load_result.model + results = [model(tensor) for tensor in transform_result.transformed] + + execute_result = mliw.ExecuteResult(results) + return execute_result + + @staticmethod + def transform_output( + request: mliw.InferenceRequest, + execute_result: mliw.ExecuteResult, + result_device: str, + ) -> mliw.TransformOutputResult: + # send the original tensors... + execute_result.predictions = [t.detach() for t in execute_result.predictions] + # todo: solve sending all tensor metadata that coincisdes with each prediction + return mliw.TransformOutputResult( + execute_result.predictions, [1], "c", "float32" + ) diff --git a/tests/on_wlm/test_dragon.py b/tests/on_wlm/test_dragon.py index a05d381415..1bef3cac8d 100644 --- a/tests/on_wlm/test_dragon.py +++ b/tests/on_wlm/test_dragon.py @@ -56,7 +56,7 @@ def test_dragon_global_path(global_dragon_teardown, wlmutils, test_dir, monkeypa def test_dragon_exp_path(global_dragon_teardown, wlmutils, test_dir, monkeypatch): monkeypatch.delenv("SMARTSIM_DRAGON_SERVER_PATH", raising=False) - monkeypatch.delenv("SMARTSIM_DRAGON_SERVER_PATH_EXP", raising=False) + monkeypatch.delenv("_SMARTSIM_DRAGON_SERVER_PATH_EXP", raising=False) exp: Experiment = Experiment( "test_dragon_connection", exp_path=test_dir, diff --git a/tests/test_config.py b/tests/test_config.py index 00a1fcdd36..5a84103ffd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -66,9 +66,9 @@ def get_redisai_env( """ env = os.environ.copy() if rai_path is not None: - env["RAI_PATH"] = rai_path + env["SMARTSIM_RAI_LIB"] = rai_path else: - env.pop("RAI_PATH", None) + env.pop("SMARTSIM_RAI_LIB", None) if lib_path is not None: env["SMARTSIM_DEP_INSTALL_PATH"] = lib_path @@ -85,7 +85,7 @@ def make_file(filepath: str) -> None: def test_redisai_invalid_rai_path(test_dir, monkeypatch): - """An invalid RAI_PATH and valid SMARTSIM_DEP_INSTALL_PATH should fail""" + """An invalid SMARTSIM_RAI_LIB and valid SMARTSIM_DEP_INSTALL_PATH should fail""" rai_file_path = os.path.join(test_dir, "lib", "mock-redisai.so") make_file(os.path.join(test_dir, "lib", "redisai.so")) @@ -94,7 +94,7 @@ def test_redisai_invalid_rai_path(test_dir, monkeypatch): config = Config() - # Fail when no file exists @ RAI_PATH + # Fail when no file exists @ SMARTSIM_RAI_LIB with pytest.raises(SSConfigError) as ex: _ = config.redisai @@ -102,7 +102,7 @@ def test_redisai_invalid_rai_path(test_dir, monkeypatch): def test_redisai_valid_rai_path(test_dir, monkeypatch): - """A valid RAI_PATH should override valid SMARTSIM_DEP_INSTALL_PATH and succeed""" + """A valid SMARTSIM_RAI_LIB should override valid SMARTSIM_DEP_INSTALL_PATH and succeed""" rai_file_path = os.path.join(test_dir, "lib", "mock-redisai.so") make_file(rai_file_path) @@ -117,7 +117,7 @@ def test_redisai_valid_rai_path(test_dir, monkeypatch): def test_redisai_invalid_lib_path(test_dir, monkeypatch): - """Invalid RAI_PATH and invalid SMARTSIM_DEP_INSTALL_PATH should fail""" + """Invalid SMARTSIM_RAI_LIB and invalid SMARTSIM_DEP_INSTALL_PATH should fail""" rai_file_path = f"{test_dir}/railib/redisai.so" @@ -133,7 +133,7 @@ def test_redisai_invalid_lib_path(test_dir, monkeypatch): def test_redisai_valid_lib_path(test_dir, monkeypatch): - """Valid RAI_PATH and invalid SMARTSIM_DEP_INSTALL_PATH should succeed""" + """Valid SMARTSIM_RAI_LIB and invalid SMARTSIM_DEP_INSTALL_PATH should succeed""" rai_file_path = os.path.join(test_dir, "lib", "mock-redisai.so") make_file(rai_file_path) @@ -147,7 +147,7 @@ def test_redisai_valid_lib_path(test_dir, monkeypatch): def test_redisai_valid_lib_path_null_rai(test_dir, monkeypatch): - """Missing RAI_PATH and valid SMARTSIM_DEP_INSTALL_PATH should succeed""" + """Missing SMARTSIM_RAI_LIB and valid SMARTSIM_DEP_INSTALL_PATH should succeed""" rai_file_path: t.Optional[str] = None lib_file_path = os.path.join(test_dir, "lib", "redisai.so") @@ -166,11 +166,11 @@ def test_redis_conf(): assert Path(config.database_conf).is_file() assert isinstance(config.database_conf, str) - os.environ["REDIS_CONF"] = "not/a/path" + os.environ["SMARTSIM_REDIS_CONF"] = "not/a/path" config = Config() with pytest.raises(SSConfigError): config.database_conf - os.environ.pop("REDIS_CONF") + os.environ.pop("SMARTSIM_REDIS_CONF") def test_redis_exe(): @@ -178,11 +178,11 @@ def test_redis_exe(): assert Path(config.database_exe).is_file() assert isinstance(config.database_exe, str) - os.environ["REDIS_PATH"] = "not/a/path" + os.environ["SMARTSIM_REDIS_SERVER_EXE"] = "not/a/path" config = Config() with pytest.raises(SSConfigError): config.database_exe - os.environ.pop("REDIS_PATH") + os.environ.pop("SMARTSIM_REDIS_SERVER_EXE") def test_redis_cli(): @@ -190,11 +190,11 @@ def test_redis_cli(): assert Path(config.redisai).is_file() assert isinstance(config.redisai, str) - os.environ["REDIS_CLI_PATH"] = "not/a/path" + os.environ["SMARTSIM_REDIS_CLI_EXE"] = "not/a/path" config = Config() with pytest.raises(SSConfigError): config.database_cli - os.environ.pop("REDIS_CLI_PATH") + os.environ.pop("SMARTSIM_REDIS_CLI_EXE") @pytest.mark.parametrize( diff --git a/tests/test_dragon_comm_utils.py b/tests/test_dragon_comm_utils.py new file mode 100644 index 0000000000..a6f9c206a4 --- /dev/null +++ b/tests/test_dragon_comm_utils.py @@ -0,0 +1,257 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import uuid + +import pytest + +from smartsim.error.errors import SmartSimError + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.channels as dch +import dragon.infrastructure.parameters as dp +import dragon.managed_memory as dm +import dragon.fli as fli + +# isort: on + +from smartsim._core.mli.comm.channel import dragon_util +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.fixture(scope="function") +def the_pool() -> dm.MemoryPool: + """Creates a memory pool.""" + raw_pool_descriptor = dp.this_process.default_pd + descriptor_ = base64.b64decode(raw_pool_descriptor) + + pool = dm.MemoryPool.attach(descriptor_) + return pool + + +@pytest.fixture(scope="function") +def the_channel() -> dch.Channel: + """Creates a Channel attached to the local memory pool.""" + channel = dch.Channel.make_process_local() + return channel + + +@pytest.fixture(scope="function") +def the_fli(the_channel) -> fli.FLInterface: + """Creates an FLI attached to the local memory pool.""" + fli_ = fli.FLInterface(main_ch=the_channel, manager_ch=None) + return fli_ + + +def test_descriptor_to_channel_empty() -> None: + """Verify that `descriptor_to_channel` raises an exception when + provided with an empty descriptor.""" + descriptor = "" + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "empty" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()], +) +def test_descriptor_to_channel_b64fail(descriptor: str) -> None: + """Verify that `descriptor_to_channel` raises an exception when + provided with an incorrectly encoded descriptor. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "base64" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + [str(uuid.uuid4())], +) +def test_descriptor_to_channel_channel_fail(descriptor: str) -> None: + """Verify that `descriptor_to_channel` raises an exception when a correctly + formatted descriptor that does not describe a real channel is passed. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + assert "channel" in ex.value.args[0] + + +def test_descriptor_to_channel_channel_not_available(the_channel: dch.Channel) -> None: + """Verify that `descriptor_to_channel` raises an exception when a channel + is no longer available. + + :param the_channel: A dragon channel + """ + + # get a good descriptor & wipe out the channel so it can't be attached + descriptor = dragon_util.channel_to_descriptor(the_channel) + the_channel.destroy() + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "address" in ex.value.args[0] + + +def test_descriptor_to_channel_happy_path(the_channel: dch.Channel) -> None: + """Verify that `descriptor_to_channel` works as expected when provided + a valid descriptor + + :param the_channel: A dragon channel + """ + + # get a good descriptor + descriptor = dragon_util.channel_to_descriptor(the_channel) + + reattached = dragon_util.descriptor_to_channel(descriptor) + assert reattached + + # and just make sure creation of the descriptor is transitive + assert dragon_util.channel_to_descriptor(reattached) == descriptor + + +def test_descriptor_to_fli_empty() -> None: + """Verify that `descriptor_to_fli` raises an exception when + provided with an empty descriptor.""" + descriptor = "" + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + assert "empty" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()], +) +def test_descriptor_to_fli_b64fail(descriptor: str) -> None: + """Verify that `descriptor_to_fli` raises an exception when + provided with an incorrectly encoded descriptor. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + assert "base64" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + [str(uuid.uuid4())], +) +def test_descriptor_to_fli_fli_fail(descriptor: str) -> None: + """Verify that `descriptor_to_fli` raises an exception when a correctly + formatted descriptor that does not describe a real FLI is passed. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + assert "fli" in ex.value.args[0].lower() + + +def test_descriptor_to_fli_fli_not_available( + the_fli: fli.FLInterface, the_channel: dch.Channel +) -> None: + """Verify that `descriptor_to_fli` raises an exception when a channel + is no longer available. + + :param the_fli: A dragon FLInterface + :param the_channel: A dragon channel + """ + + # get a good descriptor & wipe out the FLI so it can't be attached + descriptor = dragon_util.channel_to_descriptor(the_fli) + the_fli.destroy() + the_channel.destroy() + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + + +def test_descriptor_to_fli_happy_path(the_fli: dch.Channel) -> None: + """Verify that `descriptor_to_fli` works as expected when provided + a valid descriptor + + :param the_fli: A dragon FLInterface + """ + + # get a good descriptor + descriptor = dragon_util.channel_to_descriptor(the_fli) + + reattached = dragon_util.descriptor_to_fli(descriptor) + assert reattached + + # and just make sure creation of the descriptor is transitive + assert dragon_util.channel_to_descriptor(reattached) == descriptor + + +def test_pool_to_descriptor_empty() -> None: + """Verify that `pool_to_descriptor` raises an exception when + provided with a null pool.""" + + with pytest.raises(ValueError) as ex: + dragon_util.pool_to_descriptor(None) + + +def test_pool_to_happy_path(the_pool) -> None: + """Verify that `pool_to_descriptor` creates a descriptor + when supplied with a valid memory pool.""" + + descriptor = dragon_util.pool_to_descriptor(the_pool) + assert descriptor diff --git a/tests/test_dragon_installer.py b/tests/test_dragon_installer.py index b23a1a7ef0..a58d711721 100644 --- a/tests/test_dragon_installer.py +++ b/tests/test_dragon_installer.py @@ -31,12 +31,17 @@ from collections import namedtuple import pytest +from github.GitRelease import GitRelease from github.GitReleaseAsset import GitReleaseAsset from github.Requester import Requester import smartsim +import smartsim._core._install.utils import smartsim._core.utils.helpers as helpers from smartsim._core._cli.scripts.dragon_install import ( + DEFAULT_DRAGON_REPO, + DEFAULT_DRAGON_VERSION, + DragonInstallRequest, cleanup, create_dotenv, install_dragon, @@ -58,14 +63,25 @@ def test_archive(test_dir: str, archive_path: pathlib.Path) -> pathlib.Path: """Fixture for returning a simple tarfile to test on""" num_files = 10 + + archive_name = archive_path.name + archive_name = archive_name.replace(".tar.gz", "") + with tarfile.TarFile.open(archive_path, mode="w:gz") as tar: - mock_whl = pathlib.Path(test_dir) / "mock.whl" + mock_whl = pathlib.Path(test_dir) / archive_name / f"{archive_name}.whl" + mock_whl.parent.mkdir(parents=True, exist_ok=True) mock_whl.touch() + tar.add(mock_whl) + for i in range(num_files): - content = pathlib.Path(test_dir) / f"{i:04}.txt" + content = pathlib.Path(test_dir) / archive_name / f"{i:04}.txt" content.write_text(f"i am file {i}\n") tar.add(content) + content.unlink() + + mock_whl.unlink() + return archive_path @@ -118,11 +134,41 @@ def test_assets(monkeypatch: pytest.MonkeyPatch) -> t.Dict[str, GitReleaseAsset] _git_attr(value=f"http://foo/{archive_name}"), ) monkeypatch.setattr(asset, "_name", _git_attr(value=archive_name)) + monkeypatch.setattr(asset, "_id", _git_attr(value=123)) assets.append(asset) return assets +@pytest.fixture +def test_releases(monkeypatch: pytest.MonkeyPatch) -> t.Dict[str, GitRelease]: + requester = Requester( + auth=None, + base_url="https://github.com", + user_agent="mozilla", + per_page=10, + verify=False, + timeout=1, + retry=1, + pool_size=1, + ) + headers = {"mock-header": "mock-value"} + attributes = {"title": "mock-title"} + completed = True + + releases: t.List[GitRelease] = [] + + for python_version in ["py3.9", "py3.10", "py3.11"]: + for dragon_version in ["dragon-0.8", "dragon-0.9", "dragon-0.10"]: + attributes = { + "title": f"{python_version}-{dragon_version}-release", + "tag_name": f"v{dragon_version}-weekly", + } + releases.append(GitRelease(requester, headers, attributes, completed)) + + return releases + + def test_cleanup_no_op(archive_path: pathlib.Path) -> None: """Ensure that the cleanup method doesn't bomb when called with missing archive path; simulate a failed download""" @@ -143,17 +189,25 @@ def test_cleanup_archive_exists(test_archive: pathlib.Path) -> None: assert not test_archive.exists() -def test_retrieve_cached( - test_dir: str, - # archive_path: pathlib.Path, +@pytest.mark.skip("Deprecated due to builder.py changes") +def test_retrieve_updated( test_archive: pathlib.Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - """Verify that a previously retrieved asset archive is re-used""" - with tarfile.TarFile.open(test_archive) as tar: - tar.extractall(test_dir) + """Verify that a previously retrieved asset archive is not re-used if a new + version is found""" + + old_asset_id = 100 + asset_id = 123 - ts1 = test_archive.parent.stat().st_ctime + def mock__retrieve_archive(source_, destination_) -> None: + mock_extraction_dir = pathlib.Path(destination_) + with tarfile.TarFile.open(test_archive) as tar: + tar.extractall(mock_extraction_dir) + + # we'll use the mock extract to create the files that would normally be downloaded + expected_output_dir = test_archive.parent / str(asset_id) + old_output_dir = test_archive.parent / str(old_asset_id) requester = Requester( auth=None, @@ -174,14 +228,22 @@ def test_retrieve_cached( # ensure mocked asset has values that we use... monkeypatch.setattr(asset, "_browser_download_url", _git_attr(value="http://foo")) monkeypatch.setattr(asset, "_name", _git_attr(value=mock_archive_name)) + monkeypatch.setattr(asset, "_id", _git_attr(value=asset_id)) + monkeypatch.setattr( + smartsim._core._install.utils, + "retrieve", + lambda s_, d_: mock__retrieve_archive(s_, expected_output_dir), + ) # mock the retrieval of the updated archive + + # tell it to retrieve. it should return the path to the new download, not the old one + request = DragonInstallRequest(test_archive.parent) + asset_path = retrieve_asset(request, asset) - asset_path = retrieve_asset(test_archive.parent, asset) - ts2 = asset_path.stat().st_ctime + # sanity check we don't have the same paths + assert old_output_dir != expected_output_dir - assert ( - asset_path == test_archive.parent - ) # show that the expected path matches the output path - assert ts1 == ts2 # show that the file wasn't changed... + # verify the "cached" copy wasn't used + assert asset_path == expected_output_dir @pytest.mark.parametrize( @@ -214,11 +276,13 @@ def test_retrieve_cached( ) def test_retrieve_asset_info( test_assets: t.Collection[GitReleaseAsset], + test_releases: t.Collection[GitRelease], monkeypatch: pytest.MonkeyPatch, dragon_pin: str, pyv: str, is_found: bool, is_crayex: bool, + test_dir: str, ) -> None: """Verify that an information is retrieved correctly based on the python version, platform (e.g. CrayEX, !CrayEx), and target dragon pin""" @@ -234,20 +298,23 @@ def test_retrieve_asset_info( "is_crayex_platform", lambda: is_crayex, ) + # avoid hitting github API ctx.setattr( smartsim._core._cli.scripts.dragon_install, - "dragon_pin", - lambda: dragon_pin, + "_get_all_releases", + lambda x: test_releases, ) # avoid hitting github API ctx.setattr( smartsim._core._cli.scripts.dragon_install, "_get_release_assets", - lambda: test_assets, + lambda x: test_assets, ) + request = DragonInstallRequest(test_dir, version=dragon_pin) + if is_found: - chosen_asset = retrieve_asset_info() + chosen_asset = retrieve_asset_info(request) assert chosen_asset assert pyv in chosen_asset.name @@ -259,7 +326,7 @@ def test_retrieve_asset_info( assert "crayex" not in chosen_asset.name.lower() else: with pytest.raises(SmartSimCLIActionCancelled): - retrieve_asset_info() + retrieve_asset_info(request) def test_check_for_utility_missing(test_dir: str) -> None: @@ -357,11 +424,12 @@ def mock_util_check(util: str) -> bool: assert is_cray == platform_result -def test_install_package_no_wheel(extraction_dir: pathlib.Path): +def test_install_package_no_wheel(test_dir: str, extraction_dir: pathlib.Path): """Verify that a missing wheel does not blow up and has a failure retcode""" exp_path = extraction_dir + request = DragonInstallRequest(test_dir) - result = install_package(exp_path) + result = install_package(request, exp_path) assert result != 0 @@ -370,7 +438,9 @@ def test_install_macos(monkeypatch: pytest.MonkeyPatch, extraction_dir: pathlib. with monkeypatch.context() as ctx: ctx.setattr(sys, "platform", "darwin") - result = install_dragon(extraction_dir) + request = DragonInstallRequest(extraction_dir) + + result = install_dragon(request) assert result == 1 @@ -387,7 +457,7 @@ def test_create_dotenv(monkeypatch: pytest.MonkeyPatch, test_dir: str): # ensure no .env exists before trying to create it. assert not exp_env_path.exists() - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # ensure the .env is created as side-effect of create_dotenv assert exp_env_path.exists() @@ -409,7 +479,7 @@ def test_create_dotenv_existing_dir(monkeypatch: pytest.MonkeyPatch, test_dir: s # ensure no .env exists before trying to create it. assert not exp_env_path.exists() - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # ensure the .env is created as side-effect of create_dotenv assert exp_env_path.exists() @@ -434,17 +504,25 @@ def test_create_dotenv_existing_dotenv(monkeypatch: pytest.MonkeyPatch, test_dir # ensure .env exists so we can update it assert exp_env_path.exists() - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # ensure the .env is created as side-effect of create_dotenv assert exp_env_path.exists() # ensure file was overwritten and env vars are not duplicated dotenv_content = exp_env_path.read_text(encoding="utf-8") - split_content = dotenv_content.split(var_name) - - # split to confirm env var only appars once - assert len(split_content) == 2 + lines = [ + line for line in dotenv_content.split("\n") if line and not "#" in line + ] + for line in lines: + if line.startswith(var_name): + # make sure the var isn't defined recursively + # DRAGON_BASE_DIR=$DRAGON_BASE_DIR + assert var_name not in line[len(var_name) + 1 :] + else: + # make sure any values reference the original base dir var + if var_name in line: + assert f"${var_name}" in line def test_create_dotenv_format(monkeypatch: pytest.MonkeyPatch, test_dir: str): @@ -456,13 +534,13 @@ def test_create_dotenv_format(monkeypatch: pytest.MonkeyPatch, test_dir: str): with monkeypatch.context() as ctx: ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path) - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # ensure the .env is created as side-effect of create_dotenv content = exp_env_path.read_text(encoding="utf-8") # ensure we have values written, but ignore empty lines - lines = [line for line in content.split("\n") if line] + lines = [line for line in content.split("\n") if line and not "#" in line] assert lines # ensure each line is formatted as key=value diff --git a/tests/test_dragon_launcher.py b/tests/test_dragon_launcher.py index 4bd07e920c..a894757918 100644 --- a/tests/test_dragon_launcher.py +++ b/tests/test_dragon_launcher.py @@ -37,7 +37,10 @@ import zmq import smartsim._core.config -from smartsim._core._cli.scripts.dragon_install import create_dotenv +from smartsim._core._cli.scripts.dragon_install import ( + DEFAULT_DRAGON_VERSION, + create_dotenv, +) from smartsim._core.config.config import get_config from smartsim._core.launcher.dragon.dragonLauncher import ( DragonConnector, @@ -494,7 +497,7 @@ def test_load_env_env_file_created(monkeypatch: pytest.MonkeyPatch, test_dir: st with monkeypatch.context() as ctx: ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path) - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) dragon_conf = smartsim._core.config.CONFIG.dragon_dotenv # verify config does exist @@ -507,7 +510,26 @@ def test_load_env_env_file_created(monkeypatch: pytest.MonkeyPatch, test_dir: st assert loaded_env # confirm .env was parsed as expected by inspecting a key + assert "DRAGON_BASE_DIR" in loaded_env + base_dir = loaded_env["DRAGON_BASE_DIR"] + assert "DRAGON_ROOT_DIR" in loaded_env + assert loaded_env["DRAGON_ROOT_DIR"] == base_dir + + assert "DRAGON_INCLUDE_DIR" in loaded_env + assert loaded_env["DRAGON_INCLUDE_DIR"] == f"{base_dir}/include" + + assert "DRAGON_LIB_DIR" in loaded_env + assert loaded_env["DRAGON_LIB_DIR"] == f"{base_dir}/lib" + + assert "DRAGON_VERSION" in loaded_env + assert loaded_env["DRAGON_VERSION"] == DEFAULT_DRAGON_VERSION + + assert "PATH" in loaded_env + assert loaded_env["PATH"] == f"{base_dir}/bin" + + assert "LD_LIBRARY_PATH" in loaded_env + assert loaded_env["LD_LIBRARY_PATH"] == f"{base_dir}/lib" def test_load_env_cached_env(monkeypatch: pytest.MonkeyPatch, test_dir: str): @@ -517,7 +539,7 @@ def test_load_env_cached_env(monkeypatch: pytest.MonkeyPatch, test_dir: str): with monkeypatch.context() as ctx: ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path) - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # load config w/launcher connector = DragonConnector() @@ -541,7 +563,7 @@ def test_merge_env(monkeypatch: pytest.MonkeyPatch, test_dir: str): with monkeypatch.context() as ctx: ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path) - create_dotenv(mock_dragon_root) + create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION) # load config w/launcher connector = DragonConnector() diff --git a/tests/test_dragon_run_policy.py b/tests/test_dragon_run_policy.py index 1d8d069fab..5e8642c052 100644 --- a/tests/test_dragon_run_policy.py +++ b/tests/test_dragon_run_policy.py @@ -114,9 +114,6 @@ def test_create_run_policy_non_run_request(dragon_request: DragonRequest) -> Non policy = DragonBackend.create_run_policy(dragon_request, "localhost") assert policy is not None, "Default policy was not returned" - assert ( - policy.device == Policy.Device.DEFAULT - ), "Default device was not Device.DEFAULT" assert policy.cpu_affinity == [], "Default cpu affinity was not empty" assert policy.gpu_affinity == [], "Default gpu affinity was not empty" @@ -140,10 +137,8 @@ def test_create_run_policy_run_request_no_run_policy() -> None: policy = DragonBackend.create_run_policy(run_req, "localhost") - assert policy.device == Policy.Device.DEFAULT assert set(policy.cpu_affinity) == set() assert policy.gpu_affinity == [] - assert policy.affinity == Policy.Affinity.DEFAULT @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @@ -167,7 +162,6 @@ def test_create_run_policy_run_request_default_run_policy() -> None: assert set(policy.cpu_affinity) == set() assert set(policy.gpu_affinity) == set() - assert policy.affinity == Policy.Affinity.DEFAULT @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @@ -192,7 +186,6 @@ def test_create_run_policy_run_request_cpu_affinity_no_device() -> None: assert set(policy.cpu_affinity) == affinity assert policy.gpu_affinity == [] - assert policy.affinity == Policy.Affinity.SPECIFIC @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @@ -216,7 +209,6 @@ def test_create_run_policy_run_request_cpu_affinity() -> None: assert set(policy.cpu_affinity) == affinity assert policy.gpu_affinity == [] - assert policy.affinity == Policy.Affinity.SPECIFIC @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @@ -240,7 +232,6 @@ def test_create_run_policy_run_request_gpu_affinity() -> None: assert policy.cpu_affinity == [] assert set(policy.gpu_affinity) == set(affinity) - assert policy.affinity == Policy.Affinity.SPECIFIC @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") diff --git a/tests/test_dragon_run_request.py b/tests/test_dragon_run_request.py index 7514deab19..62ac572eb2 100644 --- a/tests/test_dragon_run_request.py +++ b/tests/test_dragon_run_request.py @@ -30,18 +30,14 @@ import time from unittest.mock import MagicMock +import pydantic.error_wrappers import pytest -from pydantic import ValidationError + +from smartsim._core.launcher.dragon.pqueue import NodePrioritizer # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b - -try: - import dragon - - dragon_loaded = True -except: - dragon_loaded = False +dragon = pytest.importorskip("dragon") from smartsim._core.config import CONFIG from smartsim._core.schemas.dragonRequests import * @@ -56,38 +52,6 @@ ) -class NodeMock(MagicMock): - def __init__( - self, name: t.Optional[str] = None, num_gpus: int = 2, num_cpus: int = 8 - ) -> None: - super().__init__() - self._mock_id = name - NodeMock._num_gpus = num_gpus - NodeMock._num_cpus = num_cpus - - @property - def hostname(self) -> str: - if self._mock_id: - return self._mock_id - return create_short_id_str() - - @property - def num_cpus(self) -> str: - return NodeMock._num_cpus - - @property - def num_gpus(self) -> str: - return NodeMock._num_gpus - - def _set_id(self, value: str) -> None: - self._mock_id = value - - def gpus(self, parent: t.Any = None) -> t.List[str]: - if self._num_gpus: - return [f"{self.hostname}-gpu{i}" for i in range(NodeMock._num_gpus)] - return [] - - class GroupStateMock(MagicMock): def Running(self) -> MagicMock: running = MagicMock(**{"__str__.return_value": "Running"}) @@ -102,59 +66,59 @@ class ProcessGroupMock(MagicMock): puids = [121, 122] -def node_mock() -> NodeMock: - return NodeMock() - - def get_mock_backend( - monkeypatch: pytest.MonkeyPatch, num_gpus: int = 2 + monkeypatch: pytest.MonkeyPatch, num_cpus: int, num_gpus: int ) -> "DragonBackend": - + # create all the necessary namespaces as raw magic mocks + monkeypatch.setitem(sys.modules, "dragon.data.ddict.ddict", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.native.machine", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.native.group_state", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.native.process_group", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.native.process", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.infrastructure.connection", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.infrastructure.policy", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.infrastructure.process_desc", MagicMock()) + monkeypatch.setitem(sys.modules, "dragon.data.ddict.ddict", MagicMock()) + + node_list = ["node1", "node2", "node3"] + system_mock = MagicMock(return_value=MagicMock(nodes=node_list)) + node_mock = lambda x: MagicMock(hostname=x, num_cpus=num_cpus, num_gpus=num_gpus) + process_group_mock = MagicMock(return_value=ProcessGroupMock()) process_mock = MagicMock(returncode=0) - process_group_mock = MagicMock(**{"Process.return_value": ProcessGroupMock()}) - process_module_mock = MagicMock() - process_module_mock.Process = process_mock - node_mock = NodeMock(num_gpus=num_gpus) - system_mock = MagicMock(nodes=["node1", "node2", "node3"]) + policy_mock = MagicMock(return_value=MagicMock()) + group_state_mock = GroupStateMock() + + # customize members that must perform specific actions within the namespaces monkeypatch.setitem( sys.modules, "dragon", MagicMock( **{ - "native.machine.Node.return_value": node_mock, - "native.machine.System.return_value": system_mock, - "native.group_state": GroupStateMock(), - "native.process_group.ProcessGroup.return_value": ProcessGroupMock(), + "native.machine.Node": node_mock, + "native.machine.System": system_mock, + "native.group_state": group_state_mock, + "native.process_group.ProcessGroup": process_group_mock, + "native.process_group.Process": process_mock, + "native.process.Process": process_mock, + "infrastructure.policy.Policy": policy_mock, } ), ) - monkeypatch.setitem( - sys.modules, - "dragon.infrastructure.connection", - MagicMock(), - ) - monkeypatch.setitem( - sys.modules, - "dragon.infrastructure.policy", - MagicMock(**{"Policy.return_value": MagicMock()}), - ) - monkeypatch.setitem(sys.modules, "dragon.native.process", process_module_mock) - monkeypatch.setitem(sys.modules, "dragon.native.process_group", process_group_mock) - monkeypatch.setitem(sys.modules, "dragon.native.group_state", GroupStateMock()) - monkeypatch.setitem( - sys.modules, - "dragon.native.machine", - MagicMock( - **{"System.return_value": system_mock, "Node.return_value": node_mock} - ), - ) from smartsim._core.launcher.dragon.dragonBackend import DragonBackend dragon_backend = DragonBackend(pid=99999) - monkeypatch.setattr( - dragon_backend, "_free_hosts", collections.deque(dragon_backend._hosts) + + # NOTE: we're manually updating these values due to issue w/mocking namespaces + dragon_backend._prioritizer = NodePrioritizer( + [ + MagicMock(num_cpus=num_cpus, num_gpus=num_gpus, hostname=node) + for node in node_list + ], + dragon_backend._queue_lock, ) + dragon_backend._cpus = [num_cpus] * len(node_list) + dragon_backend._gpus = [num_gpus] * len(node_list) return dragon_backend @@ -212,16 +176,14 @@ def set_mock_group_infos( } monkeypatch.setattr(dragon_backend, "_group_infos", group_infos) - monkeypatch.setattr(dragon_backend, "_free_hosts", collections.deque(hosts[1:3])) - monkeypatch.setattr(dragon_backend, "_allocated_hosts", {hosts[0]: "abc123-1"}) + monkeypatch.setattr(dragon_backend, "_allocated_hosts", {hosts[0]: {"abc123-1"}}) monkeypatch.setattr(dragon_backend, "_running_steps", ["abc123-1"]) return group_infos -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_handshake_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) handshake_req = DragonHandshakeRequest() handshake_resp = dragon_backend.process_request(handshake_req) @@ -230,9 +192,8 @@ def test_handshake_request(monkeypatch: pytest.MonkeyPatch) -> None: assert handshake_resp.dragon_pid == 99999 -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -259,9 +220,9 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None: assert dragon_backend._running_steps == [step_id] assert len(dragon_backend._queued_steps) == 0 - assert len(dragon_backend._free_hosts) == 1 - assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id - assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id + assert len(dragon_backend.free_hosts) == 1 + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] monkeypatch.setattr( dragon_backend._group_infos[step_id].process_group, "status", "Running" @@ -271,9 +232,9 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None: assert dragon_backend._running_steps == [step_id] assert len(dragon_backend._queued_steps) == 0 - assert len(dragon_backend._free_hosts) == 1 - assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id - assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id + assert len(dragon_backend.free_hosts) == 1 + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] dragon_backend._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED @@ -281,9 +242,8 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None: assert not dragon_backend._running_steps -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_deny_run_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) dragon_backend._shutdown_requested = True @@ -309,7 +269,7 @@ def test_deny_run_request(monkeypatch: pytest.MonkeyPatch) -> None: def test_run_request_with_empty_policy(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that a policy is applied to a run request""" - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -325,10 +285,9 @@ def test_run_request_with_empty_policy(monkeypatch: pytest.MonkeyPatch) -> None: assert run_req.policy is None -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that a policy is applied to a run request""" - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -356,9 +315,9 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None: assert dragon_backend._running_steps == [step_id] assert len(dragon_backend._queued_steps) == 0 - assert len(dragon_backend._free_hosts) == 1 - assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id - assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id + assert len(dragon_backend._prioritizer.unassigned()) == 1 + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] monkeypatch.setattr( dragon_backend._group_infos[step_id].process_group, "status", "Running" @@ -368,9 +327,9 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None: assert dragon_backend._running_steps == [step_id] assert len(dragon_backend._queued_steps) == 0 - assert len(dragon_backend._free_hosts) == 1 - assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id - assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id + assert len(dragon_backend._prioritizer.unassigned()) == 1 + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] + assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] dragon_backend._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED @@ -378,9 +337,8 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None: assert not dragon_backend._running_steps -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_udpate_status_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) group_infos = set_mock_group_infos(monkeypatch, dragon_backend) @@ -395,9 +353,8 @@ def test_udpate_status_request(monkeypatch: pytest.MonkeyPatch) -> None: } -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_stop_request(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) group_infos = set_mock_group_infos(monkeypatch, dragon_backend) running_steps = [ @@ -424,10 +381,9 @@ def test_stop_request(monkeypatch: pytest.MonkeyPatch) -> None: ) assert len(dragon_backend._allocated_hosts) == 0 - assert len(dragon_backend._free_hosts) == 3 + assert len(dragon_backend._prioritizer.unassigned()) == 3 -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize( "immediate, kill_jobs, frontend_shutdown", [ @@ -446,7 +402,7 @@ def test_shutdown_request( frontend_shutdown: bool, ) -> None: monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", "0") - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) monkeypatch.setattr(dragon_backend, "_cooldown_period", 1) set_mock_group_infos(monkeypatch, dragon_backend) @@ -486,11 +442,10 @@ def test_shutdown_request( assert dragon_backend._has_cooled_down == kill_jobs -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("telemetry_flag", ["0", "1"]) def test_cooldown_is_set(monkeypatch: pytest.MonkeyPatch, telemetry_flag: str) -> None: monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", telemetry_flag) - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) expected_cooldown = ( 2 * CONFIG.telemetry_frequency + 5 if int(telemetry_flag) > 0 else 5 @@ -502,19 +457,17 @@ def test_cooldown_is_set(monkeypatch: pytest.MonkeyPatch, telemetry_flag: str) - assert dragon_backend.cooldown_period == expected_cooldown -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_heartbeat_and_time(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) first_heartbeat = dragon_backend.last_heartbeat assert dragon_backend.current_time > first_heartbeat dragon_backend._heartbeat() assert dragon_backend.last_heartbeat > first_heartbeat -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("num_nodes", [1, 3, 100]) def test_can_honor(monkeypatch: pytest.MonkeyPatch, num_nodes: int) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -527,18 +480,42 @@ def test_can_honor(monkeypatch: pytest.MonkeyPatch, num_nodes: int) -> None: pmi_enabled=False, ) - assert dragon_backend._can_honor(run_req)[0] == ( - num_nodes <= len(dragon_backend._hosts) - ) + can_honor, error_msg = dragon_backend._can_honor(run_req) + + nodes_in_range = num_nodes <= len(dragon_backend._hosts) + assert can_honor == nodes_in_range + assert error_msg is None if nodes_in_range else error_msg is not None + + +@pytest.mark.parametrize("num_nodes", [-10, -1, 0]) +def test_can_honor_invalid_num_nodes( + monkeypatch: pytest.MonkeyPatch, num_nodes: int +) -> None: + """Verify that requests for invalid numbers of nodes (negative, zero) are rejected""" + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) + + with pytest.raises(pydantic.error_wrappers.ValidationError) as ex: + DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=num_nodes, + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + ) -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("affinity", [[0], [0, 1], list(range(8))]) def test_can_honor_cpu_affinity( monkeypatch: pytest.MonkeyPatch, affinity: t.List[int] ) -> None: """Verify that valid CPU affinities are accepted""" - dragon_backend = get_mock_backend(monkeypatch) + num_cpus, num_gpus = 8, 0 + dragon_backend = get_mock_backend(monkeypatch, num_cpus=num_cpus, num_gpus=num_gpus) + run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -555,11 +532,10 @@ def test_can_honor_cpu_affinity( assert dragon_backend._can_honor(run_req)[0] -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_can_honor_cpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that invalid CPU affinities are NOT accepted NOTE: negative values are captured by the Pydantic schema""" - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -576,13 +552,15 @@ def test_can_honor_cpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> assert not dragon_backend._can_honor(run_req)[0] -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("affinity", [[0], [0, 1]]) def test_can_honor_gpu_affinity( monkeypatch: pytest.MonkeyPatch, affinity: t.List[int] ) -> None: """Verify that valid GPU affinities are accepted""" - dragon_backend = get_mock_backend(monkeypatch) + + num_cpus, num_gpus = 8, 2 + dragon_backend = get_mock_backend(monkeypatch, num_cpus=num_cpus, num_gpus=num_gpus) + run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -599,11 +577,10 @@ def test_can_honor_gpu_affinity( assert dragon_backend._can_honor(run_req)[0] -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_can_honor_gpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that invalid GPU affinities are NOT accepted NOTE: negative values are captured by the Pydantic schema""" - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) run_req = DragonRunRequest( exe="sleep", exe_args=["5"], @@ -620,46 +597,45 @@ def test_can_honor_gpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> assert not dragon_backend._can_honor(run_req)[0] -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_can_honor_gpu_device_not_available(monkeypatch: pytest.MonkeyPatch) -> None: """Verify that a request for a GPU if none exists is not accepted""" # create a mock node class that always reports no GPUs available - dragon_backend = get_mock_backend(monkeypatch, num_gpus=0) - - run_req = DragonRunRequest( - exe="sleep", - exe_args=["5"], - path="/a/fake/path", - nodes=2, - tasks=1, - tasks_per_node=1, - env={}, - current_env={}, - pmi_enabled=False, - # specify GPU device w/no affinity - policy=DragonRunPolicy(gpu_affinity=[0]), - ) - - assert not dragon_backend._can_honor(run_req)[0] + with monkeypatch.context() as ctx: + dragon_backend = get_mock_backend(ctx, num_cpus=8, num_gpus=0) + + run_req = DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=2, + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + # specify GPU device w/no affinity + policy=DragonRunPolicy(gpu_affinity=[0]), + ) + can_honor, _ = dragon_backend._can_honor(run_req) + assert not can_honor -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_get_id(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) step_id = next(dragon_backend._step_ids) assert step_id.endswith("0") assert step_id != next(dragon_backend._step_ids) -@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") def test_view(monkeypatch: pytest.MonkeyPatch) -> None: - dragon_backend = get_mock_backend(monkeypatch) + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) set_mock_group_infos(monkeypatch, dragon_backend) hosts = dragon_backend.hosts + dragon_backend._prioritizer.increment(hosts[0]) - expected_message = textwrap.dedent(f"""\ + expected_msg = textwrap.dedent(f"""\ Dragon server backend update | Host | Status | |--------|----------| @@ -667,7 +643,7 @@ def test_view(monkeypatch: pytest.MonkeyPatch) -> None: | {hosts[1]} | Free | | {hosts[2]} | Free | | Step | Status | Hosts | Return codes | Num procs | - |----------|--------------|-------------|----------------|-------------| + |----------|--------------|-----------------|----------------|-------------| | abc123-1 | Running | {hosts[0]} | | 1 | | del999-2 | Cancelled | {hosts[1]} | -9 | 1 | | c101vz-3 | Completed | {hosts[1]},{hosts[2]} | 0 | 2 | @@ -676,6 +652,110 @@ def test_view(monkeypatch: pytest.MonkeyPatch) -> None: # get rid of white space to make the comparison easier actual_msg = dragon_backend.status_message.replace(" ", "") - expected_message = expected_message.replace(" ", "") + expected_msg = expected_msg.replace(" ", "") + + # ignore dashes in separators (hostname changes may cause column expansion) + while actual_msg.find("--") > -1: + actual_msg = actual_msg.replace("--", "-") + while expected_msg.find("--") > -1: + expected_msg = expected_msg.replace("--", "-") + + assert actual_msg == expected_msg + + +def test_can_honor_hosts_unavailable_hosts(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that requesting nodes with invalid names causes number of available + nodes check to fail due to valid # of named nodes being under num_nodes""" + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) + + # let's supply 2 invalid and 1 valid hostname + actual_hosts = list(dragon_backend._hosts) + actual_hosts[0] = f"x{actual_hosts[0]}" + actual_hosts[1] = f"x{actual_hosts[1]}" + + host_list = ",".join(actual_hosts) + + run_req = DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=2, # <----- requesting 2 of 3 available nodes + hostlist=host_list, # <--- only one valid name available + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + policy=DragonRunPolicy(), + ) + + can_honor, error_msg = dragon_backend._can_honor(run_req) + + # confirm the failure is indicated + assert not can_honor + # confirm failure message indicates number of nodes requested as cause + assert "named hosts" in error_msg + + +def test_can_honor_hosts_unavailable_hosts_ok(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that requesting nodes with invalid names causes number of available + nodes check to be reduced but still passes if enough valid named nodes are passed""" + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) + + # let's supply 2 valid and 1 invalid hostname + actual_hosts = list(dragon_backend._hosts) + actual_hosts[0] = f"x{actual_hosts[0]}" + + host_list = ",".join(actual_hosts) + + run_req = DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=2, # <----- requesting 2 of 3 available nodes + hostlist=host_list, # <--- two valid names are available + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + policy=DragonRunPolicy(), + ) + + can_honor, error_msg = dragon_backend._can_honor(run_req) + + # confirm the failure is indicated + assert can_honor, error_msg + # confirm failure message indicates number of nodes requested as cause + assert error_msg is None, error_msg + + +def test_can_honor_hosts_1_hosts_requested(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that requesting nodes with invalid names causes number of available + nodes check to be reduced but still passes if enough valid named nodes are passed""" + dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0) + + # let's supply 2 valid and 1 invalid hostname + actual_hosts = list(dragon_backend._hosts) + actual_hosts[0] = f"x{actual_hosts[0]}" + + host_list = ",".join(actual_hosts) + + run_req = DragonRunRequest( + exe="sleep", + exe_args=["5"], + path="/a/fake/path", + nodes=1, # <----- requesting 0 nodes - should be ignored + hostlist=host_list, # <--- two valid names are available + tasks=1, + tasks_per_node=1, + env={}, + current_env={}, + pmi_enabled=False, + policy=DragonRunPolicy(), + ) + + can_honor, error_msg = dragon_backend._can_honor(run_req) - assert actual_msg == expected_message + # confirm the failure is indicated + assert can_honor, error_msg diff --git a/tests/test_dragon_runsettings.py b/tests/test_dragon_runsettings.py index 34e8510e82..8c7600c74c 100644 --- a/tests/test_dragon_runsettings.py +++ b/tests/test_dragon_runsettings.py @@ -96,3 +96,122 @@ def test_dragon_runsettings_gpu_affinity(): # ensure the value is not changed when we extend the list rs.run_args["gpu-affinity"] = "7,8,9" assert rs.run_args["gpu-affinity"] != ",".join(str(val) for val in exp_value) + + +def test_dragon_runsettings_hostlist_null(): + """Verify that passing a null hostlist is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + with pytest.raises(ValueError) as ex: + rs.set_hostlist(None) + + assert "empty hostlist" in ex.value.args[0] + + +def test_dragon_runsettings_hostlist_empty(): + """Verify that passing an empty hostlist is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + with pytest.raises(ValueError) as ex: + rs.set_hostlist([]) + + assert "empty hostlist" in ex.value.args[0] + + +@pytest.mark.parametrize("hostlist_csv", [" ", " , , , ", ",", ",,,"]) +def test_dragon_runsettings_hostlist_whitespace_handling(hostlist_csv: str): + """Verify that passing a hostlist with emptystring host names is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + # empty string as hostname in list + with pytest.raises(ValueError) as ex: + rs.set_hostlist(hostlist_csv) + + assert "invalid names" in ex.value.args[0] + + +@pytest.mark.parametrize( + "hostlist_csv", [[" "], [" ", "", " ", " "], ["", " "], ["", "", "", ""]] +) +def test_dragon_runsettings_hostlist_whitespace_handling_list(hostlist_csv: str): + """Verify that passing a hostlist with emptystring host names contained in a list + is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + # empty string as hostname in list + with pytest.raises(ValueError) as ex: + rs.set_hostlist(hostlist_csv) + + assert "invalid names" in ex.value.args[0] + + +def test_dragon_runsettings_hostlist_as_csv(): + """Verify that a hostlist is stored properly when passing in a CSV string""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + hostnames = ["host0", "host1", "host2", "host3", "host4"] + + # set the host list with ideal comma separated values + input0 = ",".join(hostnames) + + # set the host list with a string of comma separated values + # including extra whitespace + input1 = ", ".join(hostnames) + + for hosts_input in [input0, input1]: + rs.set_hostlist(hosts_input) + + stored_list = rs.run_args.get("host-list", None) + assert stored_list + + # confirm that all values from the original list are retrieved + split_stored_list = stored_list.split(",") + assert set(hostnames) == set(split_stored_list) + + +def test_dragon_runsettings_hostlist_as_csv(): + """Verify that a hostlist is stored properly when passing in a CSV string""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + hostnames = ["host0", "host1", "host2", "host3", "host4"] + + # set the host list with ideal comma separated values + input0 = ",".join(hostnames) + + # set the host list with a string of comma separated values + # including extra whitespace + input1 = ", ".join(hostnames) + + for hosts_input in [input0, input1]: + rs.set_hostlist(hosts_input) + + stored_list = rs.run_args.get("host-list", None) + assert stored_list + + # confirm that all values from the original list are retrieved + split_stored_list = stored_list.split(",") + assert set(hostnames) == set(split_stored_list) diff --git a/tests/test_dragon_step.py b/tests/test_dragon_step.py index 19f408e0bd..f933fb7bc2 100644 --- a/tests/test_dragon_step.py +++ b/tests/test_dragon_step.py @@ -73,12 +73,18 @@ def dragon_batch_step(test_dir: str) -> DragonBatchStep: cpu_affinities = [[], [0, 1, 2], [], [3, 4, 5, 6]] gpu_affinities = [[], [], [0, 1, 2], [3, 4, 5, 6]] + # specify 3 hostnames to select from but require only 2 nodes + num_nodes = 2 + hostnames = ["host1", "host2", "host3"] + # assign some unique affinities to each run setting instance for index, rs in enumerate(settings): if gpu_affinities[index]: rs.set_node_feature("gpu") rs.set_cpu_affinity(cpu_affinities[index]) rs.set_gpu_affinity(gpu_affinities[index]) + rs.set_hostlist(hostnames) + rs.set_nodes(num_nodes) steps = list( DragonStep(name_, test_dir, rs_) for name_, rs_ in zip(names, settings) @@ -374,6 +380,11 @@ def test_dragon_batch_step_write_request_file( cpu_affinities = [[], [0, 1, 2], [], [3, 4, 5, 6]] gpu_affinities = [[], [], [0, 1, 2], [3, 4, 5, 6]] + hostnames = ["host1", "host2", "host3"] + num_nodes = 2 + + # parse requests file path from the launch command + # e.g. dragon python launch_cmd = dragon_batch_step.get_launch_cmd() requests_file = get_request_path_from_batch_script(launch_cmd) @@ -392,3 +403,5 @@ def test_dragon_batch_step_write_request_file( assert run_request assert run_request.policy.cpu_affinity == cpu_affinities[index] assert run_request.policy.gpu_affinity == gpu_affinities[index] + assert run_request.nodes == num_nodes + assert run_request.hostlist == ",".join(hostnames) diff --git a/tests/test_message_handler/__init__.py b/tests/test_message_handler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_message_handler/test_build_model.py b/tests/test_message_handler/test_build_model.py new file mode 100644 index 0000000000..56c1c8764c --- /dev/null +++ b/tests/test_message_handler/test_build_model.py @@ -0,0 +1,72 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +def test_build_model_successful(): + expected_data = b"model data" + expected_name = "model name" + expected_version = "v0.0.1" + model = handler.build_model(expected_data, expected_name, expected_version) + assert model.data == expected_data + assert model.name == expected_name + assert model.version == expected_version + + +@pytest.mark.parametrize( + "data, name, version", + [ + pytest.param( + 100, + "model name", + "v0.0.1", + id="bad data type", + ), + pytest.param( + b"model data", + 1, + "v0.0.1", + id="bad name type", + ), + pytest.param( + b"model data", + "model name", + 0.1, + id="bad version type", + ), + ], +) +def test_build_model_unsuccessful(data, name, version): + with pytest.raises(ValueError): + model = handler.build_model(data, name, version) diff --git a/tests/test_message_handler/test_build_model_key.py b/tests/test_message_handler/test_build_model_key.py new file mode 100644 index 0000000000..6c9b3dc951 --- /dev/null +++ b/tests/test_message_handler/test_build_model_key.py @@ -0,0 +1,47 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +def test_build_model_key_successful(): + fsd = "mock-feature-store-descriptor" + model_key = handler.build_model_key("tensor_key", fsd) + assert model_key.key == "tensor_key" + assert model_key.descriptor == fsd + + +def test_build_model_key_unsuccessful(): + with pytest.raises(ValueError): + fsd = "mock-feature-store-descriptor" + model_key = handler.build_model_key(100, fsd) diff --git a/tests/test_message_handler/test_build_request_attributes.py b/tests/test_message_handler/test_build_request_attributes.py new file mode 100644 index 0000000000..5b1e09b0aa --- /dev/null +++ b/tests/test_message_handler/test_build_request_attributes.py @@ -0,0 +1,55 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +def test_build_torch_request_attributes_successful(): + attribute = handler.build_torch_request_attributes("sparse") + assert attribute.tensorType == "sparse" + + +def test_build_torch_request_attributes_unsuccessful(): + with pytest.raises(ValueError): + attribute = handler.build_torch_request_attributes("invalid!") + + +def test_build_tf_request_attributes_successful(): + attribute = handler.build_tf_request_attributes(name="tfcnn", tensor_type="sparse") + assert attribute.tensorType == "sparse" + assert attribute.name == "tfcnn" + + +def test_build_tf_request_attributes_unsuccessful(): + with pytest.raises(ValueError): + attribute = handler.build_tf_request_attributes("tf_fail", "invalid!") diff --git a/tests/test_message_handler/test_build_tensor_desc.py b/tests/test_message_handler/test_build_tensor_desc.py new file mode 100644 index 0000000000..45126fb16c --- /dev/null +++ b/tests/test_message_handler/test_build_tensor_desc.py @@ -0,0 +1,90 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +@pytest.mark.parametrize( + "dtype, order, dimension", + [ + pytest.param( + "int8", + "c", + [3, 2, 5], + id="small torch tensor", + ), + pytest.param( + "int64", + "c", + [1040, 1040, 3], + id="medium torch tensor", + ), + ], +) +def test_build_tensor_descriptor_successful(dtype, order, dimension): + built_tensor_descriptor = handler.build_tensor_descriptor(order, dtype, dimension) + assert built_tensor_descriptor is not None + assert built_tensor_descriptor.order == order + assert built_tensor_descriptor.dataType == dtype + for i, j in zip(built_tensor_descriptor.dimensions, dimension): + assert i == j + + +@pytest.mark.parametrize( + "dtype, order, dimension", + [ + pytest.param( + "bad_order", + "int8", + [3, 2, 5], + id="bad order type", + ), + pytest.param( + "f", + "bad_num_type", + [3, 2, 5], + id="bad numerical type", + ), + pytest.param( + "f", + "int8", + "bad shape type", + id="bad shape type", + ), + ], +) +def test_build_tensor_descriptor_unsuccessful(dtype, order, dimension): + with pytest.raises(ValueError): + built_tensor_descriptor = handler.build_tensor_descriptor( + order, dtype, dimension + ) diff --git a/tests/test_message_handler/test_build_tensor_key.py b/tests/test_message_handler/test_build_tensor_key.py new file mode 100644 index 0000000000..6a28b80c4f --- /dev/null +++ b/tests/test_message_handler/test_build_tensor_key.py @@ -0,0 +1,46 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +def test_build_tensor_key_successful(): + fsd = "mock-feature-store-descriptor" + tensor_key = handler.build_tensor_key("tensor_key", fsd) + assert tensor_key.key == "tensor_key" + + +def test_build_tensor_key_unsuccessful(): + with pytest.raises(ValueError): + fsd = "mock-feature-store-descriptor" + tensor_key = handler.build_tensor_key(100, fsd) diff --git a/tests/test_message_handler/test_output_descriptor.py b/tests/test_message_handler/test_output_descriptor.py new file mode 100644 index 0000000000..beb9a47657 --- /dev/null +++ b/tests/test_message_handler/test_output_descriptor.py @@ -0,0 +1,78 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + +fsd = "mock-feature-store-descriptor" +tensor_key = handler.build_tensor_key("key", fsd) + + +@pytest.mark.parametrize( + "order, keys, dtype, dimension", + [ + pytest.param("c", [tensor_key], "int8", [1, 2, 3, 4], id="all specified"), + pytest.param( + "c", [tensor_key, tensor_key], "none", [1, 2, 3, 4], id="none dtype" + ), + pytest.param("c", [tensor_key], "int8", [], id="empty dimensions"), + pytest.param("c", [], "int8", [1, 2, 3, 4], id="empty keys"), + ], +) +def test_build_output_tensor_descriptor_successful(dtype, keys, order, dimension): + built_descriptor = handler.build_output_tensor_descriptor( + order, keys, dtype, dimension + ) + assert built_descriptor is not None + assert built_descriptor.order == order + assert len(built_descriptor.optionalKeys) == len(keys) + assert built_descriptor.optionalDatatype == dtype + for i, j in zip(built_descriptor.optionalDimension, dimension): + assert i == j + + +@pytest.mark.parametrize( + "order, keys, dtype, dimension", + [ + pytest.param("bad_order", [], "int8", [3, 2, 5], id="bad order type"), + pytest.param( + "f", [tensor_key], "bad_num_type", [3, 2, 5], id="bad numerical type" + ), + pytest.param("f", [tensor_key], "int8", "bad shape type", id="bad shape type"), + pytest.param("f", ["tensor_key"], "int8", [3, 2, 5], id="bad key type"), + ], +) +def test_build_output_tensor_descriptor_unsuccessful(order, keys, dtype, dimension): + with pytest.raises(ValueError): + built_tensor = handler.build_output_tensor_descriptor( + order, keys, dtype, dimension + ) diff --git a/tests/test_message_handler/test_request.py b/tests/test_message_handler/test_request.py new file mode 100644 index 0000000000..a60818f7dd --- /dev/null +++ b/tests/test_message_handler/test_request.py @@ -0,0 +1,449 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +fsd = "mock-feature-store-descriptor" + +model_key = MessageHandler.build_model_key("model_key", fsd) +model = MessageHandler.build_model(b"model data", "model_name", "v0.0.1") + +input_key1 = MessageHandler.build_tensor_key("input_key1", fsd) +input_key2 = MessageHandler.build_tensor_key("input_key2", fsd) + +output_key1 = MessageHandler.build_tensor_key("output_key1", fsd) +output_key2 = MessageHandler.build_tensor_key("output_key2", fsd) + +output_descriptor1 = MessageHandler.build_output_tensor_descriptor( + "c", [output_key1, output_key2], "int64", [] +) +output_descriptor2 = MessageHandler.build_output_tensor_descriptor("f", [], "auto", []) +output_descriptor3 = MessageHandler.build_output_tensor_descriptor( + "c", [output_key1], "none", [1, 2, 3] +) +torch_attributes = MessageHandler.build_torch_request_attributes("sparse") +tf_attributes = MessageHandler.build_tf_request_attributes( + name="tf", tensor_type="sparse" +) + +tensor_1 = MessageHandler.build_tensor_descriptor("c", "int8", [1]) +tensor_2 = MessageHandler.build_tensor_descriptor("c", "int64", [3, 2]) +tensor_3 = MessageHandler.build_tensor_descriptor("f", "int8", [1]) +tensor_4 = MessageHandler.build_tensor_descriptor("f", "int64", [3, 2]) + + +tf_indirect_request = MessageHandler.build_request( + b"reply", + model, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1, output_descriptor2, output_descriptor3], + tf_attributes, +) + +tf_direct_request = MessageHandler.build_request( + b"reply", + model, + [tensor_3, tensor_4], + [], + [output_descriptor1, output_descriptor2], + tf_attributes, +) + +torch_indirect_request = MessageHandler.build_request( + b"reply", + model, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1, output_descriptor2, output_descriptor3], + torch_attributes, +) + +torch_direct_request = MessageHandler.build_request( + b"reply", + model, + [tensor_1, tensor_2], + [], + [output_descriptor1, output_descriptor2], + torch_attributes, +) + + +@pytest.mark.parametrize( + "reply_channel, model, input, output, output_descriptors, custom_attributes", + [ + pytest.param( + "reply channel", + model_key, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1], + torch_attributes, + ), + pytest.param( + "another reply channel", + model, + [input_key1], + [output_key2], + [output_descriptor1], + tf_attributes, + ), + pytest.param( + "another reply channel", + model, + [input_key1], + [output_key2], + [output_descriptor1], + torch_attributes, + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [output_key1], + [output_descriptor1], + None, + ), + ], +) +def test_build_request_indirect_successful( + reply_channel, model, input, output, output_descriptors, custom_attributes +): + built_request = MessageHandler.build_request( + reply_channel, + model, + input, + output, + output_descriptors, + custom_attributes, + ) + assert built_request is not None + assert built_request.replyChannel.descriptor == reply_channel + if built_request.model.which() == "key": + assert built_request.model.key.key == model.key + else: + assert built_request.model.data.data == model.data + assert built_request.model.data.name == model.name + assert built_request.model.data.version == model.version + assert built_request.input.which() == "keys" + assert built_request.input.keys[0].key == input[0].key + assert len(built_request.input.keys) == len(input) + assert len(built_request.output) == len(output) + for i, j in zip(built_request.outputDescriptors, output_descriptors): + assert i.order == j.order + if built_request.customAttributes.which() == "tf": + assert ( + built_request.customAttributes.tf.tensorType == custom_attributes.tensorType + ) + elif built_request.customAttributes.which() == "torch": + assert ( + built_request.customAttributes.torch.tensorType + == custom_attributes.tensorType + ) + else: + assert built_request.customAttributes.none == custom_attributes + + +@pytest.mark.parametrize( + "reply_channel, model, input, output, output_descriptors, custom_attributes", + [ + pytest.param( + [], + model_key, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1], + tf_attributes, + id="bad channel", + ), + pytest.param( + "reply channel", + "bad model", + [input_key1], + [output_key2], + [output_descriptor1], + torch_attributes, + id="bad model", + ), + pytest.param( + "reply channel", + model_key, + ["input_key1", "input_key2"], + [output_key1, output_key2], + [output_descriptor1], + tf_attributes, + id="bad inputs", + ), + pytest.param( + "reply channel", + model_key, + [torch_attributes], + [output_key1, output_key2], + [output_descriptor1], + torch_attributes, + id="bad input schema type", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + ["output_key1", "output_key2"], + [output_descriptor1], + tf_attributes, + id="bad outputs", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [torch_attributes], + [output_descriptor1], + tf_attributes, + id="bad output schema type", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [output_key1, output_key2], + [output_descriptor1], + "bad attributes", + id="bad custom attributes", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [output_key1, output_key2], + [output_descriptor1], + model_key, + id="bad custom attributes schema type", + ), + pytest.param( + "reply channel", + model_key, + [input_key1], + [output_key1, output_key2], + "bad descriptors", + torch_attributes, + id="bad output descriptors", + ), + ], +) +def test_build_request_indirect_unsuccessful( + reply_channel, model, input, output, output_descriptors, custom_attributes +): + with pytest.raises(ValueError): + built_request = MessageHandler.build_request( + reply_channel, + model, + input, + output, + output_descriptors, + custom_attributes, + ) + + +@pytest.mark.parametrize( + "reply_channel, model, input, output, output_descriptors, custom_attributes", + [ + pytest.param( + "reply channel", + model_key, + [tensor_1, tensor_2], + [], + [output_descriptor2], + torch_attributes, + ), + pytest.param( + "another reply channel", + model, + [tensor_1], + [], + [output_descriptor3], + tf_attributes, + ), + pytest.param( + "another reply channel", + model, + [tensor_2], + [], + [output_descriptor1], + tf_attributes, + ), + pytest.param( + "another reply channel", + model, + [tensor_1], + [], + [output_descriptor1], + None, + ), + ], +) +def test_build_request_direct_successful( + reply_channel, model, input, output, output_descriptors, custom_attributes +): + built_request = MessageHandler.build_request( + reply_channel, + model, + input, + output, + output_descriptors, + custom_attributes, + ) + assert built_request is not None + assert built_request.replyChannel.descriptor == reply_channel + if built_request.model.which() == "key": + assert built_request.model.key.key == model.key + else: + assert built_request.model.data.data == model.data + assert built_request.model.data.name == model.name + assert built_request.model.data.version == model.version + assert built_request.input.which() == "descriptors" + assert len(built_request.input.descriptors) == len(input) + assert len(built_request.output) == len(output) + for i, j in zip(built_request.outputDescriptors, output_descriptors): + assert i.order == j.order + if built_request.customAttributes.which() == "tf": + assert ( + built_request.customAttributes.tf.tensorType == custom_attributes.tensorType + ) + elif built_request.customAttributes.which() == "torch": + assert ( + built_request.customAttributes.torch.tensorType + == custom_attributes.tensorType + ) + else: + assert built_request.customAttributes.none == custom_attributes + + +@pytest.mark.parametrize( + "reply_channel, model, input, output, output_descriptors, custom_attributes", + [ + pytest.param( + [], + model_key, + [tensor_3, tensor_4], + [], + [output_descriptor2], + tf_attributes, + id="bad channel", + ), + pytest.param( + b"reply channel", + "bad model", + [tensor_4], + [], + [output_descriptor2], + tf_attributes, + id="bad model", + ), + pytest.param( + b"reply channel", + model_key, + ["input_key1", "input_key2"], + [], + [output_descriptor2], + torch_attributes, + id="bad inputs", + ), + pytest.param( + b"reply channel", + model_key, + [], + ["output_key1", "output_key2"], + [output_descriptor2], + tf_attributes, + id="bad outputs", + ), + pytest.param( + b"reply channel", + model_key, + [tensor_4], + [], + [output_descriptor2], + "bad attributes", + id="bad custom attributes", + ), + pytest.param( + b"reply_channel", + model_key, + [tensor_3, tensor_4], + [], + ["output_descriptor2"], + torch_attributes, + id="bad output descriptors", + ), + ], +) +def test_build_request_direct_unsuccessful( + reply_channel, model, input, output, output_descriptors, custom_attributes +): + with pytest.raises(ValueError): + built_request = MessageHandler.build_request( + reply_channel, + model, + input, + output, + output_descriptors, + custom_attributes, + ) + + +@pytest.mark.parametrize( + "req", + [ + pytest.param(tf_indirect_request, id="tf indirect"), + pytest.param(tf_direct_request, id="tf direct"), + pytest.param(torch_indirect_request, id="indirect"), + pytest.param(torch_direct_request, id="direct"), + ], +) +def test_serialize_request_successful(req): + serialized = MessageHandler.serialize_request(req) + assert type(serialized) == bytes + + deserialized = MessageHandler.deserialize_request(serialized) + assert deserialized.to_dict() == req.to_dict() + + +def test_serialization_fails(): + with pytest.raises(ValueError): + bad_request = MessageHandler.serialize_request(tensor_1) + + +def test_deserialization_fails(): + with pytest.raises(ValueError): + new_req = torch_direct_request.copy() + req_bytes = MessageHandler.serialize_request(new_req) + req_bytes = req_bytes + b"extra bytes" + deser = MessageHandler.deserialize_request(req_bytes) diff --git a/tests/test_message_handler/test_response.py b/tests/test_message_handler/test_response.py new file mode 100644 index 0000000000..86774132ec --- /dev/null +++ b/tests/test_message_handler/test_response.py @@ -0,0 +1,191 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +fsd = "mock-feature-store-descriptor" + +result_key1 = MessageHandler.build_tensor_key("result_key1", fsd) +result_key2 = MessageHandler.build_tensor_key("result_key2", fsd) + +torch_attributes = MessageHandler.build_torch_response_attributes() +tf_attributes = MessageHandler.build_tf_response_attributes() + +tensor1 = MessageHandler.build_tensor_descriptor("c", "int8", [1]) +tensor2 = MessageHandler.build_tensor_descriptor("c", "int64", [3, 2]) + + +tf_indirect_response = MessageHandler.build_response( + "complete", + "Success!", + [result_key1, result_key2], + tf_attributes, +) + +tf_direct_response = MessageHandler.build_response( + "complete", + "Success again!", + [tensor2, tensor1], + tf_attributes, +) + +torch_indirect_response = MessageHandler.build_response( + "complete", + "Success!", + [result_key1, result_key2], + torch_attributes, +) + +torch_direct_response = MessageHandler.build_response( + "complete", + "Success again!", + [tensor1, tensor2], + torch_attributes, +) + + +@pytest.mark.parametrize( + "status, status_message, result, custom_attribute", + [ + pytest.param( + 200, + "Yay, it worked!", + [tensor1, tensor2], + None, + id="tensor descriptor list", + ), + pytest.param( + 200, + "Yay, it worked!", + [result_key1, result_key2], + tf_attributes, + id="tensor key list", + ), + ], +) +def test_build_response_successful(status, status_message, result, custom_attribute): + response = MessageHandler.build_response( + status=status, + message=status_message, + result=result, + custom_attributes=custom_attribute, + ) + assert response is not None + assert response.status == status + assert response.message == status_message + if response.result.which() == "keys": + assert response.result.keys[0].to_dict() == result[0].to_dict() + else: + assert response.result.descriptors[0].to_dict() == result[0].to_dict() + + +@pytest.mark.parametrize( + "status, status_message, result, custom_attribute", + [ + pytest.param( + "bad status", + "Yay, it worked!", + [tensor1, tensor2], + None, + id="bad status", + ), + pytest.param( + "complete", + 200, + [tensor2], + torch_attributes, + id="bad status message", + ), + pytest.param( + "complete", + "Yay, it worked!", + ["result_key1", "result_key2"], + tf_attributes, + id="bad result", + ), + pytest.param( + "complete", + "Yay, it worked!", + [tf_attributes], + tf_attributes, + id="bad result type", + ), + pytest.param( + "complete", + "Yay, it worked!", + [tensor2, tensor1], + "custom attributes", + id="bad custom attributes", + ), + pytest.param( + "complete", + "Yay, it worked!", + [tensor2, tensor1], + result_key1, + id="bad custom attributes type", + ), + ], +) +def test_build_response_unsuccessful(status, status_message, result, custom_attribute): + with pytest.raises(ValueError): + response = MessageHandler.build_response( + status, status_message, result, custom_attribute + ) + + +@pytest.mark.parametrize( + "response", + [ + pytest.param(torch_indirect_response, id="indirect"), + pytest.param(torch_direct_response, id="direct"), + pytest.param(tf_indirect_response, id="tf indirect"), + pytest.param(tf_direct_response, id="tf direct"), + ], +) +def test_serialize_response(response): + serialized = MessageHandler.serialize_response(response) + assert type(serialized) == bytes + + deserialized = MessageHandler.deserialize_response(serialized) + assert deserialized.to_dict() == response.to_dict() + + +def test_serialization_fails(): + with pytest.raises(ValueError): + bad_response = MessageHandler.serialize_response(result_key1) + + +def test_deserialization_fails(): + with pytest.raises(ValueError): + new_resp = torch_direct_response.copy() + resp_bytes = MessageHandler.serialize_response(new_resp) + resp_bytes = resp_bytes + b"extra bytes" + deser = MessageHandler.deserialize_response(resp_bytes) diff --git a/tests/test_node_prioritizer.py b/tests/test_node_prioritizer.py new file mode 100644 index 0000000000..905c0ecc90 --- /dev/null +++ b/tests/test_node_prioritizer.py @@ -0,0 +1,553 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import random +import threading +import typing as t + +import pytest + +from smartsim._core.launcher.dragon.pqueue import NodePrioritizer, PrioritizerFilter +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +# The tests in this file belong to the group_b group +pytestmark = pytest.mark.group_b + + +logger = get_logger(__name__) + + +class MockNode: + def __init__(self, hostname: str, num_cpus: int, num_gpus: int) -> None: + self.hostname = hostname + self.num_cpus = num_cpus + self.num_gpus = num_gpus + + +def mock_node_hosts( + num_cpu_nodes: int, num_gpu_nodes: int +) -> t.Tuple[t.List[MockNode], t.List[MockNode]]: + cpu_hosts = [f"cpu-node-{i}" for i in range(num_cpu_nodes)] + gpu_hosts = [f"gpu-node-{i}" for i in range(num_gpu_nodes)] + + return cpu_hosts, gpu_hosts + + +def mock_node_builder(num_cpu_nodes: int, num_gpu_nodes: int) -> t.List[MockNode]: + nodes = [] + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + + nodes.extend(MockNode(hostname, 4, 0) for hostname in cpu_hosts) + nodes.extend(MockNode(hostname, 4, 4) for hostname in gpu_hosts) + + return nodes + + +def test_node_prioritizer_init_null() -> None: + """Verify that the priorizer reports failures to send a valid node set + if a null value is passed""" + lock = threading.RLock() + with pytest.raises(SmartSimError) as ex: + NodePrioritizer(None, lock) + + assert "Missing" in ex.value.args[0] + + +def test_node_prioritizer_init_empty() -> None: + """Verify that the priorizer reports failures to send a valid node set + if an empty list is passed""" + lock = threading.RLock() + with pytest.raises(SmartSimError) as ex: + NodePrioritizer([], lock) + + assert "Missing" in ex.value.args[0] + + +@pytest.mark.parametrize( + "num_cpu_nodes,num_gpu_nodes", [(1, 1), (2, 1), (1, 2), (8, 4), (1000, 200)] +) +def test_node_prioritizer_init_ok(num_cpu_nodes: int, num_gpu_nodes: int) -> None: + """Verify that initialization with a valid node list results in the + appropriate cpu & gpu ref counts, and complete ref map""" + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + # perform prioritizer initialization + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # get a copy of all the expected host names + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + all_hosts = cpu_hosts + gpu_hosts + assert len(all_hosts) == num_cpu_nodes + num_gpu_nodes + + # verify tracking data is initialized correctly for all nodes + for hostname in all_hosts: + # show that the ref map is tracking the node + assert hostname in p._nodes + + tracking_info = p.get_tracking_info(hostname) + + # show that the node is created w/zero ref counts + assert tracking_info.num_refs == 0 + + # show that the node is created and marked as not dirty (unchanged) + # assert tracking_info.is_dirty == False + + # iterate through known cpu node keys and verify prioritizer initialization + for hostname in cpu_hosts: + # show that the device ref counters are appropriately assigned + cpu_ref = next((n for n in p._cpu_refs if n.hostname == hostname), None) + assert cpu_ref, "CPU-only node not found in cpu ref set" + + gpu_ref = next((n for n in p._gpu_refs if n.hostname == hostname), None) + assert not gpu_ref, "CPU-only node should not be found in gpu ref set" + + # iterate through known GPU node keys and verify prioritizer initialization + for hostname in gpu_hosts: + # show that the device ref counters are appropriately assigned + gpu_ref = next((n for n in p._gpu_refs if n.hostname == hostname), None) + assert gpu_ref, "GPU-only node not found in gpu ref set" + + cpu_ref = next((n for n in p._cpu_refs if n.hostname == hostname), None) + assert not cpu_ref, "GPU-only node should not be found in cpu ref set" + + # verify we have all hosts in the ref map + assert set(p._nodes.keys()) == set(all_hosts) + + # verify we have no extra hosts in ref map + assert len(p._nodes.keys()) == len(set(all_hosts)) + + +def test_node_prioritizer_direct_increment() -> None: + """Verify that performing the increment operation causes the expected + side effect on the intended records""" + + num_cpu_nodes, num_gpu_nodes = 32, 8 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + exclude_index = 2 + exclude_host0 = cpu_hosts[exclude_index] + exclude_host1 = gpu_hosts[exclude_index] + exclusions = [exclude_host0, exclude_host1] + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # let's increment each element in a predictable way and verify + for node in nodes: + if node.hostname in exclusions: + # expect 1 cpu and 1 gpu node at zero and not incremented + continue + + if node.num_gpus == 0: + num_increments = random.randint(0, num_cpu_nodes - 1) + else: + num_increments = random.randint(0, num_gpu_nodes - 1) + + # increment this node some random number of times + for _ in range(num_increments): + p.increment(node.hostname) + + # ... and verify the correct incrementing is applied + tracking_info = p.get_tracking_info(node.hostname) + assert tracking_info.num_refs == num_increments + + # verify the excluded cpu node was never changed + tracking_info0 = p.get_tracking_info(exclude_host0) + assert tracking_info0.num_refs == 0 + + # verify the excluded gpu node was never changed + tracking_info1 = p.get_tracking_info(exclude_host1) + assert tracking_info1.num_refs == 0 + + +def test_node_prioritizer_indirect_increment() -> None: + """Verify that performing the increment operation indirectly affects + each available node until we run out of nodes to return""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # verify starting state + for node in p._nodes.values(): + tracking_info = p.get_tracking_info(node.hostname) + + assert node.num_refs == 0 # <--- ref count starts at zero + assert tracking_info.num_refs == 0 # <--- ref count starts at zero + + # perform indirect + for node in p._nodes.values(): + tracking_info = p.get_tracking_info(node.hostname) + + # apply `next` operation and verify tracking info reflects new ref + node = p.next(PrioritizerFilter.CPU) + tracking_info = p.get_tracking_info(node.hostname) + + # verify side-effects + assert tracking_info.num_refs > 0 # <--- ref count should now be > 0 + + # we expect it to give back only "clean" nodes from next* + assert tracking_info.is_dirty == False # NOTE: this is "hidden" by protocol + + # every node should be incremented now. prioritizer shouldn't have anything to give + tracking_info = p.next(PrioritizerFilter.CPU) + assert tracking_info is None # <--- get_next shouldn't have any nodes to give + + +def test_node_prioritizer_indirect_decrement_availability() -> None: + """Verify that a node who is decremented (dirty) is made assignable + on a subsequent request""" + + num_cpu_nodes, num_gpu_nodes = 1, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # increment our only node... + p.increment(cpu_hosts[0]) + + tracking_info = p.next() + assert tracking_info is None, "No nodes should be assignable" + + # perform a decrement... + p.decrement(cpu_hosts[0]) + + # ... and confirm that the node is available again + tracking_info = p.next() + assert tracking_info is not None, "A node should be assignable" + + +def test_node_prioritizer_multi_increment() -> None: + """Verify that retrieving multiple nodes via `next_n` API correctly + increments reference counts and returns appropriate results""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + assert p.get_tracking_info(cpu_hosts[0]).num_refs > 0 + + p.increment(cpu_hosts[2]) + assert p.get_tracking_info(cpu_hosts[2]).num_refs > 0 + + p.increment(cpu_hosts[4]) + assert p.get_tracking_info(cpu_hosts[4]).num_refs > 0 + + # use next_n w/the minimum allowed value + all_tracking_info = p.next_n(1, PrioritizerFilter.CPU) # <---- next_n(1) + + # confirm the number requested is honored + assert len(all_tracking_info) == 1 + # ensure no unavailable node is returned + assert all_tracking_info[0].hostname not in [ + cpu_hosts[0], + cpu_hosts[2], + cpu_hosts[4], + ] + + # use next_n w/value that exceeds available number of open nodes + # 3 direct increments in setup, 1 out of next_n(1), 4 left + all_tracking_info = p.next_n(5, PrioritizerFilter.CPU) + + # confirm that no nodes are returned, even though 4 out of 5 requested are available + assert len(all_tracking_info) == 0 + + +def test_node_prioritizer_multi_increment_validate_n() -> None: + """Verify that retrieving multiple nodes via `next_n` API correctly + reports failures when the request size is above pool size""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # we have 8 total cpu nodes available... request too many nodes + all_tracking_info = p.next_n(9, PrioritizerFilter.CPU) + assert len(all_tracking_info) == 0 + + all_tracking_info = p.next_n(num_cpu_nodes * 1000, PrioritizerFilter.CPU) + assert len(all_tracking_info) == 0 + + +def test_node_prioritizer_indirect_direct_interleaved_increments() -> None: + """Verify that interleaving indirect and direct increments results in + expected ref counts""" + + num_cpu_nodes, num_gpu_nodes = 8, 4 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # perform some set of non-popped increments + p.increment(gpu_hosts[1]) + p.increment(gpu_hosts[3]) + p.increment(gpu_hosts[3]) + + # increment 0th item 1x + p.increment(cpu_hosts[0]) + + # increment 3th item 2x + p.increment(cpu_hosts[3]) + p.increment(cpu_hosts[3]) + + # increment last item 3x + p.increment(cpu_hosts[7]) + p.increment(cpu_hosts[7]) + p.increment(cpu_hosts[7]) + + tracking_info = p.get_tracking_info(gpu_hosts[1]) + assert tracking_info.num_refs == 1 + + tracking_info = p.get_tracking_info(gpu_hosts[3]) + assert tracking_info.num_refs == 2 + + nodes = [n for n in p._nodes.values() if n.num_refs == 0 and n.num_gpus == 0] + + # we should skip the 0-th item in the heap due to direct increment + tracking_info = p.next(PrioritizerFilter.CPU) + assert tracking_info.num_refs == 1 + # confirm we get a cpu node + assert "cpu-node" in tracking_info.hostname + + # this should pull the next item right out + tracking_info = p.next(PrioritizerFilter.CPU) + assert tracking_info.num_refs == 1 + assert "cpu-node" in tracking_info.hostname + + # ensure we pull from gpu nodes and the 0th item is returned + tracking_info = p.next(PrioritizerFilter.GPU) + assert tracking_info.num_refs == 1 + assert "gpu-node" in tracking_info.hostname + + # we should step over the 3-th node on this iteration + tracking_info = p.next(PrioritizerFilter.CPU) + assert tracking_info.num_refs == 1 + assert "cpu-node" in tracking_info.hostname + + # and ensure that heap also steps over a direct increment + tracking_info = p.next(PrioritizerFilter.GPU) + assert tracking_info.num_refs == 1 + assert "gpu-node" in tracking_info.hostname + + # and another GPU request should return nothing + tracking_info = p.next(PrioritizerFilter.GPU) + assert tracking_info is None + + +def test_node_prioritizer_decrement_floor() -> None: + """Verify that repeatedly decrementing ref counts does not + allow negative ref counts""" + + num_cpu_nodes, num_gpu_nodes = 8, 4 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # try a ton of decrements on all the items in the prioritizer + for _ in range(len(nodes) * 100): + index = random.randint(0, num_cpu_nodes - 1) + p.decrement(cpu_hosts[index]) + + index = random.randint(0, num_gpu_nodes - 1) + p.decrement(gpu_hosts[index]) + + for node in nodes: + tracking_info = p.get_tracking_info(node.hostname) + assert tracking_info.num_refs == 0 + + +@pytest.mark.parametrize("num_requested", [1, 2, 3]) +def test_node_prioritizer_multi_increment_subheap(num_requested: int) -> None: + """Verify that retrieving multiple nodes via `next_n` API correctly + increments reference counts and returns appropriate results + when requesting an in-bounds number of nodes""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + p.increment(cpu_hosts[4]) + + hostnames = [cpu_hosts[0], cpu_hosts[1], cpu_hosts[2], cpu_hosts[3], cpu_hosts[5]] + + # request n == {num_requested} nodes from set of 3 available + all_tracking_info = p.next_n( + num_requested, + hosts=hostnames, + ) # <---- w/0,2,4 assigned, only 1,3,5 from hostnames can work + + # all parameterizations should result in a matching output size + assert len(all_tracking_info) == num_requested + + +def test_node_prioritizer_multi_increment_subheap_assigned() -> None: + """Verify that retrieving multiple nodes via `next_n` API does + not return anything when the number requested cannot be satisfied + by the given subheap due to prior assignment""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + hostnames = [ + cpu_hosts[0], + "x" + cpu_hosts[2], + ] # <--- we can't get 2 from 1 valid node name + + # request n == {num_requested} nodes from set of 3 available + num_requested = 2 + all_tracking_info = p.next_n(num_requested, hosts=hostnames) + + # w/0,2 assigned, nothing can be returned + assert len(all_tracking_info) == 0 + + +def test_node_prioritizer_empty_subheap_next_w_no_hosts() -> None: + """Verify that retrieving multiple nodes via `next_n` API does + with an empty host list uses the entire available host list""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + hostnames = [] + + # request n == {num_requested} nodes from set of 3 available + num_requested = 1 + node = p.next(hosts=hostnames) + assert node + + # assert "No hostnames provided" == ex.value.args[0] + + +def test_node_prioritizer_empty_subheap_next_n_w_hosts() -> None: + """Verify that retrieving multiple nodes via `next_n` API does + not blow up with an empty host list""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + hostnames = [] + + # request n == {num_requested} nodes from set of 3 available + num_requested = 1 + node = p.next_n(num_requested, hosts=hostnames) + assert node is not None + + +@pytest.mark.parametrize("num_requested", [-100, -1, 0]) +def test_node_prioritizer_empty_subheap_next_n(num_requested: int) -> None: + """Verify that retrieving a node via `next_n` API does + not allow a request with num_items < 1""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + # request n == {num_requested} nodes from set of 3 available + with pytest.raises(ValueError) as ex: + p.next_n(num_requested) + + assert "Number of items requested" in ex.value.args[0] + + +@pytest.mark.parametrize("num_requested", [-100, -1, 0]) +def test_node_prioritizer_empty_subheap_next_n(num_requested: int) -> None: + """Verify that retrieving multiple nodes via `next_n` API does + not allow a request with num_items < 1""" + + num_cpu_nodes, num_gpu_nodes = 8, 0 + cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes) + nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes) + + lock = threading.RLock() + p = NodePrioritizer(nodes, lock) + + # Mark some nodes as dirty to verify retrieval + p.increment(cpu_hosts[0]) + p.increment(cpu_hosts[2]) + + hostnames = [cpu_hosts[0], cpu_hosts[2]] + + # request n == {num_requested} nodes from set of 3 available + with pytest.raises(ValueError) as ex: + p.next_n(num_requested, hosts=hostnames) + + assert "Number of items requested" in ex.value.args[0] From 5bdafc5f93fd56bf94ca5a7979a28f185c7c7ebf Mon Sep 17 00:00:00 2001 From: Chris McBride Date: Mon, 28 Oct 2024 11:59:12 -0400 Subject: [PATCH 2/5] Merge core refactor into v1.0 feature branch (#756) Combine the `core-refactor` feature branch with `mli-feature` in `v1.0` branch --------- Co-authored-by: Amanda Richardson Co-authored-by: Amanda Richardson Co-authored-by: Matt Drozt Co-authored-by: Julia Putko <81587103+juliaputko@users.noreply.github.com> Co-authored-by: amandarichardsonn <30413257+amandarichardsonn@users.noreply.github.com> Co-authored-by: Alyssa Cote <46540273+AlyssaCote@users.noreply.github.com> Co-authored-by: Al Rigazzi Co-authored-by: Julia Putko Co-authored-by: Matt Drozt --- .github/workflows/run_tests.yml | 13 +- .pylintrc | 2 +- Makefile | 10 +- conftest.py | 306 +++-- doc/changelog.md | 1 + doc/installation_instructions/basic.rst | 10 +- .../experiment_doc_examples/exp.py | 2 +- .../ml_inference/Inference-in-SmartSim.ipynb | 5 +- pyproject.toml | 32 + smartsim/_core/__init__.py | 4 +- smartsim/_core/_cli/build.py | 124 +- smartsim/_core/_cli/cli.py | 2 +- smartsim/_core/_cli/dbcli.py | 8 +- smartsim/_core/_cli/info.py | 41 +- smartsim/_core/_cli/scripts/dragon_install.py | 2 +- smartsim/_core/_cli/utils.py | 25 +- smartsim/_core/_cli/validate.py | 164 +-- smartsim/_core/_install/buildenv.py | 10 +- smartsim/_core/_install/builder.py | 96 +- smartsim/_core/arguments/shell.py | 42 + smartsim/_core/commands/__init__.py | 29 + smartsim/_core/commands/command.py | 98 ++ smartsim/_core/commands/command_list.py | 107 ++ smartsim/_core/commands/launch_commands.py | 51 + smartsim/_core/config/config.py | 90 +- smartsim/_core/control/__init__.py | 1 - smartsim/_core/control/controller.py | 956 --------------- smartsim/_core/control/controller_utils.py | 77 -- smartsim/_core/control/interval.py | 112 ++ smartsim/_core/control/job.py | 54 +- smartsim/_core/control/jobmanager.py | 364 ------ smartsim/_core/control/launch_history.py | 96 ++ smartsim/_core/control/manifest.py | 157 ++- ...previewrenderer.py => preview_renderer.py} | 20 +- smartsim/_core/dispatch.py | 389 ++++++ smartsim/_core/entrypoints/colocated.py | 348 ------ smartsim/_core/entrypoints/dragon.py | 16 +- smartsim/_core/entrypoints/dragon_client.py | 2 +- smartsim/_core/entrypoints/file_operations.py | 293 +++++ smartsim/_core/entrypoints/indirect.py | 4 +- smartsim/_core/entrypoints/redis.py | 192 --- ...lemetrymonitor.py => telemetry_monitor.py} | 0 smartsim/_core/generation/generator.py | 502 ++++---- smartsim/_core/generation/modelwriter.py | 158 --- .../_core/generation/operations/operations.py | 280 +++++ .../generation/operations/utils/helpers.py | 27 + smartsim/_core/launcher/__init__.py | 10 +- smartsim/_core/launcher/colocated.py | 244 ---- .../{dragonBackend.py => dragon_backend.py} | 40 +- ...dragonConnector.py => dragon_connector.py} | 63 +- .../{dragonLauncher.py => dragon_launcher.py} | 143 ++- .../{dragonSockets.py => dragon_sockets.py} | 4 +- smartsim/_core/launcher/launcher.py | 14 +- smartsim/_core/launcher/local/local.py | 12 +- .../lsf/{lsfCommands.py => lsf_commands.py} | 0 .../lsf/{lsfLauncher.py => lsf_launcher.py} | 12 +- .../lsf/{lsfParser.py => lsf_parser.py} | 0 .../pbs/{pbsCommands.py => pbs_commands.py} | 0 .../pbs/{pbsLauncher.py => pbs_launcher.py} | 12 +- .../pbs/{pbsParser.py => pbs_parser.py} | 0 .../sge/{sgeCommands.py => sge_commands.py} | 0 .../sge/{sgeLauncher.py => sge_launcher.py} | 16 +- .../sge/{sgeParser.py => sge_parser.py} | 0 .../{slurmCommands.py => slurm_commands.py} | 0 .../{slurmLauncher.py => slurm_launcher.py} | 10 +- .../slurm/{slurmParser.py => slurm_parser.py} | 0 smartsim/_core/launcher/step/__init__.py | 16 +- .../step/{alpsStep.py => alps_step.py} | 19 +- .../step/{dragonStep.py => dragon_step.py} | 12 +- .../step/{localStep.py => local_step.py} | 18 +- .../launcher/step/{lsfStep.py => lsf_step.py} | 22 +- .../launcher/step/{mpiStep.py => mpi_step.py} | 42 +- .../launcher/step/{pbsStep.py => pbs_step.py} | 7 +- .../launcher/step/{sgeStep.py => sge_step.py} | 0 .../step/{slurmStep.py => slurm_step.py} | 32 +- smartsim/_core/launcher/step/step.py | 34 +- .../launcher/{stepInfo.py => step_info.py} | 238 ++-- .../{stepMapping.py => step_mapping.py} | 0 .../{taskManager.py => task_manager.py} | 0 .../{launcherUtil.py => launcher_util.py} | 0 smartsim/_core/mli/comm/channel/dragon_fli.py | 4 +- .../mli/infrastructure/control/dragon_util.py | 79 ++ smartsim/_core/schemas/__init__.py | 4 +- .../{dragonRequests.py => dragon_requests.py} | 0 ...dragonResponses.py => dragon_responses.py} | 4 +- {tests => smartsim/_core/shell}/__init__.py | 0 smartsim/_core/shell/shell_launcher.py | 268 +++++ smartsim/_core/utils/__init__.py | 2 - smartsim/_core/utils/helpers.py | 288 ++++- smartsim/_core/utils/launcher.py | 99 ++ smartsim/_core/utils/redis.py | 238 ---- smartsim/_core/utils/serialize.py | 102 +- smartsim/_core/utils/telemetry/collector.py | 67 +- smartsim/_core/utils/telemetry/manifest.py | 28 +- smartsim/_core/utils/telemetry/telemetry.py | 57 +- smartsim/_core/utils/telemetry/util.py | 10 +- smartsim/builders/__init__.py | 28 + smartsim/builders/ensemble.py | 432 +++++++ smartsim/builders/utils/strategies.py | 262 +++++ smartsim/database/__init__.py | 2 +- smartsim/database/orchestrator.py | 371 +++--- smartsim/entity/__init__.py | 7 +- smartsim/entity/{strategies.py => _mock.py} | 44 +- smartsim/entity/application.py | 263 +++++ smartsim/entity/dbnode.py | 48 +- smartsim/entity/dbobject.py | 22 +- smartsim/entity/ensemble.py | 573 --------- smartsim/entity/entity.py | 50 +- smartsim/entity/entityList.py | 144 --- smartsim/entity/files.py | 162 +-- smartsim/entity/model.py | 701 ----------- smartsim/error/errors.py | 18 +- smartsim/experiment.py | 1043 +++++------------ smartsim/launchable/__init__.py | 34 + smartsim/launchable/base_job.py | 43 + smartsim/launchable/base_job_group.py | 91 ++ smartsim/launchable/colocated_job_group.py | 75 ++ smartsim/launchable/job.py | 160 +++ smartsim/launchable/job_group.py | 96 ++ smartsim/launchable/launchable.py | 38 + smartsim/launchable/mpmd_job.py | 118 ++ smartsim/launchable/mpmd_pair.py | 43 + smartsim/ml/data.py | 136 +-- smartsim/ml/tf/utils.py | 12 +- smartsim/ml/torch/data.py | 14 +- smartsim/settings/__init__.py | 91 +- smartsim/settings/arguments/__init__.py | 30 + smartsim/settings/arguments/batch/__init__.py | 35 + smartsim/settings/arguments/batch/lsf.py | 163 +++ smartsim/settings/arguments/batch/pbs.py | 186 +++ smartsim/settings/arguments/batch/slurm.py | 156 +++ .../settings/arguments/batch_arguments.py | 109 ++ .../settings/arguments/launch/__init__.py | 19 + .../launch/alps.py} | 181 ++- .../launch/dragon.py} | 67 +- smartsim/settings/arguments/launch/local.py | 87 ++ smartsim/settings/arguments/launch/lsf.py | 152 +++ smartsim/settings/arguments/launch/mpi.py | 255 ++++ smartsim/settings/arguments/launch/pals.py | 162 +++ smartsim/settings/arguments/launch/slurm.py | 353 ++++++ .../settings/arguments/launch_arguments.py | 75 ++ smartsim/settings/base.py | 689 ----------- smartsim/settings/base_settings.py | 30 + smartsim/settings/batch_command.py | 35 + smartsim/settings/batch_settings.py | 174 +++ smartsim/settings/common.py | 49 + smartsim/settings/containers.py | 173 --- smartsim/settings/launch_command.py | 41 + smartsim/settings/launch_settings.py | 226 ++++ smartsim/settings/lsfSettings.py | 560 --------- smartsim/settings/mpiSettings.py | 350 ------ smartsim/settings/palsSettings.py | 233 ---- smartsim/settings/pbsSettings.py | 263 ----- smartsim/settings/settings.py | 205 ---- .../{sgeSettings.py => sge_settings.py} | 0 smartsim/settings/slurmSettings.py | 513 -------- smartsim/status.py | 28 +- .../preview/plain_text/activeinfra.template | 10 +- .../preview/plain_text/base.template | 22 +- .../plain_text/clientconfig_debug.template | 10 +- .../plain_text/clientconfig_info.template | 10 +- .../clientconfigcolo_debug.template | 26 +- .../plain_text/clientconfigcolo_info.template | 14 +- .../plain_text/ensemble_debug.template | 12 +- .../preview/plain_text/ensemble_info.template | 18 +- .../preview/plain_text/model_debug.template | 34 +- .../preview/plain_text/model_info.template | 22 +- .../plain_text/orchestrator_debug.template | 32 +- .../plain_text/orchestrator_info.template | 12 +- smartsim/types.py | 32 + smartsim/wlm/pbs.py | 2 +- smartsim/wlm/slurm.py | 23 +- tests/_legacy/__init__.py | 25 + .../backends/run_sklearn_onnx.py | 2 +- tests/{ => _legacy}/backends/run_tf.py | 0 tests/{ => _legacy}/backends/run_torch.py | 5 +- .../backends/test_cli_mini_exp.py | 21 +- .../{ => _legacy}/backends/test_dataloader.py | 36 +- tests/{ => _legacy}/backends/test_dbmodel.py | 209 ++-- tests/{ => _legacy}/backends/test_dbscript.py | 273 ++--- tests/{ => _legacy}/backends/test_onnx.py | 20 +- tests/{ => _legacy}/backends/test_tf.py | 18 +- tests/{ => _legacy}/backends/test_torch.py | 20 +- .../full_wlm/test_generic_batch_launch.py | 58 +- .../full_wlm/test_generic_orc_launch_batch.py | 160 +-- tests/{ => _legacy}/full_wlm/test_mpmd.py | 18 +- .../full_wlm/test_slurm_allocation.py | 0 .../{ => _legacy}/full_wlm/test_symlinking.py | 62 +- .../full_wlm/test_wlm_helper_functions.py | 0 tests/{ => _legacy}/install/test_build.py | 0 tests/{ => _legacy}/install/test_buildenv.py | 0 tests/_legacy/install/test_builder.py | 404 +++++++ tests/{ => _legacy}/install/test_mlpackage.py | 0 .../install/test_package_retriever.py | 0 tests/{ => _legacy}/install/test_platform.py | 0 .../install/test_redisai_builder.py | 0 .../on_wlm/test_base_settings_on_wlm.py | 28 +- tests/_legacy/on_wlm/test_colocated_model.py | 207 ++++ .../on_wlm/test_containers_wlm.py | 22 +- tests/{ => _legacy}/on_wlm/test_dragon.py | 16 +- .../on_wlm/test_dragon_entrypoint.py | 0 .../on_wlm/test_generic_orc_launch.py | 76 +- tests/{ => _legacy}/on_wlm/test_het_job.py | 10 +- .../on_wlm/test_launch_errors.py | 22 +- .../on_wlm/test_launch_ompi_lsf.py | 12 +- tests/{ => _legacy}/on_wlm/test_local_step.py | 8 +- .../{ => _legacy}/on_wlm/test_preview_wlm.py | 102 +- tests/{ => _legacy}/on_wlm/test_restart.py | 10 +- .../test_simple_base_settings_on_wlm.py | 24 +- .../on_wlm/test_simple_entity_launch.py | 26 +- .../on_wlm/test_slurm_commands.py | 2 +- tests/{ => _legacy}/on_wlm/test_stop.py | 10 +- .../on_wlm/test_wlm_orc_config_settings.py | 48 +- tests/{ => _legacy}/test_alps_settings.py | 2 +- tests/{ => _legacy}/test_batch_settings.py | 0 tests/{ => _legacy}/test_cli.py | 16 +- tests/{ => _legacy}/test_collector_manager.py | 36 +- tests/{ => _legacy}/test_collector_sink.py | 0 tests/{ => _legacy}/test_collectors.py | 24 +- tests/_legacy/test_colo_model_local.py | 324 +++++ tests/{ => _legacy}/test_colo_model_lsf.py | 122 +- tests/{ => _legacy}/test_config.py | 0 tests/{ => _legacy}/test_containers.py | 38 +- tests/{ => _legacy}/test_controller.py | 10 +- tests/{ => _legacy}/test_controller_errors.py | 79 +- tests/{ => _legacy}/test_dbnode.py | 52 +- tests/{ => _legacy}/test_dragon_client.py | 6 +- tests/{ => _legacy}/test_dragon_installer.py | 32 +- tests/{ => _legacy}/test_dragon_launcher.py | 10 +- tests/{ => _legacy}/test_dragon_run_policy.py | 8 +- .../{ => _legacy}/test_dragon_run_request.py | 49 +- .../test_dragon_run_request_nowlm.py | 4 +- .../{ => _legacy}/test_dragon_runsettings.py | 0 tests/{ => _legacy}/test_dragon_step.py | 6 +- tests/_legacy/test_ensemble.py | 306 +++++ tests/{ => _legacy}/test_entitylist.py | 0 tests/_legacy/test_experiment.py | 372 ++++++ tests/{ => _legacy}/test_fixtures.py | 30 +- tests/_legacy/test_generator.py | 381 ++++++ tests/{ => _legacy}/test_helpers.py | 22 +- tests/{ => _legacy}/test_indirect.py | 2 +- tests/{ => _legacy}/test_interrupt.py | 26 +- tests/{ => _legacy}/test_launch_errors.py | 36 +- tests/{ => _legacy}/test_local_launch.py | 14 +- tests/{ => _legacy}/test_local_multi_run.py | 16 +- tests/{ => _legacy}/test_local_restart.py | 18 +- tests/{ => _legacy}/test_logs.py | 2 +- tests/{ => _legacy}/test_lsf_parser.py | 16 +- tests/{ => _legacy}/test_lsf_settings.py | 2 +- tests/{ => _legacy}/test_manifest.py | 94 +- tests/{ => _legacy}/test_model.py | 77 +- tests/{ => _legacy}/test_modelwriter.py | 32 +- tests/{ => _legacy}/test_mpi_settings.py | 2 +- tests/{ => _legacy}/test_multidb.py | 298 ++--- .../{ => _legacy}/test_orc_config_settings.py | 32 +- tests/{ => _legacy}/test_orchestrator.py | 205 ++-- tests/{ => _legacy}/test_output_files.py | 108 +- tests/{ => _legacy}/test_pals_settings.py | 2 +- tests/{ => _legacy}/test_pbs_parser.py | 14 +- tests/{ => _legacy}/test_pbs_settings.py | 0 tests/{ => _legacy}/test_preview.py | 368 +++--- .../test_reconnect_orchestrator.py | 41 +- tests/{ => _legacy}/test_run_settings.py | 54 + tests/{ => _legacy}/test_schema_utils.py | 0 tests/{ => _legacy}/test_serialize.py | 28 +- .../{ => _legacy}/test_sge_batch_settings.py | 2 +- tests/{ => _legacy}/test_shell_util.py | 0 tests/{ => _legacy}/test_slurm_get_alloc.py | 0 tests/{ => _legacy}/test_slurm_parser.py | 46 +- tests/{ => _legacy}/test_slurm_settings.py | 8 +- tests/{ => _legacy}/test_slurm_validation.py | 0 tests/{ => _legacy}/test_smartredis.py | 46 +- tests/{ => _legacy}/test_step_info.py | 8 +- tests/{ => _legacy}/test_symlinking.py | 94 +- tests/{ => _legacy}/test_telemetry_monitor.py | 331 +++--- tests/{ => _legacy}/utils/test_network.py | 0 tests/{ => _legacy}/utils/test_security.py | 0 tests/backends/test_ml_init.py | 48 + tests/{dragon => dragon_wlm}/__init__.py | 0 tests/{dragon => dragon_wlm}/channel.py | 0 tests/{dragon => dragon_wlm}/conftest.py | 3 - tests/{dragon => dragon_wlm}/feature_store.py | 0 .../test_core_machine_learning_worker.py | 6 +- .../test_device_manager.py | 0 .../test_dragon_backend.py | 2 +- tests/dragon_wlm/test_dragon_comm_utils.py | 257 ++++ .../test_dragon_ddict_utils.py | 0 .../test_environment_loader.py | 0 .../test_error_handling.py | 0 .../test_event_consumer.py | 0 .../test_featurestore.py | 0 .../test_featurestore_base.py | 0 .../test_featurestore_integration.py | 0 .../test_inference_reply.py | 0 .../test_inference_request.py | 0 .../test_protoclient.py | 0 .../test_reply_building.py | 0 .../test_request_dispatcher.py | 24 +- .../test_torch_worker.py | 0 .../test_worker_manager.py | 5 +- .../{dragon => dragon_wlm}/utils/__init__.py | 0 tests/{dragon => dragon_wlm}/utils/channel.py | 0 .../{dragon => dragon_wlm}/utils/msg_pump.py | 0 tests/{dragon => dragon_wlm}/utils/worker.py | 0 tests/mli/test_integrated_torch_worker.py | 10 +- tests/mli/test_service.py | 2 +- tests/on_wlm/test_colocated_model.py | 194 --- tests/temp_tests/steps_tests.py | 139 +++ tests/temp_tests/test_colocatedJobGroup.py | 95 ++ .../test_core/test_commands/test_command.py | 95 ++ .../test_commands/test_commandList.py | 99 ++ .../test_commands/test_launchCommands.py | 52 + tests/temp_tests/test_jobGroup.py | 110 ++ tests/temp_tests/test_launchable.py | 306 +++++ tests/temp_tests/test_settings/conftest.py | 61 + .../test_settings/test_alpsLauncher.py | 232 ++++ .../test_settings/test_batchSettings.py | 80 ++ tests/temp_tests/test_settings/test_common.py | 39 + .../temp_tests/test_settings/test_dispatch.py | 419 +++++++ .../test_settings/test_dragonLauncher.py | 116 ++ .../test_settings/test_launchSettings.py | 89 ++ .../test_settings/test_localLauncher.py | 169 +++ .../test_settings/test_lsfLauncher.py | 199 ++++ .../test_settings/test_lsfScheduler.py | 77 ++ .../test_settings/test_mpiLauncher.py | 304 +++++ .../test_settings/test_palsLauncher.py | 158 +++ .../test_settings/test_pbsScheduler.py | 88 ++ .../test_settings/test_slurmLauncher.py | 398 +++++++ .../test_settings/test_slurmScheduler.py | 136 +++ tests/test_application.py | 207 ++++ tests/test_colo_model_local.py | 314 ----- .../easy/correct/invalidtag.txt | 3 + .../easy/marked/invalidtag.txt | 3 + .../dir_test/dir_test_0/smartsim_params.txt | 2 +- .../dir_test/dir_test_1/smartsim_params.txt | 2 +- .../dir_test/dir_test_2/smartsim_params.txt | 2 +- .../dir_test/dir_test_3/smartsim_params.txt | 2 +- .../log_params/smartsim_params.txt | 8 +- .../to_copy_dir/{mock.txt => mock_1.txt} | 0 .../generator_files/to_copy_dir/mock_2.txt | 0 .../generator_files/to_copy_dir/mock_3.txt | 0 .../generator_files/to_symlink_dir/mock_1.txt | 0 .../to_symlink_dir/{mock2.txt => mock_2.txt} | 0 .../generator_files/to_symlink_dir/mock_3.txt | 0 tests/test_configs/send_data.py | 2 +- .../telemetry/colocatedmodel.json | 18 +- .../test_configs/telemetry/db_and_model.json | 18 +- .../telemetry/db_and_model_1run.json | 14 +- tests/test_configs/telemetry/ensembles.json | 8 +- .../test_configs/telemetry/serialmodels.json | 6 +- tests/test_configs/telemetry/telemetry.json | 150 +-- tests/test_ensemble.py | 626 ++++++---- tests/test_experiment.py | 1037 +++++++++++----- tests/test_file_operations.py | 786 +++++++++++++ tests/test_generator.py | 675 ++++++----- tests/test_init.py | 25 +- tests/test_intervals.py | 87 ++ tests/test_launch_history.py | 205 ++++ tests/test_operations.py | 364 ++++++ tests/test_permutation_strategies.py | 203 ++++ tests/test_shell_launcher.py | 392 +++++++ 361 files changed, 20117 insertions(+), 13550 deletions(-) create mode 100644 smartsim/_core/arguments/shell.py create mode 100644 smartsim/_core/commands/__init__.py create mode 100644 smartsim/_core/commands/command.py create mode 100644 smartsim/_core/commands/command_list.py create mode 100644 smartsim/_core/commands/launch_commands.py delete mode 100644 smartsim/_core/control/controller.py delete mode 100644 smartsim/_core/control/controller_utils.py create mode 100644 smartsim/_core/control/interval.py delete mode 100644 smartsim/_core/control/jobmanager.py create mode 100644 smartsim/_core/control/launch_history.py rename smartsim/_core/control/{previewrenderer.py => preview_renderer.py} (92%) create mode 100644 smartsim/_core/dispatch.py delete mode 100644 smartsim/_core/entrypoints/colocated.py create mode 100644 smartsim/_core/entrypoints/file_operations.py delete mode 100644 smartsim/_core/entrypoints/redis.py rename smartsim/_core/entrypoints/{telemetrymonitor.py => telemetry_monitor.py} (100%) delete mode 100644 smartsim/_core/generation/modelwriter.py create mode 100644 smartsim/_core/generation/operations/operations.py create mode 100644 smartsim/_core/generation/operations/utils/helpers.py delete mode 100644 smartsim/_core/launcher/colocated.py rename smartsim/_core/launcher/dragon/{dragonBackend.py => dragon_backend.py} (96%) rename smartsim/_core/launcher/dragon/{dragonConnector.py => dragon_connector.py} (92%) rename smartsim/_core/launcher/dragon/{dragonLauncher.py => dragon_launcher.py} (71%) rename smartsim/_core/launcher/dragon/{dragonSockets.py => dragon_sockets.py} (97%) rename smartsim/_core/launcher/lsf/{lsfCommands.py => lsf_commands.py} (100%) rename smartsim/_core/launcher/lsf/{lsfLauncher.py => lsf_launcher.py} (96%) rename smartsim/_core/launcher/lsf/{lsfParser.py => lsf_parser.py} (100%) rename smartsim/_core/launcher/pbs/{pbsCommands.py => pbs_commands.py} (100%) rename smartsim/_core/launcher/pbs/{pbsLauncher.py => pbs_launcher.py} (96%) rename smartsim/_core/launcher/pbs/{pbsParser.py => pbs_parser.py} (100%) rename smartsim/_core/launcher/sge/{sgeCommands.py => sge_commands.py} (100%) rename smartsim/_core/launcher/sge/{sgeLauncher.py => sge_launcher.py} (93%) rename smartsim/_core/launcher/sge/{sgeParser.py => sge_parser.py} (100%) rename smartsim/_core/launcher/slurm/{slurmCommands.py => slurm_commands.py} (100%) rename smartsim/_core/launcher/slurm/{slurmLauncher.py => slurm_launcher.py} (97%) rename smartsim/_core/launcher/slurm/{slurmParser.py => slurm_parser.py} (100%) rename smartsim/_core/launcher/step/{alpsStep.py => alps_step.py} (91%) rename smartsim/_core/launcher/step/{dragonStep.py => dragon_step.py} (98%) rename smartsim/_core/launcher/step/{localStep.py => local_step.py} (86%) rename smartsim/_core/launcher/step/{lsfStep.py => lsf_step.py} (95%) rename smartsim/_core/launcher/step/{mpiStep.py => mpi_step.py} (86%) rename smartsim/_core/launcher/step/{pbsStep.py => pbs_step.py} (94%) rename smartsim/_core/launcher/step/{sgeStep.py => sge_step.py} (100%) rename smartsim/_core/launcher/step/{slurmStep.py => slurm_step.py} (89%) rename smartsim/_core/launcher/{stepInfo.py => step_info.py} (52%) rename smartsim/_core/launcher/{stepMapping.py => step_mapping.py} (100%) rename smartsim/_core/launcher/{taskManager.py => task_manager.py} (100%) rename smartsim/_core/launcher/util/{launcherUtil.py => launcher_util.py} (100%) create mode 100644 smartsim/_core/mli/infrastructure/control/dragon_util.py rename smartsim/_core/schemas/{dragonRequests.py => dragon_requests.py} (100%) rename smartsim/_core/schemas/{dragonResponses.py => dragon_responses.py} (96%) rename {tests => smartsim/_core/shell}/__init__.py (100%) create mode 100644 smartsim/_core/shell/shell_launcher.py create mode 100644 smartsim/_core/utils/launcher.py delete mode 100644 smartsim/_core/utils/redis.py create mode 100644 smartsim/builders/__init__.py create mode 100644 smartsim/builders/ensemble.py create mode 100644 smartsim/builders/utils/strategies.py rename smartsim/entity/{strategies.py => _mock.py} (53%) create mode 100644 smartsim/entity/application.py delete mode 100644 smartsim/entity/ensemble.py delete mode 100644 smartsim/entity/entityList.py delete mode 100644 smartsim/entity/model.py create mode 100644 smartsim/launchable/__init__.py create mode 100644 smartsim/launchable/base_job.py create mode 100644 smartsim/launchable/base_job_group.py create mode 100644 smartsim/launchable/colocated_job_group.py create mode 100644 smartsim/launchable/job.py create mode 100644 smartsim/launchable/job_group.py create mode 100644 smartsim/launchable/launchable.py create mode 100644 smartsim/launchable/mpmd_job.py create mode 100644 smartsim/launchable/mpmd_pair.py create mode 100644 smartsim/settings/arguments/__init__.py create mode 100644 smartsim/settings/arguments/batch/__init__.py create mode 100644 smartsim/settings/arguments/batch/lsf.py create mode 100644 smartsim/settings/arguments/batch/pbs.py create mode 100644 smartsim/settings/arguments/batch/slurm.py create mode 100644 smartsim/settings/arguments/batch_arguments.py create mode 100644 smartsim/settings/arguments/launch/__init__.py rename smartsim/settings/{alpsSettings.py => arguments/launch/alps.py} (63%) rename smartsim/settings/{dragonRunSettings.py => arguments/launch/dragon.py} (70%) create mode 100644 smartsim/settings/arguments/launch/local.py create mode 100644 smartsim/settings/arguments/launch/lsf.py create mode 100644 smartsim/settings/arguments/launch/mpi.py create mode 100644 smartsim/settings/arguments/launch/pals.py create mode 100644 smartsim/settings/arguments/launch/slurm.py create mode 100644 smartsim/settings/arguments/launch_arguments.py delete mode 100644 smartsim/settings/base.py create mode 100644 smartsim/settings/base_settings.py create mode 100644 smartsim/settings/batch_command.py create mode 100644 smartsim/settings/batch_settings.py create mode 100644 smartsim/settings/common.py delete mode 100644 smartsim/settings/containers.py create mode 100644 smartsim/settings/launch_command.py create mode 100644 smartsim/settings/launch_settings.py delete mode 100644 smartsim/settings/lsfSettings.py delete mode 100644 smartsim/settings/mpiSettings.py delete mode 100644 smartsim/settings/palsSettings.py delete mode 100644 smartsim/settings/pbsSettings.py delete mode 100644 smartsim/settings/settings.py rename smartsim/settings/{sgeSettings.py => sge_settings.py} (100%) delete mode 100644 smartsim/settings/slurmSettings.py create mode 100644 smartsim/types.py create mode 100644 tests/_legacy/__init__.py rename tests/{ => _legacy}/backends/run_sklearn_onnx.py (99%) rename tests/{ => _legacy}/backends/run_tf.py (100%) rename tests/{ => _legacy}/backends/run_torch.py (97%) rename tests/{ => _legacy}/backends/test_cli_mini_exp.py (87%) rename tests/{ => _legacy}/backends/test_dataloader.py (90%) rename tests/{ => _legacy}/backends/test_dbmodel.py (81%) rename tests/{ => _legacy}/backends/test_dbscript.py (71%) rename tests/{ => _legacy}/backends/test_onnx.py (84%) rename tests/{ => _legacy}/backends/test_tf.py (88%) rename tests/{ => _legacy}/backends/test_torch.py (83%) rename tests/{ => _legacy}/full_wlm/test_generic_batch_launch.py (72%) rename tests/{ => _legacy}/full_wlm/test_generic_orc_launch_batch.py (51%) rename tests/{ => _legacy}/full_wlm/test_mpmd.py (87%) rename tests/{ => _legacy}/full_wlm/test_slurm_allocation.py (100%) rename tests/{ => _legacy}/full_wlm/test_symlinking.py (77%) rename tests/{ => _legacy}/full_wlm/test_wlm_helper_functions.py (100%) rename tests/{ => _legacy}/install/test_build.py (100%) rename tests/{ => _legacy}/install/test_buildenv.py (100%) create mode 100644 tests/_legacy/install/test_builder.py rename tests/{ => _legacy}/install/test_mlpackage.py (100%) rename tests/{ => _legacy}/install/test_package_retriever.py (100%) rename tests/{ => _legacy}/install/test_platform.py (100%) rename tests/{ => _legacy}/install/test_redisai_builder.py (100%) rename tests/{ => _legacy}/on_wlm/test_base_settings_on_wlm.py (74%) create mode 100644 tests/_legacy/on_wlm/test_colocated_model.py rename tests/{ => _legacy}/on_wlm/test_containers_wlm.py (88%) rename tests/{ => _legacy}/on_wlm/test_dragon.py (86%) rename tests/{ => _legacy}/on_wlm/test_dragon_entrypoint.py (100%) rename tests/{ => _legacy}/on_wlm/test_generic_orc_launch.py (63%) rename tests/{ => _legacy}/on_wlm/test_het_job.py (93%) rename tests/{ => _legacy}/on_wlm/test_launch_errors.py (84%) rename tests/{ => _legacy}/on_wlm/test_launch_ompi_lsf.py (87%) rename tests/{ => _legacy}/on_wlm/test_local_step.py (94%) rename tests/{ => _legacy}/on_wlm/test_preview_wlm.py (79%) rename tests/{ => _legacy}/on_wlm/test_restart.py (85%) rename tests/{ => _legacy}/on_wlm/test_simple_base_settings_on_wlm.py (80%) rename tests/{ => _legacy}/on_wlm/test_simple_entity_launch.py (85%) rename tests/{ => _legacy}/on_wlm/test_slurm_commands.py (97%) rename tests/{ => _legacy}/on_wlm/test_stop.py (90%) rename tests/{ => _legacy}/on_wlm/test_wlm_orc_config_settings.py (69%) rename tests/{ => _legacy}/test_alps_settings.py (98%) rename tests/{ => _legacy}/test_batch_settings.py (100%) rename tests/{ => _legacy}/test_cli.py (97%) rename tests/{ => _legacy}/test_collector_manager.py (93%) rename tests/{ => _legacy}/test_collector_sink.py (100%) rename tests/{ => _legacy}/test_collectors.py (94%) create mode 100644 tests/_legacy/test_colo_model_local.py rename tests/{ => _legacy}/test_colo_model_lsf.py (70%) rename tests/{ => _legacy}/test_config.py (100%) rename tests/{ => _legacy}/test_containers.py (85%) rename tests/{ => _legacy}/test_controller.py (90%) rename tests/{ => _legacy}/test_controller_errors.py (74%) rename tests/{ => _legacy}/test_dbnode.py (72%) rename tests/{ => _legacy}/test_dragon_client.py (97%) rename tests/{ => _legacy}/test_dragon_installer.py (94%) rename tests/{ => _legacy}/test_dragon_launcher.py (98%) rename tests/{ => _legacy}/test_dragon_run_policy.py (97%) rename tests/{ => _legacy}/test_dragon_run_request.py (95%) rename tests/{ => _legacy}/test_dragon_run_request_nowlm.py (97%) rename tests/{ => _legacy}/test_dragon_runsettings.py (100%) rename tests/{ => _legacy}/test_dragon_step.py (98%) create mode 100644 tests/_legacy/test_ensemble.py rename tests/{ => _legacy}/test_entitylist.py (100%) create mode 100644 tests/_legacy/test_experiment.py rename tests/{ => _legacy}/test_fixtures.py (70%) create mode 100644 tests/_legacy/test_generator.py rename tests/{ => _legacy}/test_helpers.py (88%) rename tests/{ => _legacy}/test_indirect.py (99%) rename tests/{ => _legacy}/test_interrupt.py (84%) rename tests/{ => _legacy}/test_launch_errors.py (72%) rename tests/{ => _legacy}/test_local_launch.py (84%) rename tests/{ => _legacy}/test_local_multi_run.py (80%) rename tests/{ => _legacy}/test_local_restart.py (81%) rename tests/{ => _legacy}/test_logs.py (99%) rename tests/{ => _legacy}/test_lsf_parser.py (92%) rename tests/{ => _legacy}/test_lsf_settings.py (99%) rename tests/{ => _legacy}/test_manifest.py (71%) rename tests/{ => _legacy}/test_model.py (75%) rename tests/{ => _legacy}/test_modelwriter.py (86%) rename tests/{ => _legacy}/test_mpi_settings.py (99%) rename tests/{ => _legacy}/test_multidb.py (55%) rename tests/{ => _legacy}/test_orc_config_settings.py (76%) rename tests/{ => _legacy}/test_orchestrator.py (56%) rename tests/{ => _legacy}/test_output_files.py (63%) rename tests/{ => _legacy}/test_pals_settings.py (99%) rename tests/{ => _legacy}/test_pbs_parser.py (88%) rename tests/{ => _legacy}/test_pbs_settings.py (100%) rename tests/{ => _legacy}/test_preview.py (78%) rename tests/{ => _legacy}/test_reconnect_orchestrator.py (68%) rename tests/{ => _legacy}/test_run_settings.py (90%) rename tests/{ => _legacy}/test_schema_utils.py (100%) rename tests/{ => _legacy}/test_serialize.py (87%) rename tests/{ => _legacy}/test_sge_batch_settings.py (98%) rename tests/{ => _legacy}/test_shell_util.py (100%) rename tests/{ => _legacy}/test_slurm_get_alloc.py (100%) rename tests/{ => _legacy}/test_slurm_parser.py (84%) rename tests/{ => _legacy}/test_slurm_settings.py (97%) rename tests/{ => _legacy}/test_slurm_validation.py (100%) rename tests/{ => _legacy}/test_smartredis.py (76%) rename tests/{ => _legacy}/test_step_info.py (89%) rename tests/{ => _legacy}/test_symlinking.py (75%) rename tests/{ => _legacy}/test_telemetry_monitor.py (81%) rename tests/{ => _legacy}/utils/test_network.py (100%) rename tests/{ => _legacy}/utils/test_security.py (100%) create mode 100644 tests/backends/test_ml_init.py rename tests/{dragon => dragon_wlm}/__init__.py (100%) rename tests/{dragon => dragon_wlm}/channel.py (100%) rename tests/{dragon => dragon_wlm}/conftest.py (99%) rename tests/{dragon => dragon_wlm}/feature_store.py (100%) rename tests/{dragon => dragon_wlm}/test_core_machine_learning_worker.py (98%) rename tests/{dragon => dragon_wlm}/test_device_manager.py (100%) rename tests/{dragon => dragon_wlm}/test_dragon_backend.py (99%) create mode 100644 tests/dragon_wlm/test_dragon_comm_utils.py rename tests/{dragon => dragon_wlm}/test_dragon_ddict_utils.py (100%) rename tests/{dragon => dragon_wlm}/test_environment_loader.py (100%) rename tests/{dragon => dragon_wlm}/test_error_handling.py (100%) rename tests/{dragon => dragon_wlm}/test_event_consumer.py (100%) rename tests/{dragon => dragon_wlm}/test_featurestore.py (100%) rename tests/{dragon => dragon_wlm}/test_featurestore_base.py (100%) rename tests/{dragon => dragon_wlm}/test_featurestore_integration.py (100%) rename tests/{dragon => dragon_wlm}/test_inference_reply.py (100%) rename tests/{dragon => dragon_wlm}/test_inference_request.py (100%) rename tests/{dragon => dragon_wlm}/test_protoclient.py (100%) rename tests/{dragon => dragon_wlm}/test_reply_building.py (100%) rename tests/{dragon => dragon_wlm}/test_request_dispatcher.py (95%) rename tests/{dragon => dragon_wlm}/test_torch_worker.py (100%) rename tests/{dragon => dragon_wlm}/test_worker_manager.py (98%) rename tests/{dragon => dragon_wlm}/utils/__init__.py (100%) rename tests/{dragon => dragon_wlm}/utils/channel.py (100%) rename tests/{dragon => dragon_wlm}/utils/msg_pump.py (100%) rename tests/{dragon => dragon_wlm}/utils/worker.py (100%) delete mode 100644 tests/on_wlm/test_colocated_model.py create mode 100644 tests/temp_tests/steps_tests.py create mode 100644 tests/temp_tests/test_colocatedJobGroup.py create mode 100644 tests/temp_tests/test_core/test_commands/test_command.py create mode 100644 tests/temp_tests/test_core/test_commands/test_commandList.py create mode 100644 tests/temp_tests/test_core/test_commands/test_launchCommands.py create mode 100644 tests/temp_tests/test_jobGroup.py create mode 100644 tests/temp_tests/test_launchable.py create mode 100644 tests/temp_tests/test_settings/conftest.py create mode 100644 tests/temp_tests/test_settings/test_alpsLauncher.py create mode 100644 tests/temp_tests/test_settings/test_batchSettings.py create mode 100644 tests/temp_tests/test_settings/test_common.py create mode 100644 tests/temp_tests/test_settings/test_dispatch.py create mode 100644 tests/temp_tests/test_settings/test_dragonLauncher.py create mode 100644 tests/temp_tests/test_settings/test_launchSettings.py create mode 100644 tests/temp_tests/test_settings/test_localLauncher.py create mode 100644 tests/temp_tests/test_settings/test_lsfLauncher.py create mode 100644 tests/temp_tests/test_settings/test_lsfScheduler.py create mode 100644 tests/temp_tests/test_settings/test_mpiLauncher.py create mode 100644 tests/temp_tests/test_settings/test_palsLauncher.py create mode 100644 tests/temp_tests/test_settings/test_pbsScheduler.py create mode 100644 tests/temp_tests/test_settings/test_slurmLauncher.py create mode 100644 tests/temp_tests/test_settings/test_slurmScheduler.py create mode 100644 tests/test_application.py delete mode 100644 tests/test_colo_model_local.py create mode 100644 tests/test_configs/generator_files/easy/correct/invalidtag.txt create mode 100644 tests/test_configs/generator_files/easy/marked/invalidtag.txt rename tests/test_configs/generator_files/to_copy_dir/{mock.txt => mock_1.txt} (100%) create mode 100644 tests/test_configs/generator_files/to_copy_dir/mock_2.txt create mode 100644 tests/test_configs/generator_files/to_copy_dir/mock_3.txt create mode 100644 tests/test_configs/generator_files/to_symlink_dir/mock_1.txt rename tests/test_configs/generator_files/to_symlink_dir/{mock2.txt => mock_2.txt} (100%) create mode 100644 tests/test_configs/generator_files/to_symlink_dir/mock_3.txt create mode 100644 tests/test_file_operations.py create mode 100644 tests/test_intervals.py create mode 100644 tests/test_launch_history.py create mode 100644 tests/test_operations.py create mode 100644 tests/test_permutation_strategies.py create mode 100644 tests/test_shell_launcher.py diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 9b988520a4..5076870d7d 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -69,6 +69,14 @@ jobs: with: python-version: ${{ matrix.py_v }} + - name: Check Test Files are Marked + run: | + diff <(find tests -path tests/_legacy -prune -o -type f -name 'test_*.py' -print \ + | xargs grep -l 'pytestmark' \ + | sort) \ + <(find tests -path tests/_legacy -prune -o -type f -name 'test_*.py' -print \ + | sort) + - name: Install build-essentials for Ubuntu if: contains( matrix.os, 'ubuntu' ) run: | @@ -131,8 +139,9 @@ jobs: run: | make check-mypy - - name: Run Pylint - run: make check-lint + # TODO: Re-enable static analysis once API is firmed up + # - name: Run Pylint + # run: make check-lint # Run isort/black style check - name: Run isort diff --git a/.pylintrc b/.pylintrc index aa378d0399..34580db3b6 100644 --- a/.pylintrc +++ b/.pylintrc @@ -167,7 +167,7 @@ max-module-lines=1000 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. -single-line-class-stmt=no +single-line-class-stmt=yes # Allow the body of an if to be on the same line as the test if there is no # else. diff --git a/Makefile b/Makefile index 4e64033d63..b4ceef2194 100644 --- a/Makefile +++ b/Makefile @@ -164,22 +164,22 @@ tutorials-prod: # help: test - Run all tests .PHONY: test test: - @python -m pytest --ignore=tests/full_wlm/ --ignore=tests/dragon + @python -m pytest --ignore=tests/full_wlm/ --ignore=tests/dragon_wlm # help: test-verbose - Run all tests verbosely .PHONY: test-verbose test-verbose: - @python -m pytest -vv --ignore=tests/full_wlm/ --ignore=tests/dragon + @python -m pytest -vv --ignore=tests/full_wlm/ --ignore=tests/dragon_wlm # help: test-debug - Run all tests with debug output .PHONY: test-debug test-debug: - @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ --ignore=tests/dragon + @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ --ignore=tests/dragon_wlm # help: test-cov - Run all tests with coverage .PHONY: test-cov test-cov: - @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ --ignore=tests/dragon + @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ --ignore=tests/dragon_wlm # help: test-full - Run all WLM tests with Python coverage (full test suite) @@ -196,4 +196,4 @@ test-wlm: # help: test-dragon - Run dragon-specific tests .PHONY: test-dragon test-dragon: - @dragon pytest tests/dragon + @dragon pytest tests/dragon_wlm diff --git a/conftest.py b/conftest.py index 54a47f9e23..895fcc9adb 100644 --- a/conftest.py +++ b/conftest.py @@ -27,48 +27,50 @@ from __future__ import annotations import asyncio -from collections import defaultdict -from dataclasses import dataclass import json import os import pathlib import shutil -import subprocess import signal import socket +import subprocess import sys import tempfile import time import typing as t import uuid import warnings +from glob import glob +from os import path as osp +from collections import defaultdict +from dataclasses import dataclass from subprocess import run -import time import psutil import pytest import smartsim from smartsim import Experiment -from smartsim._core.launcher.dragon.dragonConnector import DragonConnector -from smartsim._core.launcher.dragon.dragonLauncher import DragonLauncher from smartsim._core.config import CONFIG from smartsim._core.config.config import Config +from smartsim._core.launcher.dragon.dragon_connector import DragonConnector +from smartsim._core.launcher.dragon.dragon_launcher import DragonLauncher +from smartsim._core.generation.operations.operations import ConfigureOperation, CopyOperation, SymlinkOperation +from smartsim._core.generation.generator import Generator from smartsim._core.utils.telemetry.telemetry import JobEntity -from smartsim.database import Orchestrator -from smartsim.entity import Model +from smartsim.database import FeatureStore +from smartsim.entity import Application from smartsim.error import SSConfigError, SSInternalError from smartsim.log import get_logger -from smartsim.settings import ( - AprunSettings, - DragonRunSettings, - JsrunSettings, - MpiexecSettings, - MpirunSettings, - PalsMpiexecSettings, - RunSettings, - SrunSettings, -) +# Mock imports +class AprunSettings: pass +class DragonRunSettings: pass +class JsrunSettings: pass +class MpiexecSettings: pass +class MpirunSettings: pass +class PalsMpiexecSettings: pass +class RunSettings: pass +class SrunSettings: pass logger = get_logger(__name__) @@ -142,7 +144,7 @@ def pytest_sessionstart( time.sleep(0.1) if CONFIG.dragon_server_path is None: - dragon_server_path = os.path.join(test_output_root, "dragon_server") + dragon_server_path = os.path.join(test_output_root, "dragon_server") os.makedirs(dragon_server_path) os.environ["SMARTSIM_DRAGON_SERVER_PATH"] = dragon_server_path @@ -184,7 +186,7 @@ def build_mpi_app() -> t.Optional[pathlib.Path]: if cc is None: return None - path_to_src = pathlib.Path(FileUtils().get_test_conf_path("mpi")) + path_to_src = pathlib.Path(FileUtils().get_test_conf_path("mpi")) path_to_out = pathlib.Path(test_output_root) / "apps" / "mpi_app" os.makedirs(path_to_out.parent, exist_ok=True) cmd = [cc, str(path_to_src / "mpi_hello.c"), "-o", str(path_to_out)] @@ -195,11 +197,12 @@ def build_mpi_app() -> t.Optional[pathlib.Path]: else: return None + @pytest.fixture(scope="session") def mpi_app_path() -> t.Optional[pathlib.Path]: """Return path to MPI app if it was built - return None if it could not or will not be built + return None if it could not or will not be built """ if not CONFIG.test_mpi: return None @@ -467,13 +470,65 @@ def check_output_dir() -> None: @pytest.fixture -def dbutils() -> t.Type[DBUtils]: - return DBUtils +def fsutils() -> t.Type[FSUtils]: + return FSUtils + +@pytest.fixture +def files(fileutils): + path_to_files = fileutils.get_test_conf_path( + osp.join("generator_files", "easy", "correct/") + ) + list_of_files_strs = glob(path_to_files + "/*") + yield [pathlib.Path(str_path) for str_path in list_of_files_strs] + + +@pytest.fixture +def directory(fileutils): + directory = fileutils.get_test_conf_path( + osp.join("generator_files", "easy", "correct/") + ) + yield [pathlib.Path(directory)] + + +@pytest.fixture(params=["files", "directory"]) +def source(request): + yield request.getfixturevalue(request.param) + + +@pytest.fixture +def mock_src(test_dir: str): + """Fixture to create a mock source path.""" + return pathlib.Path(test_dir) / pathlib.Path("mock_src") + + +@pytest.fixture +def mock_dest(): + """Fixture to create a mock destination path.""" + return pathlib.Path("mock_dest") + + +@pytest.fixture +def copy_operation(mock_src: pathlib.Path, mock_dest: pathlib.Path): + """Fixture to create a CopyOperation object.""" + return CopyOperation(src=mock_src, dest=mock_dest) + + +@pytest.fixture +def symlink_operation(mock_src: pathlib.Path, mock_dest: pathlib.Path): + """Fixture to create a CopyOperation object.""" + return SymlinkOperation(src=mock_src, dest=mock_dest) -class DBUtils: +@pytest.fixture +def configure_operation(mock_src: pathlib.Path, mock_dest: pathlib.Path): + """Fixture to create a Configure object.""" + return ConfigureOperation( + src=mock_src, dest=mock_dest, file_parameters={"FOO": "BAR"} + ) + +class FSUtils: @staticmethod - def get_db_configs() -> t.Dict[str, t.Any]: + def get_fs_configs() -> t.Dict[str, t.Any]: config_settings = { "enable_checkpoints": 1, "set_max_memory": "3gb", @@ -487,7 +542,7 @@ def get_db_configs() -> t.Dict[str, t.Any]: return config_settings @staticmethod - def get_smartsim_error_db_configs() -> t.Dict[str, t.Any]: + def get_smartsim_error_fs_configs() -> t.Dict[str, t.Any]: bad_configs = { "save": [ "-1", # frequency must be positive @@ -514,7 +569,7 @@ def get_smartsim_error_db_configs() -> t.Dict[str, t.Any]: return bad_configs @staticmethod - def get_type_error_db_configs() -> t.Dict[t.Union[int, str], t.Any]: + def get_type_error_fs_configs() -> t.Dict[t.Union[int, str], t.Any]: bad_configs: t.Dict[t.Union[int, str], t.Any] = { "save": [2, True, ["2"]], # frequency must be specified as a string "maxmemory": [99, True, ["99"]], # memory form must be a string @@ -535,15 +590,15 @@ def get_type_error_db_configs() -> t.Dict[t.Union[int, str], t.Any]: @staticmethod def get_config_edit_method( - db: Orchestrator, config_setting: str + fs: FeatureStore, config_setting: str ) -> t.Optional[t.Callable[..., None]]: - """Get a db configuration file edit method from a str""" + """Get a fs configuration file edit method from a str""" config_edit_methods: t.Dict[str, t.Callable[..., None]] = { - "enable_checkpoints": db.enable_checkpoints, - "set_max_memory": db.set_max_memory, - "set_eviction_strategy": db.set_eviction_strategy, - "set_max_clients": db.set_max_clients, - "set_max_message_size": db.set_max_message_size, + "enable_checkpoints": fs.enable_checkpoints, + "set_max_memory": fs.set_max_memory, + "set_eviction_strategy": fs.set_eviction_strategy, + "set_max_clients": fs.set_max_clients, + "set_max_message_size": fs.set_max_message_size, } return config_edit_methods.get(config_setting, None) @@ -649,21 +704,21 @@ class ColoUtils: @staticmethod def setup_test_colo( fileutils: t.Type[FileUtils], - db_type: str, + fs_type: str, exp: Experiment, application_file: str, - db_args: t.Dict[str, t.Any], + fs_args: t.Dict[str, t.Any], colo_settings: t.Optional[RunSettings] = None, - colo_model_name: str = "colocated_model", + colo_application_name: str = "colocated_application", port: t.Optional[int] = None, on_wlm: bool = False, - ) -> Model: - """Setup database needed for the colo pinning tests""" + ) -> Application: + """Setup feature store needed for the colo pinning tests""" # get test setup sr_test_script = fileutils.get_test_conf_path(application_file) - # Create an app with a colo_db which uses 1 db_cpu + # Create an app with a colo_fs which uses 1 fs_cpu if colo_settings is None: colo_settings = exp.create_run_settings( exe=sys.executable, exe_args=[sr_test_script] @@ -672,31 +727,31 @@ def setup_test_colo( colo_settings.set_tasks(1) colo_settings.set_nodes(1) - colo_model = exp.create_model(colo_model_name, colo_settings) + colo_application = exp.create_application(colo_application_name, colo_settings) - if db_type in ["tcp", "deprecated"]: - db_args["port"] = port if port is not None else _find_free_port(test_ports) - db_args["ifname"] = "lo" - if db_type == "uds" and colo_model_name is not None: + if fs_type in ["tcp", "deprecated"]: + fs_args["port"] = port if port is not None else _find_free_port(test_ports) + fs_args["ifname"] = "lo" + if fs_type == "uds" and colo_application_name is not None: tmp_dir = tempfile.gettempdir() socket_suffix = str(uuid.uuid4())[:7] - socket_name = f"{colo_model_name}_{socket_suffix}.socket" - db_args["unix_socket"] = os.path.join(tmp_dir, socket_name) + socket_name = f"{colo_application_name}_{socket_suffix}.socket" + fs_args["unix_socket"] = os.path.join(tmp_dir, socket_name) colocate_fun: t.Dict[str, t.Callable[..., None]] = { - "tcp": colo_model.colocate_db_tcp, - "deprecated": colo_model.colocate_db, - "uds": colo_model.colocate_db_uds, + "tcp": colo_application.colocate_fs_tcp, + "deprecated": colo_application.colocate_fs, + "uds": colo_application.colocate_fs_uds, } with warnings.catch_warnings(): - if db_type == "deprecated": - message = "`colocate_db` has been deprecated" + if fs_type == "deprecated": + message = "`colocate_fs` has been deprecated" warnings.filterwarnings("ignore", message=message) - colocate_fun[db_type](**db_args) - # assert model will launch with colocated db - assert colo_model.colocated - # Check to make sure that limit_db_cpus made it into the colo settings - return colo_model + colocate_fun[fs_type](**fs_args) + # assert application will launch with colocated fs + assert colo_application.colocated + # Check to make sure that limit_fs_cpus made it into the colo settings + return colo_application @pytest.fixture(scope="function") @@ -708,7 +763,9 @@ def global_dragon_teardown() -> None: """ if test_launcher != "dragon" or CONFIG.dragon_server_path is None: return - logger.debug(f"Tearing down Dragon infrastructure, server path: {CONFIG.dragon_server_path}") + logger.debug( + f"Tearing down Dragon infrastructure, server path: {CONFIG.dragon_server_path}" + ) dragon_connector = DragonConnector() dragon_connector.ensure_connected() dragon_connector.cleanup() @@ -744,7 +801,7 @@ def mock_sink() -> t.Type[MockSink]: @pytest.fixture def mock_con() -> t.Callable[[int, int], t.Iterable[t.Any]]: - """Generates mock db connection telemetry""" + """Generates mock fs connection telemetry""" def _mock_con(min: int = 1, max: int = 254) -> t.Iterable[t.Any]: for i in range(min, max): @@ -758,7 +815,7 @@ def _mock_con(min: int = 1, max: int = 254) -> t.Iterable[t.Any]: @pytest.fixture def mock_mem() -> t.Callable[[int, int], t.Iterable[t.Any]]: - """Generates mock db memory usage telemetry""" + """Generates mock fs memory usage telemetry""" def _mock_mem(min: int = 1, max: int = 1000) -> t.Iterable[t.Any]: for i in range(min, max): @@ -875,9 +932,13 @@ def num_calls(self) -> int: def details(self) -> t.List[t.Tuple[t.Tuple[t.Any, ...], t.Dict[str, t.Any]]]: return self._details -## Reuse database across tests -database_registry: t.DefaultDict[str, t.Optional[Orchestrator]] = defaultdict(lambda: None) +## Reuse feature store across tests + +feature_store_registry: t.DefaultDict[str, t.Optional[FeatureStore]] = defaultdict( + lambda: None +) + @pytest.fixture(scope="function") def local_experiment(test_dir: str) -> smartsim.Experiment: @@ -885,45 +946,48 @@ def local_experiment(test_dir: str) -> smartsim.Experiment: name = pathlib.Path(test_dir).stem return smartsim.Experiment(name, exp_path=test_dir, launcher="local") + @pytest.fixture(scope="function") def wlm_experiment(test_dir: str, wlmutils: WLMUtils) -> smartsim.Experiment: """Create a default experiment that uses the requested launcher""" name = pathlib.Path(test_dir).stem return smartsim.Experiment( - name, - exp_path=test_dir, - launcher=wlmutils.get_test_launcher() + name, exp_path=test_dir, launcher=wlmutils.get_test_launcher() ) -def _cleanup_db(name: str) -> None: - global database_registry - db = database_registry[name] - if db and db.is_active(): + +def _cleanup_fs(name: str) -> None: + global feature_store_registry + fs = feature_store_registry[name] + if fs and fs.is_active(): exp = Experiment("cleanup") try: - db = exp.reconnect_orchestrator(db.checkpoint_file) - exp.stop(db) + fs = exp.reconnect_feature_store(fs.checkpoint_file) + exp.stop(fs) except: pass + @dataclass class DBConfiguration: name: str launcher: str num_nodes: int - interface: t.Union[str,t.List[str]] + interface: t.Union[str, t.List[str]] hostlist: t.Optional[t.List[str]] port: int + @dataclass -class PrepareDatabaseOutput: - orchestrator: t.Optional[Orchestrator] # The actual orchestrator object - new_db: bool # True if a new database was created when calling prepare_db +class PrepareFeatureStoreOutput: + featurestore: t.Optional[FeatureStore] # The actual feature store object + new_fs: bool # True if a new feature store was created when calling prepare_fs + -# Reuse databases +# Reuse feature stores @pytest.fixture(scope="session") -def local_db() -> t.Generator[DBConfiguration, None, None]: - name = "local_db_fixture" +def local_fs() -> t.Generator[DBConfiguration, None, None]: + name = "local_fs_fixture" config = DBConfiguration( name, "local", @@ -933,30 +997,32 @@ def local_db() -> t.Generator[DBConfiguration, None, None]: _find_free_port(tuple(reversed(test_ports))), ) yield config - _cleanup_db(name) + _cleanup_fs(name) + + @pytest.fixture(scope="session") -def single_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]: +def single_fs(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]: hostlist = wlmutils.get_test_hostlist() hostlist = hostlist[-1:] if hostlist is not None else None - name = "single_db_fixture" + name = "single_fx_fixture" config = DBConfiguration( name, wlmutils.get_test_launcher(), 1, wlmutils.get_test_interface(), hostlist, - _find_free_port(tuple(reversed(test_ports))) + _find_free_port(tuple(reversed(test_ports))), ) yield config - _cleanup_db(name) + _cleanup_fs(name) @pytest.fixture(scope="session") -def clustered_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]: +def clustered_fs(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]: hostlist = wlmutils.get_test_hostlist() hostlist = hostlist[-4:-1] if hostlist is not None else None - name = "clustered_db_fixture" + name = "clustered_fs_fixture" config = DBConfiguration( name, wlmutils.get_test_launcher(), @@ -966,14 +1032,12 @@ def clustered_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None] _find_free_port(tuple(reversed(test_ports))), ) yield config - _cleanup_db(name) + _cleanup_fs(name) @pytest.fixture -def register_new_db() -> t.Callable[[DBConfiguration], Orchestrator]: - def _register_new_db( - config: DBConfiguration - ) -> Orchestrator: +def register_new_fs() -> t.Callable[[DBConfiguration], FeatureStore]: + def _register_new_fs(config: DBConfiguration) -> FeatureStore: exp_path = pathlib.Path(test_output_root, config.name) exp_path.mkdir(exist_ok=True) exp = Experiment( @@ -981,44 +1045,40 @@ def _register_new_db( exp_path=str(exp_path), launcher=config.launcher, ) - orc = exp.create_database( + feature_store = exp.create_feature_store( port=config.port, batch=False, interface=config.interface, hosts=config.hostlist, - db_nodes=config.num_nodes + fs_nodes=config.num_nodes, ) - exp.generate(orc, overwrite=True) - exp.start(orc) - global database_registry - database_registry[config.name] = orc - return orc - return _register_new_db + exp.generate(feature_store, overwrite=True) + exp.start(feature_store) + global feature_store_registry + feature_store_registry[config.name] = feature_store + return feature_store + + return _register_new_fs @pytest.fixture(scope="function") -def prepare_db( - register_new_db: t.Callable[ - [DBConfiguration], - Orchestrator - ] -) -> t.Callable[ - [DBConfiguration], - PrepareDatabaseOutput -]: - def _prepare_db(db_config: DBConfiguration) -> PrepareDatabaseOutput: - global database_registry - db = database_registry[db_config.name] - - new_db = False - db_up = False - - if db: - db_up = db.is_active() - - if not db_up or db is None: - db = register_new_db(db_config) - new_db = True - - return PrepareDatabaseOutput(db, new_db) - return _prepare_db +def prepare_fs( + register_new_fs: t.Callable[[DBConfiguration], FeatureStore] +) -> t.Callable[[DBConfiguration], PrepareFeatureStoreOutput]: + def _prepare_fs(fs_config: DBConfiguration) -> PrepareFeatureStoreOutput: + global feature_store_registry + fs = feature_store_registry[fs_config.name] + + new_fs = False + fs_up = False + + if fs: + fs_up = fs.is_active() + + if not fs_up or fs is None: + fs = register_new_fs(fs_config) + new_fs = True + + return PrepareFeatureStoreOutput(fs, new_fs) + + return _prepare_fs diff --git a/doc/changelog.md b/doc/changelog.md index bca9209f7a..752957bfdc 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- Merge core refactor into MLI feature branch - Implement asynchronous notifications for shared data - Quick bug fix in _validate - Add helper methods to MLI classes diff --git a/doc/installation_instructions/basic.rst b/doc/installation_instructions/basic.rst index 73fbceb253..a5db285ca8 100644 --- a/doc/installation_instructions/basic.rst +++ b/doc/installation_instructions/basic.rst @@ -45,14 +45,14 @@ ML Library Support ================== We currently support both Nvidia and AMD GPUs when using RedisAI for GPU inference. The support -for these GPUs often depends on the version of the CUDA or ROCm stack that is availble on your -machine. In _most_ cases, the versions backwards compatible. If you encounter problems, please +for these GPUs often depends on the version of the CUDA or ROCm stack that is available on your +machine. In _most_ cases, the versions are backwards compatible. If you encounter problems, please contact us and we can build the backend libraries for your desired version of CUDA and ROCm. CPU backends are provided for Apple (both Intel and Apple Silicon) and Linux (x86_64). Be sure to reference the table below to find which versions of the ML libraries are supported for -your particular platform. Additional, see :ref:`installation notes ` for helpful +your particular platform. Additionally, see :ref:`installation notes ` for helpful information regarding various system types before installation. Linux @@ -175,7 +175,7 @@ MacOSX .. note:: - Users have succesfully run SmartSim on Windows using Windows Subsystem for Linux + Users have successfully run SmartSim on Windows using Windows Subsystem for Linux with Nvidia support. Generally, users should follow the Linux instructions here, however we make no guarantee or offer of support. @@ -387,7 +387,7 @@ source remains at the site of the clone instead of in site-packages. pip install -e ".[dev]" # for zsh users Use the now installed ``smart`` cli to install the machine learning runtimes and -dragon. Referring to "Step 2: Build SmartSim above". +dragon. Referring to "Step 2: Build SmartSim" above. Build the SmartRedis library ============================ diff --git a/doc/tutorials/doc_examples/experiment_doc_examples/exp.py b/doc/tutorials/doc_examples/experiment_doc_examples/exp.py index b5374e7bd0..b4b4e01003 100644 --- a/doc/tutorials/doc_examples/experiment_doc_examples/exp.py +++ b/doc/tutorials/doc_examples/experiment_doc_examples/exp.py @@ -1,5 +1,5 @@ from smartsim import Experiment -from smartsim._core.control.previewrenderer import Verbosity +from smartsim._core.control.preview_renderer import Verbosity from smartsim.log import get_logger # Initialize an Experiment diff --git a/doc/tutorials/ml_inference/Inference-in-SmartSim.ipynb b/doc/tutorials/ml_inference/Inference-in-SmartSim.ipynb index 2b5f0a3a59..4afdc38955 100644 --- a/doc/tutorials/ml_inference/Inference-in-SmartSim.ipynb +++ b/doc/tutorials/ml_inference/Inference-in-SmartSim.ipynb @@ -44,8 +44,9 @@ ], "source": [ "## Installing the ML backends\n", - "from smartsim._core.utils.helpers import installed_redisai_backends\n", - "print(installed_redisai_backends())\n" + "# from smartsim._core.utils.helpers import installed_redisai_backends\n", + "#print(installed_redisai_backends())\n", + "# TODO: replace deprecated installed_redisai_backends" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 61e17891b3..bf721b0c99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,13 @@ force-exclude = ''' [tool.pytest.ini_options] log_cli = true log_cli_level = "debug" +testpaths = [ + "tests" +] +norecursedirs = [ + "tests/test_configs", + "tests/_legacy", +] markers = [ "group_a: fast test subset a", "group_b: fast test subset b", @@ -134,6 +141,31 @@ module = [ "smartsim.ml.torch.*", # must solve/ignore inheritance issues "watchdog", "dragon.*", + + # Ignore these modules while the core refactor is on going. Uncomment as + # needed for gradual typing + # + # FIXME: DO NOT MERGE THIS INTO DEVELOP BRANCH UNLESS THESE ARE PASSING OR + # REMOVED!! + "smartsim._core._cli.*", + "smartsim._core.control.manifest", + "smartsim._core.entrypoints.dragon_client", + "smartsim._core.launcher.launcher", + "smartsim._core.launcher.local.*", + "smartsim._core.launcher.lsf.*", + "smartsim._core.launcher.pbs.*", + "smartsim._core.launcher.sge.*", + "smartsim._core.launcher.slurm.*", + "smartsim._core.launcher.step.*", + "smartsim._core.launcher.step_info", + "smartsim._core.launcher.step_mapping", + "smartsim._core.launcher.task_manager", + "smartsim._core.utils.serialize", + "smartsim._core.utils.telemetry.*", + "smartsim.database.*", + "smartsim.settings.sge_settings", + "smartsim._core.control.controller_utils", + "smartsim.entity.dbnode", ] ignore_missing_imports = true ignore_errors = true diff --git a/smartsim/_core/__init__.py b/smartsim/_core/__init__.py index 4900787704..ee8d3cc96a 100644 --- a/smartsim/_core/__init__.py +++ b/smartsim/_core/__init__.py @@ -24,7 +24,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .control import Controller, Manifest, previewrenderer +from .control import Manifest, preview_renderer from .generation import Generator -__all__ = ["Controller", "Manifest", "Generator", "previewrenderer"] +__all__ = ["Manifest", "Generator", "preview_renderer"] diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index ec9ef4aa29..58ef31ab8a 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -28,6 +28,7 @@ import importlib.metadata import operator import os +import platform import re import shutil import textwrap @@ -43,9 +44,10 @@ display_post_install_logs, install_dragon, ) -from smartsim._core._cli.utils import SMART_LOGGER_FORMAT +from smartsim._core._cli.utils import SMART_LOGGER_FORMAT, pip from smartsim._core._install import builder -from smartsim._core._install.buildenv import BuildEnv, DbEngine, Version_, Versioner +from smartsim._core._install.buildenv import BuildEnv, SetupError, Version_, Versioner +from smartsim._core._install.builder import BuildError from smartsim._core._install.mlpackages import ( DEFAULT_MLPACKAGE_PATH, DEFAULT_MLPACKAGES, @@ -60,7 +62,6 @@ ) from smartsim._core._install.redisaiBuilder import RedisAIBuilder from smartsim._core.config import CONFIG -from smartsim._core.utils.helpers import installed_redisai_backends from smartsim.error import SSConfigError from smartsim.log import get_logger @@ -70,79 +71,6 @@ # may be installed into a different directory. -def check_backends_install() -> bool: - """Checks if backends have already been installed. - Logs details on how to proceed forward - if the SMARTSIM_RAI_LIB environment variable is set or if - backends have already been installed. - """ - rai_path = os.environ.get("SMARTSIM_RAI_LIB", "") - installed = installed_redisai_backends() - msg = "" - - if rai_path and installed: - msg = ( - f"There is no need to build. backends are already built and " - f"specified in the environment at 'SMARTSIM_RAI_LIB': {CONFIG.redisai}" - ) - elif rai_path and not installed: - msg = ( - "Before running 'smart build', unset your SMARTSIM_RAI_LIB environment " - "variable with 'unset SMARTSIM_RAI_LIB'." - ) - elif not rai_path and installed: - msg = ( - "If you wish to re-run `smart build`, you must first run `smart clean`." - " The following backend(s) must be removed: " + ", ".join(installed) - ) - - if msg: - logger.error(msg) - - return not bool(msg) - - -def build_database( - build_env: BuildEnv, versions: Versioner, keydb: bool, verbose: bool -) -> None: - # check database installation - database_name = "KeyDB" if keydb else "Redis" - database_builder = builder.DatabaseBuilder( - build_env(), - jobs=build_env.JOBS, - malloc=build_env.MALLOC, - verbose=verbose, - ) - if not database_builder.is_built: - logger.info( - f"Building {database_name} version {versions.REDIS} " - f"from {versions.REDIS_URL}" - ) - database_builder.build_from_git( - versions.REDIS_URL, branch=versions.REDIS_BRANCH - ) - database_builder.cleanup() - logger.info(f"{database_name} build complete!") - else: - logger.warning( - f"{database_name} was previously built, run 'smart clobber' to rebuild" - ) - - -def build_redis_ai( - platform: Platform, - mlpackages: MLPackageCollection, - build_env: BuildEnv, - verbose: bool, -) -> None: - logger.info("Building RedisAI and backends...") - rai_builder = RedisAIBuilder( - platform, mlpackages, build_env, CONFIG.build_path, verbose - ) - rai_builder.build() - rai_builder.cleanup_build() - - def parse_requirement( requirement: str, ) -> t.Tuple[str, t.Optional[str], t.Callable[[Version_], bool]]: @@ -228,19 +156,6 @@ def _format_incompatible_python_env_message( """) -def _configure_keydb_build(versions: Versioner) -> None: - """Configure the redis versions to be used during the build operation""" - versions.REDIS = Version_("6.2.0") - versions.REDIS_URL = "https://github.com/EQ-Alpha/KeyDB.git" - versions.REDIS_BRANCH = "v6.2.0" - - CONFIG.conf_path = Path(CONFIG.core_path, "config", "keydb.conf") - if not CONFIG.conf_path.resolve().is_file(): - raise SSConfigError( - "Database configuration file at SMARTSIM_REDIS_CONF could not be found" - ) - - # pylint: disable-next=too-many-statements def execute( args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / @@ -248,16 +163,11 @@ def execute( # Unpack various arguments verbose = args.v - keydb = args.keydb - device = Device.from_str(args.device.lower()) + device = Device(args.device.lower()) is_dragon_requested = args.dragon dragon_repo = args.dragon_repo dragon_version = args.dragon_version - if Path(CONFIG.build_path).exists(): - logger.warning(f"Build path already exists, removing: {CONFIG.build_path}") - shutil.rmtree(CONFIG.build_path) - # The user should never have to specify the OS and Architecture current_platform = Platform( OperatingSystem.autodetect(), Architecture.autodetect(), device @@ -289,13 +199,9 @@ def execute( env_vars = list(env.keys()) print(tabulate(env, headers=env_vars, tablefmt="github"), "\n") - if keydb: - _configure_keydb_build(versions) - if verbose: - db_name: DbEngine = "KEYDB" if keydb else "REDIS" logger.info("Version Information:") - vers = versions.as_dict(db_name=db_name) + vers = versions.as_dict() version_names = list(vers.keys()) print(tabulate(vers, headers=version_names, tablefmt="github"), "\n") @@ -324,17 +230,7 @@ def execute( else: logger.warning("Dragon installation failed") - # REDIS/KeyDB - build_database(build_env, versions, keydb, verbose) - - if (CONFIG.lib_path / "redisai.so").exists(): - logger.warning("RedisAI was previously built, run 'smart clean' to rebuild") - elif not args.skip_backends: - build_redis_ai(current_platform, mlpackages, build_env, verbose) - else: - logger.info("Skipping compilation of RedisAI and backends") - - backends = installed_redisai_backends() + backends = [] backends_str = ", ".join(s.capitalize() for s in backends) if backends else "No" logger.info(f"{backends_str} backend(s) available") @@ -423,9 +319,3 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: type=str, help="Path to directory with JSON files describing platform and packages", ) - parser.add_argument( - "--keydb", - action="store_true", - default=False, - help="Build KeyDB instead of Redis", - ) diff --git a/smartsim/_core/_cli/cli.py b/smartsim/_core/_cli/cli.py index 3d5c6e066e..71d0c3a398 100644 --- a/smartsim/_core/_cli/cli.py +++ b/smartsim/_core/_cli/cli.py @@ -108,7 +108,7 @@ def default_cli() -> SmartCli: menu = [ MenuItemConfig( "build", - "Build SmartSim dependencies (Redis, RedisAI, Dragon, ML runtimes)", + "Build SmartSim dependencies (Dragon, ML runtimes)", build_execute, build_parser, ), diff --git a/smartsim/_core/_cli/dbcli.py b/smartsim/_core/_cli/dbcli.py index 733c2fe4d4..b06e5984f6 100644 --- a/smartsim/_core/_cli/dbcli.py +++ b/smartsim/_core/_cli/dbcli.py @@ -28,14 +28,14 @@ import os import typing as t -from smartsim._core._cli.utils import get_db_path +from smartsim._core._cli.utils import get_fs_path def execute( _args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / ) -> int: - if db_path := get_db_path(): - print(db_path) + if fs_path := get_fs_path(): + print(fs_path) return os.EX_OK - print("Database (Redis or KeyDB) dependencies not found") + print("Feature store(Redis or KeyDB) dependencies not found") return os.EX_SOFTWARE diff --git a/smartsim/_core/_cli/info.py b/smartsim/_core/_cli/info.py index c08fcb1a35..7fa094fbdc 100644 --- a/smartsim/_core/_cli/info.py +++ b/smartsim/_core/_cli/info.py @@ -6,7 +6,6 @@ from tabulate import tabulate -import smartsim._core._cli.utils as _utils import smartsim._core.utils.helpers as _helpers from smartsim._core._install.buildenv import BuildEnv as _BuildEnv @@ -21,7 +20,6 @@ def execute( tabulate( [ ["SmartSim", _fmt_py_pkg_version("smartsim")], - ["SmartRedis", _fmt_py_pkg_version("smartredis")], ], headers=["Name", "Version"], tablefmt="fancy_outline", @@ -29,42 +27,31 @@ def execute( end="\n\n", ) - print("Orchestrator Configuration:") - db_path = _utils.get_db_path() - db_table = [["Installed", _fmt_installed_db(db_path)]] - if db_path: - db_table.append(["Location", str(db_path)]) - print(tabulate(db_table, tablefmt="fancy_outline"), end="\n\n") + print("Dragon Installation:") + # TODO: Fix hardcoded dragon version + dragon_version = "0.10" - print("Redis AI Configuration:") - rai_path = _helpers.redis_install_base().parent / "redisai.so" - rai_table = [["Status", _fmt_installed_redis_ai(rai_path)]] - if rai_path.is_file(): - rai_table.append(["Location", str(rai_path)]) - print(tabulate(rai_table, tablefmt="fancy_outline"), end="\n\n") + fs_table = [["Version", str(dragon_version)]] + print(tabulate(fs_table, tablefmt="fancy_outline"), end="\n\n") - print("Machine Learning Backends:") - backends = _helpers.installed_redisai_backends() + print("Machine Learning Packages:") print( tabulate( [ [ "Tensorflow", - _utils.color_bool("tensorflow" in backends), _fmt_py_pkg_version("tensorflow"), ], [ "Torch", - _utils.color_bool("torch" in backends), _fmt_py_pkg_version("torch"), ], [ "ONNX", - _utils.color_bool("onnxruntime" in backends), _fmt_py_pkg_version("onnx"), ], ], - headers=["Name", "Backend Available", "Python Package"], + headers=["Name", "Python Package"], tablefmt="fancy_outline", ), end="\n\n", @@ -72,17 +59,11 @@ def execute( return os.EX_OK -def _fmt_installed_db(db_path: t.Optional[pathlib.Path]) -> str: - if db_path is None: +def _fmt_installed_fs(fs_path: t.Optional[pathlib.Path]) -> str: + if fs_path is None: return _MISSING_DEP - db_name, _ = db_path.name.split("-", 1) - return _helpers.colorize(db_name.upper(), "green") - - -def _fmt_installed_redis_ai(rai_path: pathlib.Path) -> str: - if not rai_path.is_file(): - return _MISSING_DEP - return _helpers.colorize("Installed", "green") + fs_name, _ = fs_path.name.split("-", 1) + return _helpers.colorize(fs_name.upper(), "green") def _fmt_py_pkg_version(pkg_name: str) -> str: diff --git a/smartsim/_core/_cli/scripts/dragon_install.py b/smartsim/_core/_cli/scripts/dragon_install.py index 3a9358390b..7a7d75f1d2 100644 --- a/smartsim/_core/_cli/scripts/dragon_install.py +++ b/smartsim/_core/_cli/scripts/dragon_install.py @@ -22,7 +22,7 @@ logger = get_logger(__name__) DEFAULT_DRAGON_REPO = "DragonHPC/dragon" -DEFAULT_DRAGON_VERSION = "0.9" +DEFAULT_DRAGON_VERSION = "0.10" DEFAULT_DRAGON_VERSION_TAG = f"v{DEFAULT_DRAGON_VERSION}" _GH_TOKEN = "SMARTSIM_DRAGON_TOKEN" diff --git a/smartsim/_core/_cli/utils.py b/smartsim/_core/_cli/utils.py index 9c9b46cab5..ff6a2d2573 100644 --- a/smartsim/_core/_cli/utils.py +++ b/smartsim/_core/_cli/utils.py @@ -91,38 +91,15 @@ def clean(core_path: Path, _all: bool = False) -> int: lib_path = core_path / "lib" if lib_path.is_dir(): - # remove RedisAI - rai_path = lib_path / "redisai.so" - if rai_path.is_file(): - rai_path.unlink() - logger.info("Successfully removed existing RedisAI installation") - backend_path = lib_path / "backends" if backend_path.is_dir(): shutil.rmtree(backend_path, ignore_errors=True) logger.info("Successfully removed ML runtimes") - bin_path = core_path / "bin" - if bin_path.is_dir() and _all: - files_to_remove = ["redis-server", "redis-cli", "keydb-server", "keydb-cli"] - removed = False - for _file in files_to_remove: - file_path = bin_path.joinpath(_file) - - if file_path.is_file(): - removed = True - file_path.unlink() - if removed: - logger.info("Successfully removed SmartSim database installation") - return os.EX_OK -def get_db_path() -> t.Optional[Path]: - bin_path = get_install_path() / "_core" / "bin" - for option in bin_path.iterdir(): - if option.name in ("redis-cli", "keydb-cli"): - return option +def get_fs_path() -> t.Optional[Path]: return None diff --git a/smartsim/_core/_cli/validate.py b/smartsim/_core/_cli/validate.py index b7905b773b..0e21e01ac6 100644 --- a/smartsim/_core/_cli/validate.py +++ b/smartsim/_core/_cli/validate.py @@ -33,14 +33,8 @@ import typing as t from types import TracebackType -import numpy as np -from smartredis import Client - -from smartsim import Experiment from smartsim._core._cli.utils import SMART_LOGGER_FORMAT -from smartsim._core.types import Device -from smartsim._core.utils.helpers import installed_redisai_backends -from smartsim._core.utils.network import find_free_port +from smartsim._core._install.platform import Device from smartsim.log import get_logger logger = get_logger("Smart", fmt=SMART_LOGGER_FORMAT) @@ -53,8 +47,6 @@ if t.TYPE_CHECKING: - from multiprocessing.connection import Connection - # pylint: disable-next=unsubscriptable-object _TemporaryDirectory = tempfile.TemporaryDirectory[str] else: @@ -79,12 +71,11 @@ def __exit__( def execute( - args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / + args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None ) -> int: """Validate the SmartSim installation works as expected given a simple experiment """ - backends = installed_redisai_backends() temp_dir = "" device = Device(args.device) try: @@ -92,21 +83,10 @@ def execute( temp_dir = ctx.enter_context(_VerificationTempDir(dir=os.getcwd())) validate_env = { "SR_LOG_LEVEL": os.environ.get("SR_LOG_LEVEL", "INFO"), - "SR_LOG_FILE": os.environ.get( - "SR_LOG_FILE", os.path.join(temp_dir, "smartredis.log") - ), } if device == Device.GPU: validate_env["CUDA_VISIBLE_DEVICES"] = "0" ctx.enter_context(_env_vars_set_to(validate_env)) - test_install( - location=temp_dir, - port=args.port, - device=device, - with_tf="tensorflow" in backends, - with_pt="torch" in backends, - with_onnx="onnxruntime" in backends, - ) except Exception as e: logger.error( "SmartSim failed to run a simple experiment!\n" @@ -127,7 +107,7 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: type=int, default=None, help=( - "The port on which to run the orchestrator for the mini experiment. " + "The port on which to run the feature store for the mini experiment. " "If not provided, `smart` will attempt to automatically select an " "open port" ), @@ -141,34 +121,6 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: ) -def test_install( - location: str, - port: t.Optional[int], - device: Device, - with_tf: bool, - with_pt: bool, - with_onnx: bool, -) -> None: - exp = Experiment("ValidationExperiment", exp_path=location, launcher="local") - exp.telemetry.disable() - port = find_free_port() if port is None else port - - with _make_managed_local_orc(exp, port) as client: - logger.info("Verifying Tensor Transfer") - client.put_tensor("plain-tensor", np.ones((1, 1, 3, 3))) - client.get_tensor("plain-tensor") - if with_pt: - logger.info("Verifying Torch Backend") - _test_torch_install(client, device) - if with_onnx: - logger.info("Verifying ONNX Backend") - _test_onnx_install(client, device) - if with_tf: # Run last in case TF locks an entire GPU - logger.info("Verifying TensorFlow Backend") - _test_tf_install(client, location, device) - logger.info("Success!") - - @contextlib.contextmanager def _env_vars_set_to( evars: t.Mapping[str, t.Optional[str]] @@ -188,113 +140,3 @@ def _set_or_del_env_var(var: str, val: t.Optional[str]) -> None: os.environ[var] = val else: os.environ.pop(var, None) - - -@contextlib.contextmanager -def _make_managed_local_orc( - exp: Experiment, port: int -) -> t.Generator[Client, None, None]: - """Context managed orc that will be stopped if an exception is raised""" - orc = exp.create_database(db_nodes=1, interface="lo", port=port) - exp.generate(orc) - exp.start(orc) - try: - (client_addr,) = orc.get_address() - yield Client(False, address=client_addr) - finally: - exp.stop(orc) - - -def _test_tf_install(client: Client, tmp_dir: str, device: Device) -> None: - - model_path, inputs, outputs = _build_tf_frozen_model(tmp_dir) - - client.set_model_from_file( - "keras-fcn", - model_path, - "TF", - device=device.value.upper(), - inputs=inputs, - outputs=outputs, - ) - client.put_tensor("keras-input", np.random.rand(1, 28, 28).astype(np.float32)) - client.run_model("keras-fcn", inputs=["keras-input"], outputs=["keras-output"]) - client.get_tensor("keras-output") - - -def _build_tf_frozen_model(tmp_dir: str) -> t.Tuple[str, t.List[str], t.List[str]]: - - from tensorflow import keras # pylint: disable=no-name-in-module - - from smartsim.ml.tf import freeze_model - - fcn = keras.Sequential( - layers=[ - keras.layers.InputLayer(input_shape=(28, 28), name="input"), - keras.layers.Flatten(input_shape=(28, 28), name="flatten"), - keras.layers.Dense(128, activation="relu", name="dense"), - keras.layers.Dense(10, activation="softmax", name="output"), - ], - name="FullyConnectedNetwork", - ) - fcn.compile( - optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] - ) - model_path, inputs, outputs = freeze_model(fcn, tmp_dir, "keras_model.pb") - return model_path, inputs, outputs - - -def _test_torch_install(client: Client, device: Device) -> None: - import torch - from torch import nn - - class Net(nn.Module): - def __init__(self) -> None: - super().__init__() - self.conv: t.Callable[..., torch.Tensor] = nn.Conv2d(1, 1, 3) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.conv(x) - - if device == Device.GPU: - device_ = torch.device("cuda") - else: - device_ = torch.device("cpu") - - net = Net() - net.to(device_) - net.eval() - - forward_input = torch.rand(1, 1, 3, 3).to(device_) - traced = torch.jit.trace( # type: ignore[no-untyped-call, unused-ignore] - net, forward_input - ) - - buffer = io.BytesIO() - torch.jit.save(traced, buffer) # type: ignore[no-untyped-call, unused-ignore] - model = buffer.getvalue() - - client.set_model("torch-nn", model, backend="TORCH", device=device.value.upper()) - client.put_tensor("torch-in", torch.rand(1, 1, 3, 3).numpy()) - client.run_model("torch-nn", inputs=["torch-in"], outputs=["torch-out"]) - client.get_tensor("torch-out") - - -def _test_onnx_install(client: Client, device: Device) -> None: - from skl2onnx import to_onnx - from sklearn.cluster import KMeans - - data = np.arange(20, dtype=np.float32).reshape(10, 2) - model = KMeans(n_clusters=2, n_init=10) - model.fit(data) - - kmeans = to_onnx(model, data, target_opset=11) - model = kmeans.SerializeToString() - sample = np.arange(20, dtype=np.float32).reshape(10, 2) - - client.put_tensor("onnx-input", sample) - client.set_model("onnx-kmeans", model, "ONNX", device=device.value.upper()) - client.run_model( - "onnx-kmeans", inputs=["onnx-input"], outputs=["onnx-labels", "onnx-transform"] - ) - client.get_tensor("onnx-labels") diff --git a/smartsim/_core/_install/buildenv.py b/smartsim/_core/_install/buildenv.py index bff421b129..b8c6775120 100644 --- a/smartsim/_core/_install/buildenv.py +++ b/smartsim/_core/_install/buildenv.py @@ -132,6 +132,9 @@ def get_env(var: str, default: str) -> str: return os.environ.get(var, default) +# TODO Add A Version class for the new backend + + class Versioner: """Versioner is responsible for managing all the versions within SmartSim including SmartSim itself. @@ -221,12 +224,7 @@ class BuildEnv: """Environment for building third-party dependencies BuildEnv provides a method for configuring how the third-party - dependencies within SmartSim are built, namely Redis/KeyDB - and RedisAI. - - The environment variables listed here can be set to control the - Redis build in the pip wheel build as well as the Redis and RedisAI - build executed by the CLI. + dependencies within SmartSim are built. Build tools are also checked for here and if they are not found then a SetupError is raised. diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 957f2b6ef6..7b850a9158 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -35,6 +35,7 @@ from pathlib import Path from subprocess import SubprocessError +from smartsim._core._install.platform import Architecture, OperatingSystem, Platform from smartsim._core._install.utils import retrieve from smartsim._core.utils import expand_exe_path @@ -167,98 +168,3 @@ def run_command( raise BuildError(error) except (OSError, SubprocessError) as e: raise BuildError(e) from e - - -class DatabaseBuilder(Builder): - """Class to build Redis or KeyDB from Source - Supported build methods: - - from git - See buildenv.py for buildtime configuration of Redis/KeyDB - version and url. - """ - - def __init__( - self, - build_env: t.Optional[t.Dict[str, str]] = None, - malloc: str = "libc", - jobs: int = 1, - verbose: bool = False, - ) -> None: - super().__init__( - build_env or {}, - jobs=jobs, - verbose=verbose, - ) - self.malloc = malloc - - @property - def is_built(self) -> bool: - """Check if Redis or KeyDB is built""" - bin_files = {file.name for file in self.bin_path.iterdir()} - redis_files = {"redis-server", "redis-cli"} - keydb_files = {"keydb-server", "keydb-cli"} - return redis_files.issubset(bin_files) or keydb_files.issubset(bin_files) - - def build_from_git(self, git_url: str, branch: str) -> None: - """Build Redis from git - :param git_url: url from which to retrieve Redis - :param branch: branch to checkout - """ - # pylint: disable=too-many-locals - database_name = "keydb" if "KeyDB" in git_url else "redis" - database_build_path = Path(self.build_dir, database_name.lower()) - - # remove git directory if it exists as it should - # really never exist as we delete after build - redis_build_path = Path(self.build_dir, "redis") - keydb_build_path = Path(self.build_dir, "keydb") - if redis_build_path.is_dir(): - shutil.rmtree(str(redis_build_path)) - if keydb_build_path.is_dir(): - shutil.rmtree(str(keydb_build_path)) - - # Check database URL - if not self.is_valid_url(git_url): - raise BuildError(f"Malformed {database_name} URL: {git_url}") - - retrieve(git_url, self.build_dir / database_name, branch=branch, depth=1) - # build Redis - build_cmd = [ - self.binary_path("make"), - "-j", - str(self.jobs), - f"MALLOC={self.malloc}", - ] - self.run_command(build_cmd, cwd=str(database_build_path)) - - # move redis binaries to smartsim/smartsim/_core/bin - database_src_dir = database_build_path / "src" - server_source = database_src_dir / (database_name.lower() + "-server") - server_destination = self.bin_path / (database_name.lower() + "-server") - cli_source = database_src_dir / (database_name.lower() + "-cli") - cli_destination = self.bin_path / (database_name.lower() + "-cli") - self.copy_file(server_source, server_destination, set_exe=True) - self.copy_file(cli_source, cli_destination, set_exe=True) - - # validate install -- redis-server - core_path = Path(os.path.abspath(__file__)).parent.parent - dependency_path = os.environ.get("SMARTSIM_DEP_INSTALL_PATH", core_path) - bin_path = Path(dependency_path, "bin").resolve() - try: - database_exe = next(bin_path.glob("*-server")) - database = Path( - os.environ.get("SMARTSIM_REDIS_SERVER_EXE", database_exe) - ).resolve() - _ = expand_exe_path(str(database)) - except (TypeError, FileNotFoundError) as e: - raise BuildError("Installation of redis-server failed!") from e - - # validate install -- redis-cli - try: - redis_cli_exe = next(bin_path.glob("*-cli")) - redis_cli = Path( - os.environ.get("SMARTSIM_REDIS_CLI_EXE", redis_cli_exe) - ).resolve() - _ = expand_exe_path(str(redis_cli)) - except (TypeError, FileNotFoundError) as e: - raise BuildError("Installation of redis-cli failed!") from e diff --git a/smartsim/_core/arguments/shell.py b/smartsim/_core/arguments/shell.py new file mode 100644 index 0000000000..e4138d0ebb --- /dev/null +++ b/smartsim/_core/arguments/shell.py @@ -0,0 +1,42 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t +from abc import abstractmethod + +from smartsim.log import get_logger +from smartsim.settings.arguments.launch_arguments import LaunchArguments + +logger = get_logger(__name__) + + +class ShellLaunchArguments(LaunchArguments): + @abstractmethod + def format_env_vars(self, env_vars: t.Mapping[str, str | None]) -> list[str]: ... + @abstractmethod + def format_launch_args(self) -> list[str]: ... diff --git a/smartsim/_core/commands/__init__.py b/smartsim/_core/commands/__init__.py new file mode 100644 index 0000000000..a35efc62f8 --- /dev/null +++ b/smartsim/_core/commands/__init__.py @@ -0,0 +1,29 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .command import Command +from .command_list import CommandList +from .launch_commands import LaunchCommands diff --git a/smartsim/_core/commands/command.py b/smartsim/_core/commands/command.py new file mode 100644 index 0000000000..0968759afd --- /dev/null +++ b/smartsim/_core/commands/command.py @@ -0,0 +1,98 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +from collections.abc import MutableSequence +from copy import deepcopy + +from typing_extensions import Self + + +class Command(MutableSequence[str]): + """Basic container for command information""" + + def __init__(self, command: t.List[str]) -> None: + if not command: + raise TypeError("Command list cannot be empty") + if not all(isinstance(item, str) for item in command): + raise TypeError("All items in the command list must be strings") + """Command constructor""" + self._command = command + + @property + def command(self) -> t.List[str]: + """Get the command list. + Return a reference to the command list. + """ + return self._command + + @t.overload + def __getitem__(self, idx: int) -> str: ... + @t.overload + def __getitem__(self, idx: slice) -> Self: ... + def __getitem__(self, idx: t.Union[int, slice]) -> t.Union[str, Self]: + """Get the command at the specified index.""" + cmd = self._command[idx] + if isinstance(cmd, str): + return cmd + return type(self)(cmd) + + @t.overload + def __setitem__(self, idx: int, value: str) -> None: ... + @t.overload + def __setitem__(self, idx: slice, value: t.Iterable[str]) -> None: ... + def __setitem__( + self, idx: t.Union[int, slice], value: t.Union[str, t.Iterable[str]] + ) -> None: + """Set the command at the specified index.""" + if isinstance(idx, int): + if not isinstance(value, str): + raise TypeError( + "Value must be of type `str` when assigning to an index" + ) + self._command[idx] = deepcopy(value) + return + if not isinstance(value, list) or not all( + isinstance(item, str) for item in value + ): + raise TypeError("Value must be a list of strings when assigning to a slice") + self._command[idx] = (deepcopy(val) for val in value) + + def __delitem__(self, idx: t.Union[int, slice]) -> None: + """Delete the command at the specified index.""" + del self._command[idx] + + def __len__(self) -> int: + """Get the length of the command list.""" + return len(self._command) + + def insert(self, idx: int, value: str) -> None: + """Insert a command at the specified index.""" + self._command.insert(idx, value) + + def __str__(self) -> str: # pragma: no cover + string = f"\nCommand: {' '.join(str(cmd) for cmd in self.command)}" + return string diff --git a/smartsim/_core/commands/command_list.py b/smartsim/_core/commands/command_list.py new file mode 100644 index 0000000000..fcffe42a2a --- /dev/null +++ b/smartsim/_core/commands/command_list.py @@ -0,0 +1,107 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +from collections.abc import MutableSequence +from copy import deepcopy + +from .command import Command + + +class CommandList(MutableSequence[Command]): + """Container for a Sequence of Command objects""" + + def __init__(self, commands: t.Optional[t.Union[Command, t.List[Command]]] = None): + """CommandList constructor""" + if commands is None: + commands = [] + elif isinstance(commands, Command): + commands = [commands] + self._commands: t.List[Command] = list(commands) + + @property + def commands(self) -> t.List[Command]: + """Get the Command list. + Return a reference to the Command list. + """ + return self._commands + + @t.overload + def __getitem__(self, idx: int) -> Command: ... + @t.overload + def __getitem__(self, idx: slice) -> t.List[Command]: ... + def __getitem__( + self, idx: t.Union[slice, int] + ) -> t.Union[Command, t.List[Command]]: + """Get the Command at the specified index.""" + return self._commands[idx] + + @t.overload + def __setitem__(self, idx: int, value: Command) -> None: ... + @t.overload + def __setitem__(self, idx: slice, value: t.Iterable[Command]) -> None: ... + def __setitem__( + self, idx: t.Union[int, slice], value: t.Union[Command, t.Iterable[Command]] + ) -> None: + """Set the Commands at the specified index.""" + if isinstance(idx, int): + if not isinstance(value, Command): + raise TypeError( + "Value must be of type `Command` when assigning to an index" + ) + self._commands[idx] = deepcopy(value) + return + if not isinstance(value, list): + raise TypeError( + "Value must be a list of Commands when assigning to a slice" + ) + for sublist in value: + if not isinstance(sublist.command, list) or not all( + isinstance(item, str) for item in sublist.command + ): + raise TypeError( + "Value sublists must be a list of Commands when assigning to a slice" + ) + self._commands[idx] = (deepcopy(val) for val in value) + + def __delitem__(self, idx: t.Union[int, slice]) -> None: + """Delete the Command at the specified index.""" + del self._commands[idx] + + def __len__(self) -> int: + """Get the length of the Command list.""" + return len(self._commands) + + def insert(self, idx: int, value: Command) -> None: + """Insert a Command at the specified index.""" + self._commands.insert(idx, value) + + def __str__(self) -> str: # pragma: no cover + string = "\n\nCommand List:\n\n" + for counter, cmd in enumerate(self.commands): + string += f"CommandList index {counter} value:" + string += f"{cmd}\n\n" + return string diff --git a/smartsim/_core/commands/launch_commands.py b/smartsim/_core/commands/launch_commands.py new file mode 100644 index 0000000000..74303ac942 --- /dev/null +++ b/smartsim/_core/commands/launch_commands.py @@ -0,0 +1,51 @@ +from .command_list import CommandList + + +class LaunchCommands: + """Container for aggregating prelaunch commands (e.g. file + system operations), launch commands, and postlaunch commands + """ + + def __init__( + self, + prelaunch_commands: CommandList, + launch_commands: CommandList, + postlaunch_commands: CommandList, + ) -> None: + """LaunchCommand constructor""" + self._prelaunch_commands = prelaunch_commands + self._launch_commands = launch_commands + self._postlaunch_commands = postlaunch_commands + + @property + def prelaunch_command(self) -> CommandList: + """Get the prelaunch command list. + Return a reference to the command list. + """ + return self._prelaunch_commands + + @property + def launch_command(self) -> CommandList: + """Get the launch command list. + Return a reference to the command list. + """ + return self._launch_commands + + @property + def postlaunch_command(self) -> CommandList: + """Get the postlaunch command list. + Return a reference to the command list. + """ + return self._postlaunch_commands + + def __str__(self) -> str: # pragma: no cover + string = "\n\nPrelaunch Command List:\n" + for pre_cmd in self.prelaunch_command: + string += f"{pre_cmd}\n" + string += "\n\nLaunch Command List:\n" + for launch_cmd in self.launch_command: + string += f"{launch_cmd}\n" + string += "\n\nPostlaunch Command List:\n" + for post_cmd in self.postlaunch_command: + string += f"{post_cmd}\n" + return string diff --git a/smartsim/_core/config/config.py b/smartsim/_core/config/config.py index c8b4ff17b9..b4cbae8d22 100644 --- a/smartsim/_core/config/config.py +++ b/smartsim/_core/config/config.py @@ -40,22 +40,6 @@ # These values can be set through environment variables to # override the default behavior of SmartSim. # -# SMARTSIM_RAI_LIB -# - Path to the RAI shared library -# - Default: /smartsim/smartsim/_core/lib/redisai.so -# -# SMARTSIM_REDIS_CONF -# - Path to the redis.conf file -# - Default: /SmartSim/smartsim/_core/config/redis.conf -# -# SMARTSIM_REDIS_SERVER_EXE -# - Path to the redis-server executable -# - Default: /SmartSim/smartsim/_core/bin/redis-server -# -# SMARTSIM_REDIS_CLI_EXE -# - Path to the redis-cli executable -# - Default: /SmartSim/smartsim/_core/bin/redis-cli -# # SMARTSIM_LOG_LEVEL # - Log level for SmartSim # - Default: info @@ -94,77 +78,13 @@ class Config: def __init__(self) -> None: # SmartSim/smartsim/_core self.core_path = Path(os.path.abspath(__file__)).parent.parent - # TODO: Turn this into a property. Need to modify the configuration - # of KeyDB vs Redis at build time - self.conf_dir = self.core_path / "config" - self.conf_path = self.conf_dir / "redis.conf" - - @property - def dependency_path(self) -> Path: - return Path( - os.environ.get("SMARTSIM_DEP_INSTALL_PATH", str(self.core_path)) - ).resolve() - - @property - def lib_path(self) -> Path: - return Path(self.dependency_path, "lib") - - @property - def bin_path(self) -> Path: - return Path(self.dependency_path, "bin") - @property - def build_path(self) -> Path: - return Path(self.dependency_path, "build") - - @property - def redisai(self) -> str: - rai_path = self.lib_path / "redisai.so" - redisai = Path(os.environ.get("SMARTSIM_RAI_LIB", rai_path)).resolve() - if not redisai.is_file(): - raise SSConfigError( - "RedisAI dependency not found. Build with `smart` cli " - "or specify SMARTSIM_RAI_LIB" - ) - return str(redisai) - - @property - def database_conf(self) -> str: - conf = Path(os.environ.get("SMARTSIM_REDIS_CONF", self.conf_path)).resolve() - if not conf.is_file(): - raise SSConfigError( - "Database configuration file at SMARTSIM_REDIS_CONF could not be found" - ) - return str(conf) + dependency_path = os.environ.get("SMARTSIM_DEP_INSTALL_PATH", self.core_path) - @property - def database_exe(self) -> str: - try: - database_exe = next(self.bin_path.glob("*-server")) - database = Path( - os.environ.get("SMARTSIM_REDIS_SERVER_EXE", database_exe) - ).resolve() - exe = expand_exe_path(str(database)) - return exe - except (TypeError, FileNotFoundError) as e: - raise SSConfigError( - "Specified database binary at SMARTSIM_REDIS_SERVER_EXE " - "could not be used" - ) from e - - @property - def database_cli(self) -> str: - try: - redis_cli_exe = next(self.bin_path.glob("*-cli")) - redis_cli = Path( - os.environ.get("SMARTSIM_REDIS_CLI_EXE", redis_cli_exe) - ).resolve() - exe = expand_exe_path(str(redis_cli)) - return exe - except (TypeError, FileNotFoundError) as e: - raise SSConfigError( - "Specified Redis binary at SMARTSIM_REDIS_CLI_EXE could not be used" - ) from e + self.lib_path = Path(dependency_path, "lib").resolve() + self.bin_path = Path(dependency_path, "bin").resolve() + self.conf_path = Path(dependency_path, "config") + self.conf_dir = Path(self.core_path, "config") @property def database_file_parse_trials(self) -> int: diff --git a/smartsim/_core/control/__init__.py b/smartsim/_core/control/__init__.py index 0acd80650c..ba3af1440f 100644 --- a/smartsim/_core/control/__init__.py +++ b/smartsim/_core/control/__init__.py @@ -24,5 +24,4 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .controller import Controller from .manifest import Manifest diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py deleted file mode 100644 index 0b943ee905..0000000000 --- a/smartsim/_core/control/controller.py +++ /dev/null @@ -1,956 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -import itertools -import os -import os.path as osp -import pathlib -import pickle -import signal -import subprocess -import sys -import threading -import time -import typing as t - -from smartredis import Client, ConfigOptions - -from smartsim._core.utils.network import get_ip_from_host - -from ..._core.launcher.step import Step -from ..._core.utils.helpers import ( - SignalInterceptionStack, - unpack_colo_db_identifier, - unpack_db_identifier, -) -from ..._core.utils.redis import ( - db_is_active, - set_ml_model, - set_script, - shutdown_db_node, -) -from ...database import Orchestrator -from ...entity import Ensemble, EntitySequence, Model, SmartSimEntity -from ...error import ( - LauncherError, - SmartSimError, - SSDBIDConflictError, - SSInternalError, - SSUnsupportedError, -) -from ...log import get_logger -from ...servertype import CLUSTERED, STANDALONE -from ...status import TERMINAL_STATUSES, SmartSimStatus -from ..config import CONFIG -from ..launcher import ( - DragonLauncher, - LocalLauncher, - LSFLauncher, - PBSLauncher, - SGELauncher, - SlurmLauncher, -) -from ..launcher.launcher import Launcher -from ..utils import check_cluster_status, create_cluster, serialize -from .controller_utils import _AnonymousBatchJob, _look_up_launched_data -from .job import Job -from .jobmanager import JobManager -from .manifest import LaunchedManifest, LaunchedManifestBuilder, Manifest - -if t.TYPE_CHECKING: - from types import FrameType - - from ..utils.serialize import TStepLaunchMetaData - - -logger = get_logger(__name__) - -# job manager lock -JM_LOCK = threading.RLock() - - -class Controller: - """The controller module provides an interface between the - smartsim entities created in the experiment and the - underlying workload manager or run framework. - """ - - def __init__(self, launcher: str = "local") -> None: - """Initialize a Controller - - :param launcher: the type of launcher being used - """ - self._jobs = JobManager(JM_LOCK) - self.init_launcher(launcher) - self._telemetry_monitor: t.Optional[subprocess.Popen[bytes]] = None - - def start( - self, - exp_name: str, - exp_path: str, - manifest: Manifest, - block: bool = True, - kill_on_interrupt: bool = True, - ) -> None: - """Start the passed SmartSim entities - - This function should not be called directly, but rather - through the experiment interface. - - The controller will start the job-manager thread upon - execution of all jobs. - """ - # launch a telemetry monitor to track job progress - if CONFIG.telemetry_enabled: - self._start_telemetry_monitor(exp_path) - - self._jobs.kill_on_interrupt = kill_on_interrupt - - # register custom signal handler for ^C (SIGINT) - SignalInterceptionStack.get(signal.SIGINT).push_unique( - self._jobs.signal_interrupt - ) - launched = self._launch(exp_name, exp_path, manifest) - - # start the job manager thread if not already started - if not self._jobs.actively_monitoring: - self._jobs.start() - - serialize.save_launch_manifest( - launched.map(_look_up_launched_data(self._launcher)) - ) - - # block until all non-database jobs are complete - if block: - # poll handles its own keyboard interrupt as - # it may be called separately - self.poll(5, True, kill_on_interrupt=kill_on_interrupt) - - @property - def active_orchestrator_jobs(self) -> t.Dict[str, Job]: - """Return active orchestrator jobs.""" - return {**self._jobs.db_jobs} - - @property - def orchestrator_active(self) -> bool: - with JM_LOCK: - if len(self._jobs.db_jobs) > 0: - return True - return False - - def poll( - self, interval: int, verbose: bool, kill_on_interrupt: bool = True - ) -> None: - """Poll running jobs and receive logging output of job status - - :param interval: number of seconds to wait before polling again - :param verbose: set verbosity - :param kill_on_interrupt: flag for killing jobs when SIGINT is received - """ - self._jobs.kill_on_interrupt = kill_on_interrupt - to_monitor = self._jobs.jobs - while len(to_monitor) > 0: - time.sleep(interval) - - # acquire lock to avoid "dictionary changed during iteration" error - # without having to copy dictionary each time. - if verbose: - with JM_LOCK: - for job in to_monitor.values(): - logger.info(job) - - def finished( - self, entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> bool: - """Return a boolean indicating wether a job has finished or not - - :param entity: object launched by SmartSim. - :returns: bool - :raises ValueError: if entity has not been launched yet - """ - try: - if isinstance(entity, Orchestrator): - raise TypeError("Finished() does not support Orchestrator instances") - if isinstance(entity, EntitySequence): - return all(self.finished(ent) for ent in entity.entities) - if not isinstance(entity, SmartSimEntity): - raise TypeError( - f"Argument was of type {type(entity)} not derived " - "from SmartSimEntity or EntitySequence" - ) - - return self._jobs.is_finished(entity) - except KeyError: - raise ValueError( - f"Entity {entity.name} has not been launched in this experiment" - ) from None - - def stop_entity( - self, entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> None: - """Stop an instance of an entity - - This function will also update the status of the job in - the jobmanager so that the job appears as "cancelled". - - :param entity: entity to be stopped - """ - with JM_LOCK: - job = self._jobs[entity.name] - if job.status not in TERMINAL_STATUSES: - logger.info( - " ".join( - ("Stopping model", entity.name, "with job name", str(job.name)) - ) - ) - status = self._launcher.stop(job.name) - - job.set_status( - status.status, - status.launcher_status, - status.returncode, - error=status.error, - output=status.output, - ) - self._jobs.move_to_completed(job) - - def stop_db(self, db: Orchestrator) -> None: - """Stop an orchestrator - - :param db: orchestrator to be stopped - """ - if db.batch: - self.stop_entity(db) - else: - with JM_LOCK: - for node in db.entities: - for host_ip, port in itertools.product( - (get_ip_from_host(host) for host in node.hosts), db.ports - ): - retcode, _, _ = shutdown_db_node(host_ip, port) - # Sometimes the DB will not shutdown (unless we force NOSAVE) - if retcode != 0: - self.stop_entity(node) - continue - - job = self._jobs[node.name] - job.set_status( - SmartSimStatus.STATUS_CANCELLED, - "", - 0, - output=None, - error=None, - ) - self._jobs.move_to_completed(job) - - db.reset_hosts() - - def stop_entity_list(self, entity_list: EntitySequence[SmartSimEntity]) -> None: - """Stop an instance of an entity list - - :param entity_list: entity list to be stopped - """ - - if entity_list.batch: - self.stop_entity(entity_list) - else: - for entity in entity_list.entities: - self.stop_entity(entity) - - def get_jobs(self) -> t.Dict[str, Job]: - """Return a dictionary of completed job data - - :returns: dict[str, Job] - """ - with JM_LOCK: - return self._jobs.completed - - def get_entity_status( - self, entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> SmartSimStatus: - """Get the status of an entity - - :param entity: entity to get status of - :raises TypeError: if not SmartSimEntity | EntitySequence - :return: status of entity - """ - if not isinstance(entity, (SmartSimEntity, EntitySequence)): - raise TypeError( - "Argument must be of type SmartSimEntity or EntitySequence, " - f"not {type(entity)}" - ) - return self._jobs.get_status(entity) - - def get_entity_list_status( - self, entity_list: EntitySequence[SmartSimEntity] - ) -> t.List[SmartSimStatus]: - """Get the statuses of an entity list - - :param entity_list: entity list containing entities to - get statuses of - :raises TypeError: if not EntitySequence - :return: list of SmartSimStatus statuses - """ - if not isinstance(entity_list, EntitySequence): - raise TypeError( - f"Argument was of type {type(entity_list)} not EntitySequence" - ) - if entity_list.batch: - return [self.get_entity_status(entity_list)] - statuses = [] - for entity in entity_list.entities: - statuses.append(self.get_entity_status(entity)) - return statuses - - def init_launcher(self, launcher: str) -> None: - """Initialize the controller with a specific type of launcher. - SmartSim currently supports slurm, pbs(pro), lsf, - and local launching - - :param launcher: which launcher to initialize - :raises SSUnsupportedError: if a string is passed that is not - a supported launcher - :raises TypeError: if no launcher argument is provided. - """ - launcher_map: t.Dict[str, t.Type[Launcher]] = { - "slurm": SlurmLauncher, - "pbs": PBSLauncher, - "pals": PBSLauncher, - "lsf": LSFLauncher, - "local": LocalLauncher, - "dragon": DragonLauncher, - "sge": SGELauncher, - } - - if launcher is not None: - launcher = launcher.lower() - if launcher in launcher_map: - # create new instance of the launcher - self._launcher = launcher_map[launcher]() - self._jobs.set_launcher(self._launcher) - else: - raise SSUnsupportedError("Launcher type not supported: " + launcher) - else: - raise TypeError("Must provide a 'launcher' argument") - - @staticmethod - def symlink_output_files( - job_step: Step, entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> None: - """Create symlinks for entity output files that point to the output files - under the .smartsim directory - - :param job_step: Job step instance - :param entity: Entity instance - """ - historical_out, historical_err = map(pathlib.Path, job_step.get_output_files()) - entity_out = pathlib.Path(entity.path) / f"{entity.name}.out" - entity_err = pathlib.Path(entity.path) / f"{entity.name}.err" - - # check if there is already a link to a previous run - if entity_out.is_symlink() or entity_err.is_symlink(): - entity_out.unlink() - entity_err.unlink() - - historical_err.touch() - historical_out.touch() - - if historical_err.exists() and historical_out.exists(): - entity_out.symlink_to(historical_out) - entity_err.symlink_to(historical_err) - else: - raise FileNotFoundError( - f"Output files for {entity.name} could not be found. " - "Symlinking files failed." - ) - - def _launch( - self, exp_name: str, exp_path: str, manifest: Manifest - ) -> LaunchedManifest[t.Tuple[str, Step]]: - """Main launching function of the controller - - Orchestrators are always launched first so that the - address of the database can be given to following entities - - :param exp_name: The name of the launching experiment - :param exp_path: path to location of ``Experiment`` directory if generated - :param manifest: Manifest of deployables to launch - """ - - manifest_builder = LaunchedManifestBuilder[t.Tuple[str, Step]]( - exp_name=exp_name, - exp_path=exp_path, - launcher_name=str(self._launcher), - ) - # Loop over deployables to launch and launch multiple orchestrators - for orchestrator in manifest.dbs: - for key in self._jobs.get_db_host_addresses(): - _, db_id = unpack_db_identifier(key, "_") - if orchestrator.db_identifier == db_id: - raise SSDBIDConflictError( - f"Database identifier {orchestrator.db_identifier}" - " has already been used. Pass in a unique" - " name for db_identifier" - ) - - if orchestrator.num_shards > 1 and isinstance( - self._launcher, LocalLauncher - ): - raise SmartSimError( - "Local launcher does not support multi-host orchestrators" - ) - self._launch_orchestrator(orchestrator, manifest_builder) - - if self.orchestrator_active: - self._set_dbobjects(manifest) - - # create all steps prior to launch - steps: t.List[ - t.Tuple[Step, t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]]] - ] = [] - - symlink_substeps: t.List[ - t.Tuple[Step, t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]]] - ] = [] - - for elist in manifest.ensembles: - ens_telem_dir = manifest_builder.run_telemetry_subdirectory / "ensemble" - if elist.batch: - batch_step, substeps = self._create_batch_job_step(elist, ens_telem_dir) - manifest_builder.add_ensemble( - elist, [(batch_step.name, step) for step in substeps] - ) - - # symlink substeps to maintain directory structure - for substep, substep_entity in zip(substeps, elist.models): - symlink_substeps.append((substep, substep_entity)) - - steps.append((batch_step, elist)) - else: - # if ensemble is to be run as separate job steps, aka not in a batch - job_steps = [ - (self._create_job_step(e, ens_telem_dir / elist.name), e) - for e in elist.entities - ] - manifest_builder.add_ensemble( - elist, [(step.name, step) for step, _ in job_steps] - ) - steps.extend(job_steps) - # models themselves cannot be batch steps. If batch settings are - # attached, wrap them in an anonymous batch job step - for model in manifest.models: - model_telem_dir = manifest_builder.run_telemetry_subdirectory / "model" - if model.batch_settings: - anon_entity_list = _AnonymousBatchJob(model) - batch_step, substeps = self._create_batch_job_step( - anon_entity_list, model_telem_dir - ) - manifest_builder.add_model(model, (batch_step.name, batch_step)) - - symlink_substeps.append((substeps[0], model)) - steps.append((batch_step, model)) - else: - job_step = self._create_job_step(model, model_telem_dir) - manifest_builder.add_model(model, (job_step.name, job_step)) - steps.append((job_step, model)) - - # launch and symlink steps - for step, entity in steps: - self._launch_step(step, entity) - self.symlink_output_files(step, entity) - - # symlink substeps to maintain directory structure - for substep, entity in symlink_substeps: - self.symlink_output_files(substep, entity) - - return manifest_builder.finalize() - - def _launch_orchestrator( - self, - orchestrator: Orchestrator, - manifest_builder: LaunchedManifestBuilder[t.Tuple[str, Step]], - ) -> None: - """Launch an Orchestrator instance - - This function will launch the Orchestrator instance and - if on WLM, find the nodes where it was launched and - set them in the JobManager - - :param orchestrator: orchestrator to launch - :param manifest_builder: An `LaunchedManifestBuilder` to record the - names and `Step`s of the launched orchestrator - """ - orchestrator.remove_stale_files() - orc_telem_dir = manifest_builder.run_telemetry_subdirectory / "database" - - # if the orchestrator was launched as a batch workload - if orchestrator.batch: - orc_batch_step, substeps = self._create_batch_job_step( - orchestrator, orc_telem_dir - ) - manifest_builder.add_database( - orchestrator, [(orc_batch_step.name, step) for step in substeps] - ) - - self._launch_step(orc_batch_step, orchestrator) - self.symlink_output_files(orc_batch_step, orchestrator) - - # symlink substeps to maintain directory structure - for substep, substep_entity in zip(substeps, orchestrator.entities): - self.symlink_output_files(substep, substep_entity) - - # if orchestrator was run on existing allocation, locally, or in allocation - else: - db_steps = [ - (self._create_job_step(db, orc_telem_dir / orchestrator.name), db) - for db in orchestrator.entities - ] - manifest_builder.add_database( - orchestrator, [(step.name, step) for step, _ in db_steps] - ) - for db_step in db_steps: - self._launch_step(*db_step) - self.symlink_output_files(*db_step) - - # wait for orchestrator to spin up - self._orchestrator_launch_wait(orchestrator) - - # set the jobs in the job manager to provide SSDB variable to entities - # if _host isnt set within each - self._jobs.set_db_hosts(orchestrator) - - # create the database cluster - if orchestrator.num_shards > 2: - num_trials = 5 - cluster_created = False - while not cluster_created: - try: - create_cluster(orchestrator.hosts, orchestrator.ports) - check_cluster_status(orchestrator.hosts, orchestrator.ports) - num_shards = orchestrator.num_shards - logger.info(f"Database cluster created with {num_shards} shards") - cluster_created = True - except SSInternalError: - if num_trials > 0: - logger.debug( - "Cluster creation failed, attempting again in five seconds." - ) - num_trials -= 1 - time.sleep(5) - else: - # surface SSInternalError as we have no way to recover - raise - self._save_orchestrator(orchestrator) - logger.debug(f"Orchestrator launched on nodes: {orchestrator.hosts}") - - def _launch_step( - self, - job_step: Step, - entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], - ) -> None: - """Use the launcher to launch a job step - - :param job_step: a job step instance - :param entity: entity instance - :raises SmartSimError: if launch fails - """ - # attempt to retrieve entity name in JobManager.completed - completed_job = self._jobs.completed.get(entity.name, None) - - # if completed job DNE and is the entity name is not - # running in JobManager.jobs or JobManager.db_jobs, - # launch the job - if completed_job is None and ( - entity.name not in self._jobs.jobs and entity.name not in self._jobs.db_jobs - ): - try: - job_id = self._launcher.run(job_step) - except LauncherError as e: - msg = f"An error occurred when launching {entity.name} \n" - msg += "Check error and output files for details.\n" - msg += f"{entity}" - logger.error(msg) - raise SmartSimError(f"Job step {entity.name} failed to launch") from e - - # if the completed job does exist and the entity passed in is the same - # that has ran and completed, relaunch the entity. - elif completed_job is not None and completed_job.entity is entity: - try: - job_id = self._launcher.run(job_step) - except LauncherError as e: - msg = f"An error occurred when launching {entity.name} \n" - msg += "Check error and output files for details.\n" - msg += f"{entity}" - logger.error(msg) - raise SmartSimError(f"Job step {entity.name} failed to launch") from e - - # the entity is using a duplicate name of an existing entity in - # the experiment, throw an error - else: - raise SSUnsupportedError("SmartSim entities cannot have duplicate names.") - - # a job step is a task if it is not managed by a workload manager (i.e. Slurm) - # but is rather started, monitored, and exited through the Popen interface - # in the taskmanager - is_task = not job_step.managed - - if self._jobs.query_restart(entity.name): - logger.debug(f"Restarting {entity.name}") - self._jobs.restart_job(job_step.name, job_id, entity.name, is_task) - else: - logger.debug(f"Launching {entity.name}") - self._jobs.add_job(job_step.name, job_id, entity, is_task) - - def _create_batch_job_step( - self, - entity_list: t.Union[Orchestrator, Ensemble, _AnonymousBatchJob], - telemetry_dir: pathlib.Path, - ) -> t.Tuple[Step, t.List[Step]]: - """Use launcher to create batch job step - - :param entity_list: EntityList to launch as batch - :param telemetry_dir: Path to a directory in which the batch job step - may write telemetry events - :return: batch job step instance and a list of run steps to be - executed within the batch job - """ - if not entity_list.batch_settings: - raise ValueError( - "EntityList must have batch settings to be launched as batch" - ) - - telemetry_dir = telemetry_dir / entity_list.name - batch_step = self._launcher.create_step( - entity_list.name, entity_list.path, entity_list.batch_settings - ) - batch_step.meta["entity_type"] = str(type(entity_list).__name__).lower() - batch_step.meta["status_dir"] = str(telemetry_dir) - - substeps = [] - for entity in entity_list.entities: - # tells step creation not to look for an allocation - entity.run_settings.in_batch = True - step = self._create_job_step(entity, telemetry_dir) - substeps.append(step) - batch_step.add_to_batch(step) - return batch_step, substeps - - def _create_job_step( - self, entity: SmartSimEntity, telemetry_dir: pathlib.Path - ) -> Step: - """Create job steps for all entities with the launcher - - :param entity: an entity to create a step for - :param telemetry_dir: Path to a directory in which the job step - may write telemetry events - :return: the job step - """ - # get SSDB, SSIN, SSOUT and add to entity run settings - if isinstance(entity, Model): - self._prep_entity_client_env(entity) - - step = self._launcher.create_step(entity.name, entity.path, entity.run_settings) - - step.meta["entity_type"] = str(type(entity).__name__).lower() - step.meta["status_dir"] = str(telemetry_dir / entity.name) - - return step - - def _prep_entity_client_env(self, entity: Model) -> None: - """Retrieve all connections registered to this entity - - :param entity: The entity to retrieve connections from - """ - - client_env: t.Dict[str, t.Union[str, int, float, bool]] = {} - address_dict = self._jobs.get_db_host_addresses() - - for db_id, addresses in address_dict.items(): - db_name, _ = unpack_db_identifier(db_id, "_") - if addresses: - # Cap max length of SSDB - client_env[f"SSDB{db_name}"] = ",".join(addresses[:128]) - - # Retrieve num_shards to append to client env - client_env[f"SR_DB_TYPE{db_name}"] = ( - CLUSTERED if len(addresses) > 1 else STANDALONE - ) - - if entity.incoming_entities: - client_env["SSKEYIN"] = ",".join( - [in_entity.name for in_entity in entity.incoming_entities] - ) - if entity.query_key_prefixing(): - client_env["SSKEYOUT"] = entity.name - - # Set address to local if it's a colocated model - if entity.colocated and entity.run_settings.colocated_db_settings is not None: - db_name_colo = entity.run_settings.colocated_db_settings["db_identifier"] - assert isinstance(db_name_colo, str) - for key in address_dict: - _, db_id = unpack_db_identifier(key, "_") - if db_name_colo == db_id: - raise SSDBIDConflictError( - f"Database identifier {db_name_colo}" - " has already been used. Pass in a unique" - " name for db_identifier" - ) - - db_name_colo = unpack_colo_db_identifier(db_name_colo) - if colo_cfg := entity.run_settings.colocated_db_settings: - port = colo_cfg.get("port", None) - socket = colo_cfg.get("unix_socket", None) - if socket and port: - raise SSInternalError( - "Co-located was configured for both TCP/IP and UDS" - ) - if port: - client_env[f"SSDB{db_name_colo}"] = f"127.0.0.1:{str(port)}" - elif socket: - client_env[f"SSDB{db_name_colo}"] = f"unix://{socket}" - else: - raise SSInternalError( - "Colocated database was not configured for either TCP or UDS" - ) - client_env[f"SR_DB_TYPE{db_name_colo}"] = STANDALONE - - entity.run_settings.update_env(client_env) - - def _save_orchestrator(self, orchestrator: Orchestrator) -> None: - """Save the orchestrator object via pickle - - This function saves the orchestrator information to a pickle - file that can be imported by subsequent experiments to reconnect - to the orchestrator. - - :param orchestrator: Orchestrator configuration to be saved - """ - - if not orchestrator.is_active(): - raise Exception("Orchestrator is not running") - - # Extract only the db_jobs associated with this particular orchestrator - if orchestrator.batch: - job_names = [orchestrator.name] - else: - job_names = [dbnode.name for dbnode in orchestrator.entities] - db_jobs = { - name: job for name, job in self._jobs.db_jobs.items() if name in job_names - } - - # Extract the associated steps - steps = [ - self._launcher.step_mapping[db_job.name] for db_job in db_jobs.values() - ] - - orc_data = {"db": orchestrator, "db_jobs": db_jobs, "steps": steps} - - with open(orchestrator.checkpoint_file, "wb") as pickle_file: - pickle.dump(orc_data, pickle_file) - - def _orchestrator_launch_wait(self, orchestrator: Orchestrator) -> None: - """Wait for the orchestrator instances to run - - In the case where the orchestrator is launched as a batch - through a WLM, we wait for the orchestrator to exit the - queue before proceeding so new launched entities can - be launched with SSDB address - - :param orchestrator: orchestrator instance - :raises SmartSimError: if launch fails or manually stopped by user - """ - if orchestrator.batch: - logger.info("Orchestrator launched as a batch") - logger.info("While queued, SmartSim will wait for Orchestrator to run") - logger.info("CTRL+C interrupt to abort and cancel launch") - - ready = False - while not ready: - try: - time.sleep(CONFIG.jm_interval) - # manually trigger job update if JM not running - if not self._jobs.actively_monitoring: - self._jobs.check_jobs() - - # _jobs.get_status acquires JM lock for main thread, no need for locking - statuses = self.get_entity_list_status(orchestrator) - if all(stat == SmartSimStatus.STATUS_RUNNING for stat in statuses): - ready = True - # TODO: Add a node status check - elif any(stat in TERMINAL_STATUSES for stat in statuses): - self.stop_db(orchestrator) - msg = "Orchestrator failed during startup" - msg += f" See {orchestrator.path} for details" - raise SmartSimError(msg) - else: - logger.debug("Waiting for orchestrator instances to spin up...") - except KeyboardInterrupt: - logger.info("Orchestrator launch cancelled - requesting to stop") - self.stop_db(orchestrator) - - # re-raise keyboard interrupt so the job manager will display - # any running and un-killed jobs as this method is only called - # during launch and we handle all keyboard interrupts during - # launch explicitly - raise - - def reload_saved_db( - self, checkpoint_file: t.Union[str, os.PathLike[str]] - ) -> Orchestrator: - with JM_LOCK: - - if not osp.exists(checkpoint_file): - raise FileNotFoundError( - f"The SmartSim database config file {os.fspath(checkpoint_file)} " - "cannot be found." - ) - - try: - with open(checkpoint_file, "rb") as pickle_file: - db_config = pickle.load(pickle_file) - except (OSError, IOError) as e: - msg = "Database checkpoint corrupted" - raise SmartSimError(msg) from e - - err_message = ( - "The SmartSim database checkpoint is incomplete or corrupted. " - ) - if not "db" in db_config: - raise SmartSimError( - err_message + "Could not find the orchestrator object." - ) - - if not "db_jobs" in db_config: - raise SmartSimError( - err_message + "Could not find database job objects." - ) - - if not "steps" in db_config: - raise SmartSimError( - err_message + "Could not find database job objects." - ) - orc: Orchestrator = db_config["db"] - - # TODO check that each db_object is running - - job_steps = zip(db_config["db_jobs"].values(), db_config["steps"]) - try: - for db_job, step in job_steps: - self._jobs.db_jobs[db_job.ename] = db_job - self._launcher.add_step_to_mapping_table(db_job.name, step) - if step.task_id: - self._launcher.task_manager.add_existing(int(step.task_id)) - except LauncherError as e: - raise SmartSimError("Failed to reconnect orchestrator") from e - - # start job manager if not already started - if not self._jobs.actively_monitoring: - self._jobs.start() - - return orc - - def _set_dbobjects(self, manifest: Manifest) -> None: - if not manifest.has_db_objects: - return - - address_dict = self._jobs.get_db_host_addresses() - for ( - db_id, - db_addresses, - ) in address_dict.items(): - db_name, name = unpack_db_identifier(db_id, "_") - - hosts = list({address.split(":")[0] for address in db_addresses}) - ports = list({int(address.split(":")[-1]) for address in db_addresses}) - - if not db_is_active(hosts=hosts, ports=ports, num_shards=len(db_addresses)): - raise SSInternalError("Cannot set DB Objects, DB is not running") - - os.environ[f"SSDB{db_name}"] = db_addresses[0] - - os.environ[f"SR_DB_TYPE{db_name}"] = ( - CLUSTERED if len(db_addresses) > 1 else STANDALONE - ) - - options = ConfigOptions.create_from_environment(name) - client = Client(options, logger_name="SmartSim") - - for model in manifest.models: - if not model.colocated: - for db_model in model.db_models: - set_ml_model(db_model, client) - for db_script in model.db_scripts: - set_script(db_script, client) - - for ensemble in manifest.ensembles: - for db_model in ensemble.db_models: - set_ml_model(db_model, client) - for db_script in ensemble.db_scripts: - set_script(db_script, client) - for entity in ensemble.models: - if not entity.colocated: - # Set models which could belong only - # to the entities and not to the ensemble - # but avoid duplicates - for db_model in entity.db_models: - if db_model not in ensemble.db_models: - set_ml_model(db_model, client) - for db_script in entity.db_scripts: - if db_script not in ensemble.db_scripts: - set_script(db_script, client) - - def _start_telemetry_monitor(self, exp_dir: str) -> None: - """Spawns a telemetry monitor process to keep track of the life times - of the processes launched through this controller. - - :param exp_dir: An experiment directory - """ - if ( - self._telemetry_monitor is None - or self._telemetry_monitor.returncode is not None - ): - logger.debug("Starting telemetry monitor process") - cmd = [ - sys.executable, - "-m", - "smartsim._core.entrypoints.telemetrymonitor", - "-exp_dir", - exp_dir, - "-frequency", - str(CONFIG.telemetry_frequency), - "-cooldown", - str(CONFIG.telemetry_cooldown), - ] - # pylint: disable-next=consider-using-with - self._telemetry_monitor = subprocess.Popen( - cmd, - stderr=sys.stderr, - stdout=sys.stdout, - cwd=str(pathlib.Path(__file__).parent.parent.parent), - shell=False, - ) diff --git a/smartsim/_core/control/controller_utils.py b/smartsim/_core/control/controller_utils.py deleted file mode 100644 index 37ae9aebfb..0000000000 --- a/smartsim/_core/control/controller_utils.py +++ /dev/null @@ -1,77 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -import pathlib -import typing as t - -from ..._core.launcher.step import Step -from ...entity import EntityList, Model -from ...error import SmartSimError -from ..launcher.launcher import Launcher - -if t.TYPE_CHECKING: - from ..utils.serialize import TStepLaunchMetaData - - -class _AnonymousBatchJob(EntityList[Model]): - @staticmethod - def _validate(model: Model) -> None: - if model.batch_settings is None: - msg = "Unable to create _AnonymousBatchJob without batch_settings" - raise SmartSimError(msg) - - def __init__(self, model: Model) -> None: - self._validate(model) - super().__init__(model.name, model.path) - self.entities = [model] - self.batch_settings = model.batch_settings - - def _initialize_entities(self, **kwargs: t.Any) -> None: ... - - -def _look_up_launched_data( - launcher: Launcher, -) -> t.Callable[[t.Tuple[str, Step]], "TStepLaunchMetaData"]: - def _unpack_launched_data(data: t.Tuple[str, Step]) -> "TStepLaunchMetaData": - # NOTE: we cannot assume that the name of the launched step - # ``launched_step_name`` is equal to the name of the step referring to - # the entity ``step.name`` as is the case when an entity list is - # launched as a batch job - launched_step_name, step = data - launched_step_map = launcher.step_mapping[launched_step_name] - out_file, err_file = step.get_output_files() - return ( - launched_step_map.step_id, - launched_step_map.task_id, - launched_step_map.managed, - out_file, - err_file, - pathlib.Path(step.meta.get("status_dir", step.cwd)), - ) - - return _unpack_launched_data diff --git a/smartsim/_core/control/interval.py b/smartsim/_core/control/interval.py new file mode 100644 index 0000000000..e35b1c694c --- /dev/null +++ b/smartsim/_core/control/interval.py @@ -0,0 +1,112 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import time +import typing as t + +Seconds = t.NewType("Seconds", float) + + +class SynchronousTimeInterval: + """A utility class to represent and synchronously block the execution of a + thread for an interval of time. + """ + + def __init__(self, delta: float | None) -> None: + """Initialize a new `SynchronousTimeInterval` interval + + :param delta: The difference in time the interval represents in + seconds. If `None`, the interval will represent an infinite amount + of time. + :raises ValueError: The `delta` is negative + """ + if delta is not None and delta < 0: + raise ValueError("Timeout value cannot be less than 0") + if delta is None: + delta = float("inf") + self._delta = Seconds(delta) + """The amount of time, in seconds, the interval spans.""" + self._start = time.perf_counter() + """The time of the creation of the interval""" + + @property + def delta(self) -> Seconds: + """The difference in time the interval represents + + :returns: The difference in time the interval represents + """ + return self._delta + + @property + def elapsed(self) -> Seconds: + """The amount of time that has passed since the interval was created + + :returns: The amount of time that has passed since the interval was + created + """ + return Seconds(time.perf_counter() - self._start) + + @property + def remaining(self) -> Seconds: + """The amount of time remaining in the interval + + :returns: The amount of time remaining in the interval + """ + return Seconds(max(self.delta - self.elapsed, 0)) + + @property + def expired(self) -> bool: + """The amount of time remaining in interval + + :returns: The amount of time left in the interval + """ + return self.remaining <= 0 + + @property + def infinite(self) -> bool: + """Return true if the timeout interval is infinitely long + + :returns: `True` if the delta is infinite, `False` otherwise + """ + return self.remaining == float("inf") + + def new_interval(self) -> SynchronousTimeInterval: + """Make a new timeout with the same interval + + :returns: The new time interval + """ + return type(self)(self.delta) + + def block(self) -> None: + """Block the thread until the timeout completes + + :raises RuntimeError: The thread would be blocked forever + """ + if self.remaining == float("inf"): + raise RuntimeError("Cannot block thread forever") + time.sleep(self.remaining) diff --git a/smartsim/_core/control/job.py b/smartsim/_core/control/job.py index 6941d7607a..91609349ad 100644 --- a/smartsim/_core/control/job.py +++ b/smartsim/_core/control/job.py @@ -29,8 +29,8 @@ import typing as t from dataclasses import dataclass -from ...entity import EntitySequence, SmartSimEntity -from ...status import SmartSimStatus +from ...entity import SmartSimEntity +from ...status import JobStatus @dataclass(frozen=True) @@ -47,8 +47,7 @@ class _JobKey: class JobEntity: """An entity containing run-time SmartSimEntity metadata. The run-time metadata - is required to perform telemetry collection. The `JobEntity` satisfies the core - API necessary to use a `JobManager` to manage retrieval of managed step updates. + is required to perform telemetry collection. """ def __init__(self) -> None: @@ -76,9 +75,9 @@ def __init__(self) -> None: """Flag indicating if the entity has completed execution""" @property - def is_db(self) -> bool: - """Returns `True` if the entity represents a database or database shard""" - return self.type in ["orchestrator", "dbnode"] + def is_fs(self) -> bool: + """Returns `True` if the entity represents a feature store or feature store shard""" + return self.type in ["featurestore", "fsnode"] @property def is_managed(self) -> bool: @@ -112,13 +111,13 @@ def check_completion_status(self) -> None: self._is_complete = True @staticmethod - def _map_db_metadata(entity_dict: t.Dict[str, t.Any], entity: "JobEntity") -> None: - """Map DB-specific properties from a runtime manifest onto a `JobEntity` + def _map_fs_metadata(entity_dict: t.Dict[str, t.Any], entity: "JobEntity") -> None: + """Map FS-specific properties from a runtime manifest onto a `JobEntity` :param entity_dict: The raw dictionary deserialized from manifest JSON :param entity: The entity instance to modify """ - if entity.is_db: + if entity.is_fs: # add collectors if they're configured to be enabled in the manifest entity.collectors = { "client": entity_dict.get("client_file", ""), @@ -184,47 +183,42 @@ def from_manifest( cls._map_standard_metadata( entity_type, entity_dict, entity, exp_dir, raw_experiment ) - cls._map_db_metadata(entity_dict, entity) + cls._map_fs_metadata(entity_dict, entity) return entity class Job: - """Keep track of various information for the controller. - In doing so, continuously add various fields of information - that is queryable by the user through interface methods in - the controller class. + """Keep track of various information. + In doing so, continuously add various fields of information. """ def __init__( self, job_name: str, job_id: t.Optional[str], - entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity], JobEntity], + entity: t.Union[SmartSimEntity, JobEntity], launcher: str, - is_task: bool, ) -> None: """Initialize a Job. :param job_name: Name of the job step :param job_id: The id associated with the job - :param entity: The SmartSim entity(list) associated with the job + :param entity: The SmartSim entity associated with the job :param launcher: Launcher job was started with - :param is_task: process monitored by TaskManager (True) or the WLM (True) """ self.name = job_name self.jid = job_id self.entity = entity - self.status = SmartSimStatus.STATUS_NEW + self.status = JobStatus.NEW # status before smartsim status mapping is applied self.raw_status: t.Optional[str] = None self.returncode: t.Optional[int] = None # output is only populated if it's system related (e.g. cmd failed immediately) self.output: t.Optional[str] = None self.error: t.Optional[str] = None # same as output - self.hosts: t.List[str] = [] # currently only used for DB jobs + self.hosts: t.List[str] = [] # currently only used for FS jobs self.launched_with = launcher - self.is_task = is_task self.start_time = time.time() self.history = History() @@ -235,7 +229,7 @@ def ename(self) -> str: def set_status( self, - new_status: SmartSimStatus, + new_status: JobStatus, raw_status: str, returncode: t.Optional[int], error: t.Optional[str] = None, @@ -263,23 +257,19 @@ def record_history(self) -> None: """Record the launching history of a job.""" self.history.record(self.jid, self.status, self.returncode, self.elapsed) - def reset( - self, new_job_name: str, new_job_id: t.Optional[str], is_task: bool - ) -> None: + def reset(self, new_job_name: str, new_job_id: t.Optional[str]) -> None: """Reset the job in order to be able to restart it. :param new_job_name: name of the new job step :param new_job_id: new job id to launch under - :param is_task: process monitored by TaskManager (True) or the WLM (True) """ self.name = new_job_name self.jid = new_job_id - self.status = SmartSimStatus.STATUS_NEW + self.status = JobStatus.NEW self.returncode = None self.output = None self.error = None self.hosts = [] - self.is_task = is_task self.start_time = time.time() self.history.new_run() @@ -299,7 +289,7 @@ def error_report(self) -> str: warning += f"Job status at failure: {self.status} \n" warning += f"Launcher status at failure: {self.raw_status} \n" warning += f"Job returncode: {self.returncode} \n" - warning += f"Error and output file located at: {self.entity.path}" + # warning += f"Error and output file located at: {self.entity.path}" return warning def __str__(self) -> str: @@ -327,14 +317,14 @@ def __init__(self, runs: int = 0) -> None: """ self.runs = runs self.jids: t.Dict[int, t.Optional[str]] = {} - self.statuses: t.Dict[int, SmartSimStatus] = {} + self.statuses: t.Dict[int, JobStatus] = {} self.returns: t.Dict[int, t.Optional[int]] = {} self.job_times: t.Dict[int, float] = {} def record( self, job_id: t.Optional[str], - status: SmartSimStatus, + status: JobStatus, returncode: t.Optional[int], job_time: float, ) -> None: diff --git a/smartsim/_core/control/jobmanager.py b/smartsim/_core/control/jobmanager.py deleted file mode 100644 index 1bc24cf9af..0000000000 --- a/smartsim/_core/control/jobmanager.py +++ /dev/null @@ -1,364 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -import itertools -import time -import typing as t -from collections import ChainMap -from threading import RLock, Thread -from types import FrameType - -from ...database import Orchestrator -from ...entity import DBNode, EntitySequence, SmartSimEntity -from ...log import ContextThread, get_logger -from ...status import TERMINAL_STATUSES, SmartSimStatus -from ..config import CONFIG -from ..launcher import Launcher, LocalLauncher -from ..utils.network import get_ip_from_host -from .job import Job, JobEntity - -logger = get_logger(__name__) - - -class JobManager: - """The JobManager maintains a mapping between user defined entities - and the steps launched through the launcher. The JobManager - holds jobs according to entity type. - - The JobManager is threaded and runs during the course of an experiment - to update the statuses of Jobs. - - The JobManager and Controller share a single instance of a launcher - object that allows both the Controller and launcher access to the - wlm to query information about jobs that the user requests. - """ - - def __init__(self, lock: RLock, launcher: t.Optional[Launcher] = None) -> None: - """Initialize a Jobmanager - - :param launcher: a Launcher object to manage jobs - """ - self.monitor: t.Optional[Thread] = None - - # active jobs - self.jobs: t.Dict[str, Job] = {} - self.db_jobs: t.Dict[str, Job] = {} - - # completed jobs - self.completed: t.Dict[str, Job] = {} - - self.actively_monitoring = False # on/off flag - self._launcher = launcher # reference to launcher - self._lock = lock # thread lock - - self.kill_on_interrupt = True # flag for killing jobs on SIGINT - - def start(self) -> None: - """Start a thread for the job manager""" - self.monitor = ContextThread(name="JobManager", daemon=True, target=self.run) - self.monitor.start() - - def run(self) -> None: - """Start the JobManager thread to continually check - the status of all jobs. Whichever launcher is selected - by the user will be responsible for returning statuses - that progress the state of the job. - - The interval of the checks is controlled by - smartsim.constats.TM_INTERVAL and should be set to values - above 20 for congested, multi-user systems - - The job manager thread will exit when no jobs are left - or when the main thread dies - """ - logger.debug("Starting Job Manager") - self.actively_monitoring = True - while self.actively_monitoring: - self._thread_sleep() - self.check_jobs() # update all job statuses at once - for _, job in self().items(): - # if the job has errors then output the report - # this should only output once - if job.returncode is not None and job.status in TERMINAL_STATUSES: - if int(job.returncode) != 0: - logger.warning(job) - logger.warning(job.error_report()) - self.move_to_completed(job) - else: - # job completed without error - logger.info(job) - self.move_to_completed(job) - - # if no more jobs left to actively monitor - if not self(): - self.actively_monitoring = False - logger.debug("Sleeping, no jobs to monitor") - - def move_to_completed(self, job: Job) -> None: - """Move job to completed queue so that its no longer - actively monitored by the job manager - - :param job: job instance we are transitioning - """ - with self._lock: - self.completed[job.ename] = job - job.record_history() - - # remove from actively monitored jobs - if job.ename in self.db_jobs: - del self.db_jobs[job.ename] - elif job.ename in self.jobs: - del self.jobs[job.ename] - - def __getitem__(self, entity_name: str) -> Job: - """Return the job associated with the name of the entity - from which it was created. - - :param entity_name: The name of the entity of a job - :returns: the Job associated with the entity_name - """ - with self._lock: - entities = ChainMap(self.db_jobs, self.jobs, self.completed) - return entities[entity_name] - - def __call__(self) -> t.Dict[str, Job]: - """Returns dictionary all jobs for () operator - - :returns: Dictionary of all jobs - """ - all_jobs = {**self.jobs, **self.db_jobs} - return all_jobs - - def __contains__(self, key: str) -> bool: - try: - self[key] # pylint: disable=pointless-statement - return True - except KeyError: - return False - - def add_job( - self, - job_name: str, - job_id: t.Optional[str], - entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity], JobEntity], - is_task: bool = True, - ) -> None: - """Add a job to the job manager which holds specific jobs by type. - - :param job_name: name of the job step - :param job_id: job step id created by launcher - :param entity: entity that was launched on job step - :param is_task: process monitored by TaskManager (True) or the WLM (True) - """ - launcher = str(self._launcher) - # all operations here should be atomic - job = Job(job_name, job_id, entity, launcher, is_task) - if isinstance(entity, (DBNode, Orchestrator)): - self.db_jobs[entity.name] = job - elif isinstance(entity, JobEntity) and entity.is_db: - self.db_jobs[entity.name] = job - else: - self.jobs[entity.name] = job - - def is_finished(self, entity: SmartSimEntity) -> bool: - """Detect if a job has completed - - :param entity: entity to check - :return: True if finished - """ - with self._lock: - job = self[entity.name] # locked operation - if entity.name in self.completed: - if job.status in TERMINAL_STATUSES: - return True - return False - - def check_jobs(self) -> None: - """Update all jobs in jobmanager - - Update all jobs returncode, status, error and output - through one call to the launcher. - - """ - with self._lock: - jobs = self().values() - job_name_map = {job.name: job.ename for job in jobs} - - # returns (job step name, StepInfo) tuples - if self._launcher: - step_names = list(job_name_map.keys()) - statuses = self._launcher.get_step_update(step_names) - for job_name, status in statuses: - job = self[job_name_map[job_name]] - - if status: - # uses abstract step interface - job.set_status( - status.status, - status.launcher_status, - status.returncode, - error=status.error, - output=status.output, - ) - - def get_status( - self, - entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], - ) -> SmartSimStatus: - """Return the status of a job. - - :param entity: SmartSimEntity or EntitySequence instance - :returns: a SmartSimStatus status - """ - with self._lock: - if entity.name in self.completed: - return self.completed[entity.name].status - - if entity.name in self: - job: Job = self[entity.name] # locked - return job.status - - return SmartSimStatus.STATUS_NEVER_STARTED - - def set_launcher(self, launcher: Launcher) -> None: - """Set the launcher of the job manager to a specific launcher instance - - :param launcher: child of Launcher - """ - self._launcher = launcher - - def query_restart(self, entity_name: str) -> bool: - """See if the job just started should be restarted or not. - - :param entity_name: name of entity to check for a job for - :return: if job should be restarted instead of started - """ - if entity_name in self.completed: - return True - return False - - def restart_job( - self, - job_name: str, - job_id: t.Optional[str], - entity_name: str, - is_task: bool = True, - ) -> None: - """Function to reset a job to record history and be - ready to launch again. - - :param job_name: new job step name - :param job_id: new job id - :param entity_name: name of the entity of the job - :param is_task: process monitored by TaskManager (True) or the WLM (True) - - """ - with self._lock: - job = self.completed[entity_name] - del self.completed[entity_name] - job.reset(job_name, job_id, is_task) - - if isinstance(job.entity, (DBNode, Orchestrator)): - self.db_jobs[entity_name] = job - else: - self.jobs[entity_name] = job - - def get_db_host_addresses(self) -> t.Dict[str, t.List[str]]: - """Retrieve the list of hosts for the database - for corresponding database identifiers - - :return: dictionary of host ip addresses - """ - - address_dict: t.Dict[str, t.List[str]] = {} - for db_job in self.db_jobs.values(): - addresses = [] - if isinstance(db_job.entity, (DBNode, Orchestrator)): - db_entity = db_job.entity - for combine in itertools.product(db_job.hosts, db_entity.ports): - ip_addr = get_ip_from_host(combine[0]) - addresses.append(":".join((ip_addr, str(combine[1])))) - - dict_entry: t.List[str] = address_dict.get(db_entity.db_identifier, []) - dict_entry.extend(addresses) - address_dict[db_entity.db_identifier] = dict_entry - - return address_dict - - def set_db_hosts(self, orchestrator: Orchestrator) -> None: - """Set the DB hosts in db_jobs so future entities can query this - - :param orchestrator: orchestrator instance - """ - # should only be called during launch in the controller - - with self._lock: - if orchestrator.batch: - self.db_jobs[orchestrator.name].hosts = orchestrator.hosts - - else: - for dbnode in orchestrator.entities: - if not dbnode.is_mpmd: - self.db_jobs[dbnode.name].hosts = [dbnode.host] - else: - self.db_jobs[dbnode.name].hosts = dbnode.hosts - - def signal_interrupt(self, signo: int, _frame: t.Optional[FrameType]) -> None: - """Custom handler for whenever SIGINT is received""" - if not signo: - logger.warning("Received SIGINT with no signal number") - if self.actively_monitoring and len(self) > 0: - if self.kill_on_interrupt: - for _, job in self().items(): - if job.status not in TERMINAL_STATUSES and self._launcher: - self._launcher.stop(job.name) - else: - logger.warning("SmartSim process interrupted before resource cleanup") - logger.warning("You may need to manually stop the following:") - - for job_name, job in self().items(): - if job.is_task: - # this will be the process id - logger.warning(f"Task {job_name} with id: {job.jid}") - else: - logger.warning( - f"Job {job_name} with {job.launched_with} id: {job.jid}" - ) - - def _thread_sleep(self) -> None: - """Sleep the job manager for a specific constant - set for the launcher type. - """ - local_jm_interval = 2 - if isinstance(self._launcher, (LocalLauncher)): - time.sleep(local_jm_interval) - else: - time.sleep(CONFIG.jm_interval) - - def __len__(self) -> int: - # number of active jobs - return len(self.db_jobs) + len(self.jobs) diff --git a/smartsim/_core/control/launch_history.py b/smartsim/_core/control/launch_history.py new file mode 100644 index 0000000000..e7f04a4ffa --- /dev/null +++ b/smartsim/_core/control/launch_history.py @@ -0,0 +1,96 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import dataclasses +import typing as t + +from smartsim._core.utils import helpers as _helpers + +if t.TYPE_CHECKING: + from smartsim._core.utils.launcher import LauncherProtocol + from smartsim.types import LaunchedJobID + + +@dataclasses.dataclass(frozen=True) +class LaunchHistory: + """A cache to manage and quickly look up which launched job ids were + issued by which launcher + """ + + _id_to_issuer: dict[LaunchedJobID, LauncherProtocol[t.Any]] = dataclasses.field( + default_factory=dict + ) + + def save_launch( + self, launcher: LauncherProtocol[t.Any], id_: LaunchedJobID + ) -> None: + """Save a launcher and a launch job id that it issued for later + reference. + + :param launcher: A launcher that started a job and issued an id for + that job + :param id_: The id of the launched job started by the launcher + :raises ValueError: An id of equal value has already been saved + """ + if id_ in self._id_to_issuer: + raise ValueError("An ID of that value has already been saved") + self._id_to_issuer[id_] = launcher + + def iter_past_launchers(self) -> t.Iterable[LauncherProtocol[t.Any]]: + """Iterate over the unique launcher instances stored in history + + :returns: An iterator over unique launcher instances + """ + return _helpers.unique(self._id_to_issuer.values()) + + def group_by_launcher( + self, ids: t.Collection[LaunchedJobID] | None = None, unknown_ok: bool = False + ) -> dict[LauncherProtocol[t.Any], set[LaunchedJobID]]: + """Return a mapping of launchers to launched job ids issued by that + launcher. + + :param ids: The subset launch ids to group by common launchers. + :param unknown_ok: If set to `True` and the history is unable to + determine which launcher instance issued a requested launched job + id, the history will silently omit the id from the returned + mapping. If set to `False` a `ValueError` will be raised instead. + Set to `False` by default. + :raises ValueError: An unknown launch id was requested to be grouped by + launcher, and `unknown_ok` is set to `False`. + :returns: A mapping of launchers to collections of launched job ids + that were issued by that launcher. + """ + if ids is None: + ids = self._id_to_issuer + launchers_to_launched = _helpers.group_by(self._id_to_issuer.get, ids) + unknown = launchers_to_launched.get(None, []) + if unknown and not unknown_ok: + formatted_unknown = ", ".join(unknown) + msg = f"IDs {formatted_unknown} could not be mapped back to a launcher" + raise ValueError(msg) + return {k: set(v) for k, v in launchers_to_launched.items() if k is not None} diff --git a/smartsim/_core/control/manifest.py b/smartsim/_core/control/manifest.py index fd5770f187..20d302f624 100644 --- a/smartsim/_core/control/manifest.py +++ b/smartsim/_core/control/manifest.py @@ -29,8 +29,9 @@ import typing as t from dataclasses import dataclass, field -from ...database import Orchestrator -from ...entity import DBNode, Ensemble, EntitySequence, Model, SmartSimEntity +from ...builders import Ensemble +from ...database import FeatureStore +from ...entity import Application, FSNode, SmartSimEntity from ...error import SmartSimError from ..config import CONFIG from ..utils import helpers as _helpers @@ -38,7 +39,7 @@ _T = t.TypeVar("_T") _U = t.TypeVar("_U") -_AtomicLaunchableT = t.TypeVar("_AtomicLaunchableT", Model, DBNode) +_AtomicLaunchableT = t.TypeVar("_AtomicLaunchableT", Application, FSNode) if t.TYPE_CHECKING: import os @@ -47,41 +48,38 @@ class Manifest: """This class is used to keep track of all deployables generated by an experiment. Different types of deployables (i.e. different - `SmartSimEntity`-derived objects or `EntitySequence`-derived objects) can + `SmartSimEntity`-derived objects) can be accessed by using the corresponding accessor. - Instances of ``Model``, ``Ensemble`` and ``Orchestrator`` + Instances of ``Application``, ``Ensemble`` and ``FeatureStore`` can all be passed as arguments """ - def __init__( - self, *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> None: + def __init__(self, *args: t.Union[SmartSimEntity]) -> None: self._deployables = list(args) self._check_types(self._deployables) self._check_names(self._deployables) - self._check_entity_lists_nonempty() @property - def dbs(self) -> t.List[Orchestrator]: - """Return a list of Orchestrator instances in Manifest + def fss(self) -> t.List[FeatureStore]: + """Return a list of FeatureStore instances in Manifest - :raises SmartSimError: if user added to databases to manifest - :return: List of orchestrator instances + :raises SmartSimError: if user added to feature stores to manifest + :return: List of feature store instances """ - dbs = [item for item in self._deployables if isinstance(item, Orchestrator)] - return dbs + fss = [item for item in self._deployables if isinstance(item, FeatureStore)] + return fss @property - def models(self) -> t.List[Model]: - """Return Model instances in Manifest + def applications(self) -> t.List[Application]: + """Return Application instances in Manifest - :return: model instances + :return: application instances """ - _models: t.List[Model] = [ - item for item in self._deployables if isinstance(item, Model) + _applications: t.List[Application] = [ + item for item in self._deployables if isinstance(item, Application) ] - return _models + return _applications @property def ensembles(self) -> t.List[Ensemble]: @@ -91,20 +89,6 @@ def ensembles(self) -> t.List[Ensemble]: """ return [e for e in self._deployables if isinstance(e, Ensemble)] - @property - def all_entity_lists(self) -> t.List[EntitySequence[SmartSimEntity]]: - """All entity lists, including ensembles and - exceptional ones like Orchestrator - - :return: list of entity lists - """ - _all_entity_lists: t.List[EntitySequence[SmartSimEntity]] = list(self.ensembles) - - for db in self.dbs: - _all_entity_lists.append(db) - - return _all_entity_lists - @property def has_deployable(self) -> bool: """ @@ -127,24 +111,16 @@ def _check_names(deployables: t.List[t.Any]) -> None: @staticmethod def _check_types(deployables: t.List[t.Any]) -> None: for deployable in deployables: - if not isinstance(deployable, (SmartSimEntity, EntitySequence)): + if not isinstance(deployable, SmartSimEntity): raise TypeError( - f"Entity has type {type(deployable)}, not " - + "SmartSimEntity or EntitySequence" + f"Entity has type {type(deployable)}, not " + "SmartSimEntity" ) - def _check_entity_lists_nonempty(self) -> None: - """Check deployables for sanity before launching""" - - for entity_list in self.all_entity_lists: - if len(entity_list) < 1: - raise ValueError(f"{entity_list.name} is empty. Nothing to launch.") - def __str__(self) -> str: output = "" e_header = "=== Ensembles ===\n" - m_header = "=== Models ===\n" - db_header = "=== Database ===\n" + a_header = "=== Applications ===\n" + fs_header = "=== Feature Stores ===\n" if self.ensembles: output += e_header @@ -157,38 +133,38 @@ def __str__(self) -> str: output += f"{str(ensemble.batch_settings)}\n" output += "\n" - if self.models: - output += m_header - for model in self.models: - output += f"{model.name}\n" - if model.batch_settings: - output += f"{model.batch_settings}\n" - output += f"{model.run_settings}\n" - if model.params: - output += f"Parameters: \n{_helpers.fmt_dict(model.params)}\n" + if self.applications: + output += a_header + for application in self.applications: + output += f"{application.name}\n" + if application.batch_settings: + output += f"{application.batch_settings}\n" + output += f"{application.run_settings}\n" + if application.params: + output += f"Parameters: \n{_helpers.fmt_dict(application.params)}\n" output += "\n" - for adb in self.dbs: - output += db_header - output += f"Shards: {adb.num_shards}\n" - output += f"Port: {str(adb.ports[0])}\n" - output += f"Network: {adb._interfaces}\n" - output += f"Batch Launch: {adb.batch}\n" - if adb.batch: - output += f"{str(adb.batch_settings)}\n" + for afs in self.fss: + output += fs_header + output += f"Shards: {afs.num_shards}\n" + output += f"Port: {str(afs.ports[0])}\n" + output += f"Network: {afs._interfaces}\n" + output += f"Batch Launch: {afs.batch}\n" + if afs.batch: + output += f"{str(afs.batch_settings)}\n" output += "\n" return output @property - def has_db_objects(self) -> bool: - """Check if any entity has DBObjects to set""" - ents: t.Iterable[t.Union[Model, Ensemble]] = itertools.chain( - self.models, + def has_fs_objects(self) -> bool: + """Check if any entity has FSObjects to set""" + ents: t.Iterable[t.Union[Application, Ensemble]] = itertools.chain( + self.applications, self.ensembles, (member for ens in self.ensembles for member in ens.entities), ) - return any(any(ent.db_models) or any(ent.db_scripts) for ent in ents) + return any(any(ent.fs_models) or any(ent.fs_scripts) for ent in ents) class _LaunchedManifestMetadata(t.NamedTuple): @@ -215,14 +191,15 @@ class LaunchedManifest(t.Generic[_T]): """Immutable manifest mapping launched entities or collections of launched entities to other pieces of external data. This is commonly used to map a launch-able entity to its constructed ``Step`` instance without assuming - that ``step.name == job.name`` or querying the ``JobManager`` which itself - can be ephemeral. + that ``step.name == job.name``. """ metadata: _LaunchedManifestMetadata - models: t.Tuple[t.Tuple[Model, _T], ...] - ensembles: t.Tuple[t.Tuple[Ensemble, t.Tuple[t.Tuple[Model, _T], ...]], ...] - databases: t.Tuple[t.Tuple[Orchestrator, t.Tuple[t.Tuple[DBNode, _T], ...]], ...] + applications: t.Tuple[t.Tuple[Application, _T], ...] + ensembles: t.Tuple[t.Tuple[Ensemble, t.Tuple[t.Tuple[Application, _T], ...]], ...] + featurestores: t.Tuple[ + t.Tuple[FeatureStore, t.Tuple[t.Tuple[FSNode, _T], ...]], ... + ] def map(self, func: t.Callable[[_T], _U]) -> "LaunchedManifest[_U]": def _map_entity_data( @@ -233,14 +210,14 @@ def _map_entity_data( return LaunchedManifest( metadata=self.metadata, - models=_map_entity_data(func, self.models), + applications=_map_entity_data(func, self.applications), ensembles=tuple( - (ens, _map_entity_data(func, model_data)) - for ens, model_data in self.ensembles + (ens, _map_entity_data(func, application_data)) + for ens, application_data in self.ensembles ), - databases=tuple( - (db_, _map_entity_data(func, node_data)) - for db_, node_data in self.databases + featurestores=tuple( + (fs_, _map_entity_data(func, node_data)) + for fs_, node_data in self.featurestores ), ) @@ -257,11 +234,13 @@ class LaunchedManifestBuilder(t.Generic[_T]): launcher_name: str run_id: str = field(default_factory=_helpers.create_short_id_str) - _models: t.List[t.Tuple[Model, _T]] = field(default_factory=list, init=False) - _ensembles: t.List[t.Tuple[Ensemble, t.Tuple[t.Tuple[Model, _T], ...]]] = field( + _applications: t.List[t.Tuple[Application, _T]] = field( default_factory=list, init=False ) - _databases: t.List[t.Tuple[Orchestrator, t.Tuple[t.Tuple[DBNode, _T], ...]]] = ( + _ensembles: t.List[t.Tuple[Ensemble, t.Tuple[t.Tuple[Application, _T], ...]]] = ( + field(default_factory=list, init=False) + ) + _featurestores: t.List[t.Tuple[FeatureStore, t.Tuple[t.Tuple[FSNode, _T], ...]]] = ( field(default_factory=list, init=False) ) @@ -273,14 +252,14 @@ def exp_telemetry_subdirectory(self) -> pathlib.Path: def run_telemetry_subdirectory(self) -> pathlib.Path: return _format_run_telemetry_path(self.exp_path, self.exp_name, self.run_id) - def add_model(self, model: Model, data: _T) -> None: - self._models.append((model, data)) + def add_application(self, application: Application, data: _T) -> None: + self._applications.append((application, data)) def add_ensemble(self, ens: Ensemble, data: t.Sequence[_T]) -> None: self._ensembles.append((ens, self._entities_to_data(ens.entities, data))) - def add_database(self, db_: Orchestrator, data: t.Sequence[_T]) -> None: - self._databases.append((db_, self._entities_to_data(db_.entities, data))) + def add_feature_store(self, fs_: FeatureStore, data: t.Sequence[_T]) -> None: + self._featurestores.append((fs_, self._entities_to_data(fs_.entities, data))) @staticmethod def _entities_to_data( @@ -303,9 +282,9 @@ def finalize(self) -> LaunchedManifest[_T]: self.exp_path, self.launcher_name, ), - models=tuple(self._models), + applications=tuple(self._applications), ensembles=tuple(self._ensembles), - databases=tuple(self._databases), + featurestores=tuple(self._featurestores), ) diff --git a/smartsim/_core/control/previewrenderer.py b/smartsim/_core/control/preview_renderer.py similarity index 92% rename from smartsim/_core/control/previewrenderer.py rename to smartsim/_core/control/preview_renderer.py index 857a703973..17d9ceac15 100644 --- a/smartsim/_core/control/previewrenderer.py +++ b/smartsim/_core/control/preview_renderer.py @@ -33,10 +33,10 @@ import jinja2.utils as u from jinja2 import pass_eval_context -from ..._core.config import CONFIG -from ..._core.control import Manifest from ...error.errors import PreviewFormatError from ...log import get_logger +from ..config import CONFIG +from . import Manifest from .job import Job logger = get_logger(__name__) @@ -65,7 +65,7 @@ def as_toggle(_eval_ctx: u.F, value: bool) -> str: @pass_eval_context def get_ifname(_eval_ctx: u.F, value: t.List[str]) -> str: - """Extract Network Interface from orchestrator run settings.""" + """Extract Network Interface from feature store run settings.""" if value: for val in value: if "ifname=" in val: @@ -75,12 +75,12 @@ def get_ifname(_eval_ctx: u.F, value: t.List[str]) -> str: @pass_eval_context -def get_dbtype(_eval_ctx: u.F, value: str) -> str: - """Extract data base type.""" +def get_fstype(_eval_ctx: u.F, value: str) -> str: + """Extract feature store type.""" if value: if "-cli" in value: - db_type, _ = value.split("/")[-1].split("-", 1) - return db_type + fs_type, _ = value.split("/")[-1].split("-", 1) + return fs_type return "" @@ -112,7 +112,7 @@ def render( verbosity_level: Verbosity = Verbosity.INFO, output_format: Format = Format.PLAINTEXT, output_filename: t.Optional[str] = None, - active_dbjobs: t.Optional[t.Dict[str, Job]] = None, + active_fsjobs: t.Optional[t.Dict[str, Job]] = None, ) -> str: """ Render the template from the supplied entities. @@ -133,7 +133,7 @@ def render( env.filters["as_toggle"] = as_toggle env.filters["get_ifname"] = get_ifname - env.filters["get_dbtype"] = get_dbtype + env.filters["get_fstype"] = get_fstype env.filters["is_list"] = is_list env.globals["Verbosity"] = Verbosity @@ -150,7 +150,7 @@ def render( rendered_preview = tpl.render( exp_entity=exp, - active_dbjobs=active_dbjobs, + active_dbjobs=active_fsjobs, manifest=manifest, config=CONFIG, verbosity_level=verbosity_level, diff --git a/smartsim/_core/dispatch.py b/smartsim/_core/dispatch.py new file mode 100644 index 0000000000..be096366df --- /dev/null +++ b/smartsim/_core/dispatch.py @@ -0,0 +1,389 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import dataclasses +import os +import pathlib +import typing as t + +from typing_extensions import Self, TypeAlias, TypeVarTuple, Unpack + +from smartsim._core.utils import helpers +from smartsim.error import errors +from smartsim.types import LaunchedJobID + +if t.TYPE_CHECKING: + from smartsim._core.arguments.shell import ShellLaunchArguments + from smartsim._core.utils.launcher import LauncherProtocol + from smartsim.experiment import Experiment + from smartsim.settings.arguments import LaunchArguments + + +_Ts = TypeVarTuple("_Ts") + + +WorkingDirectory: TypeAlias = pathlib.Path +"""A working directory represented as a string or PathLike object""" + +_DispatchableT = t.TypeVar("_DispatchableT", bound="LaunchArguments") +"""Any type of luanch arguments, typically used when the type bound by the type +argument is a key a `Dispatcher` dispatch registry +""" +_LaunchableT = t.TypeVar("_LaunchableT") +"""Any type, typically used to bind to a type accepted as the input parameter +to the to the `LauncherProtocol.start` method +""" + +EnvironMappingType: TypeAlias = t.Mapping[str, "str | None"] +"""A mapping of user provided mapping of environment variables in which to run +a job +""" +FormatterType: TypeAlias = t.Callable[ + [ + _DispatchableT, + t.Sequence[str], + WorkingDirectory, + EnvironMappingType, + pathlib.Path, + pathlib.Path, + ], + _LaunchableT, +] +"""A callable that is capable of formatting the components of a job into a type +capable of being launched by a launcher. +""" +_LaunchConfigType: TypeAlias = """_LauncherAdapter[ + t.Sequence[str], + WorkingDirectory, + EnvironMappingType, + pathlib.Path, + pathlib.Path]""" + +"""A launcher adapater that has configured a launcher to launch the components +of a job with some pre-determined launch settings +""" +_UnkownType: TypeAlias = t.NoReturn +"""A type alias for a bottom type. Use this to inform a user that the parameter +a parameter should never be set or a callable will never return +""" + + +@t.final +class Dispatcher: + """A class capable of deciding which launcher type should be used to launch + a given settings type. + + The `Dispatcher` class maintains a type safe API for adding and retrieving + a settings type into the underlying mapping. It does this through two main + methods: `Dispatcher.dispatch` and `Dispatcher.get_dispatch`. + + `Dispatcher.dispatch` takes in a dispatchable type, a launcher type that is + capable of launching a launchable type and formatting function that maps an + instance of the dispatchable type to an instance of the launchable type. + The dispatcher will then take these components and then enter them into its + dispatch registry. `Dispatcher.dispatch` can also be used as a decorator, + to automatically add a dispatchable type dispatch to a dispatcher at type + creation time. + + `Dispatcher.get_dispatch` takes a dispatchable type or instance as a + parameter, and will attempt to look up, in its dispatch registry, how to + dispatch that type. It will then return an object that can configure a + launcher of the expected launcher type. If the dispatchable type was never + registered a `TypeError` will be raised. + """ + + def __init__( + self, + *, + dispatch_registry: ( + t.Mapping[type[LaunchArguments], _DispatchRegistration[t.Any, t.Any]] | None + ) = None, + ) -> None: + """Initialize a new `Dispatcher` + + :param dispatch_registry: A pre-configured dispatch registry that the + dispatcher should use. This registry is not type checked and is + used blindly. This registry is shallow copied, meaning that adding + into the original registry after construction will not mutate the + state of the registry. + """ + self._dispatch_registry = ( + dict(dispatch_registry) if dispatch_registry is not None else {} + ) + + def copy(self) -> Self: + """Create a shallow copy of the Dispatcher""" + return type(self)(dispatch_registry=self._dispatch_registry) + + @t.overload + def dispatch( # Signature when used as a decorator + self, + args: None = ..., + *, + with_format: FormatterType[_DispatchableT, _LaunchableT], + to_launcher: type[LauncherProtocol[_LaunchableT]], + allow_overwrite: bool = ..., + ) -> t.Callable[[type[_DispatchableT]], type[_DispatchableT]]: ... + @t.overload + def dispatch( # Signature when used as a method + self, + args: type[_DispatchableT], + *, + with_format: FormatterType[_DispatchableT, _LaunchableT], + to_launcher: type[LauncherProtocol[_LaunchableT]], + allow_overwrite: bool = ..., + ) -> None: ... + def dispatch( # Actual implementation + self, + args: type[_DispatchableT] | None = None, + *, + with_format: FormatterType[_DispatchableT, _LaunchableT], + to_launcher: type[LauncherProtocol[_LaunchableT]], + allow_overwrite: bool = False, + ) -> t.Callable[[type[_DispatchableT]], type[_DispatchableT]] | None: + """A type safe way to add a mapping of settings type to launcher type + to handle a settings instance at launch time. + """ + err_msg: str | None = None + if getattr(to_launcher, "_is_protocol", False): + err_msg = f"Cannot dispatch to protocol class `{to_launcher.__name__}`" + elif getattr(to_launcher, "__abstractmethods__", frozenset()): + err_msg = f"Cannot dispatch to abstract class `{to_launcher.__name__}`" + if err_msg is not None: + raise TypeError(err_msg) + + def register(args_: type[_DispatchableT], /) -> type[_DispatchableT]: + if args_ in self._dispatch_registry and not allow_overwrite: + launcher_type = self._dispatch_registry[args_].launcher_type + raise TypeError( + f"{args_.__name__} has already been registered to be " + f"launched with {launcher_type}" + ) + self._dispatch_registry[args_] = _DispatchRegistration( + with_format, to_launcher + ) + return args_ + + if args is not None: + register(args) + return None + return register + + def get_dispatch( + self, args: _DispatchableT | type[_DispatchableT] + ) -> _DispatchRegistration[_DispatchableT, _UnkownType]: + """Find a type of launcher that is registered as being able to launch a + settings instance of the provided type + """ + if not isinstance(args, type): + args = type(args) + dispatch_ = self._dispatch_registry.get(args, None) + if dispatch_ is None: + raise TypeError( + f"No dispatch for `{args.__name__}` has been registered " + f"has been registered with {type(self).__name__} `{self}`" + ) + # Note the sleight-of-hand here: we are secretly casting a type of + # `_DispatchRegistration[Any, Any]` -> + # `_DispatchRegistration[_DispatchableT, _LaunchableT]`. + # where `_LaunchableT` is unbound! + # + # This is safe to do if all entries in the mapping were added using a + # type safe method (e.g. `Dispatcher.dispatch`), but if a user were to + # supply a custom dispatch registry or otherwise modify the registry + # this is not necessarily 100% type safe!! + return dispatch_ + + +@t.final +@dataclasses.dataclass(frozen=True) +class _DispatchRegistration(t.Generic[_DispatchableT, _LaunchableT]): + """An entry into the `Dispatcher`'s dispatch registry. This class is simply + a wrapper around a launcher and how to format a `_DispatchableT` instance + to be launched by the afore mentioned launcher. + """ + + formatter: FormatterType[_DispatchableT, _LaunchableT] + launcher_type: type[LauncherProtocol[_LaunchableT]] + + def _is_compatible_launcher(self, launcher: LauncherProtocol[t.Any]) -> bool: + # Disabling because we want to match the type of the dispatch + # *exactly* as specified by the user + # pylint: disable-next=unidiomatic-typecheck + return type(launcher) is self.launcher_type + + def create_new_launcher_configuration( + self, for_experiment: Experiment, with_arguments: _DispatchableT + ) -> _LaunchConfigType: + """Create a new instance of a launcher for an experiment that the + provided settings were set to dispatch, and configure it with the + provided launch settings. + + :param for_experiment: The experiment responsible creating the launcher + :param with_settings: The settings with which to configure the newly + created launcher + :returns: A configured launcher + """ + launcher = self.launcher_type.create(for_experiment) + return self.create_adapter_from_launcher(launcher, with_arguments) + + def create_adapter_from_launcher( + self, launcher: LauncherProtocol[_LaunchableT], arguments: _DispatchableT + ) -> _LaunchConfigType: + """Creates configured launcher from an existing launcher using the + provided settings. + + :param launcher: A launcher that the type of `settings` has been + configured to dispatch to. + :param settings: A settings with which to configure the launcher. + :returns: A configured launcher. + """ + if not self._is_compatible_launcher(launcher): + raise TypeError( + f"Cannot create launcher adapter from launcher `{launcher}` " + f"of type `{type(launcher)}`; expected launcher of type " + f"exactly `{self.launcher_type}`" + ) + + def format_( + exe: t.Sequence[str], + path: pathlib.Path, + env: EnvironMappingType, + out: pathlib.Path, + err: pathlib.Path, + ) -> _LaunchableT: + return self.formatter(arguments, exe, path, env, out, err) + + return _LauncherAdapter(launcher, format_) + + def configure_first_compatible_launcher( + self, + with_arguments: _DispatchableT, + from_available_launchers: t.Iterable[LauncherProtocol[t.Any]], + ) -> _LaunchConfigType: + """Configure the first compatible adapter launch to launch with the + provided settings. Launchers are iterated and discarded from the + iterator until the iterator is exhausted. + + :param with_settings: The settings with which to configure the launcher + :param from_available_launchers: An iterable that yields launcher instances + :raises errors.LauncherNotFoundError: No compatible launcher was + yielded from the provided iterator. + :returns: A launcher configured with the provided settings. + """ + launcher = helpers.first(self._is_compatible_launcher, from_available_launchers) + if launcher is None: + raise errors.LauncherNotFoundError( + f"No launcher of exactly type `{self.launcher_type.__name__}` " + "could be found from provided launchers" + ) + return self.create_adapter_from_launcher(launcher, with_arguments) + + +@t.final +class _LauncherAdapter(t.Generic[Unpack[_Ts]]): + """The launcher adapter is an adapter class takes a launcher that is + capable of launching some type `LaunchableT` and a function with a generic + argument list that returns a `LaunchableT`. The launcher adapter will then + provide `start` method that will have the same argument list as the + provided function and launch the output through the provided launcher. + + For example, the launcher adapter could be used like so: + + .. highlight:: python + .. code-block:: python + + class SayHelloLauncher(LauncherProtocol[str]): + ... + def start(self, title: str): + ... + print(f"Hello, {title}") + ... + ... + + @dataclasses.dataclass + class Person: + name: str + honorific: str + + def full_title(self) -> str: + return f"{honorific}. {self.name}" + + mark = Person("Jim", "Mr") + sally = Person("Sally", "Ms") + matt = Person("Matt", "Dr") + hello_person_launcher = _LauncherAdapter(SayHelloLauncher, + Person.full_title) + + hello_person_launcher.start(mark) # prints: "Hello, Mr. Mark" + hello_person_launcher.start(sally) # prints: "Hello, Ms. Sally" + hello_person_launcher.start(matt) # prints: "Hello, Dr. Matt" + """ + + def __init__( + self, + launcher: LauncherProtocol[_LaunchableT], + map_: t.Callable[[Unpack[_Ts]], _LaunchableT], + ) -> None: + """Initialize a launcher adapter + + :param launcher: The launcher instance this class should wrap + :param map_: A callable with arguments for the new `start` method that + can translate them into the expected launching type for the wrapped + launcher. + """ + # NOTE: We need to cast off the `_LaunchableT` -> `Any` in the + # `__init__` method signature to hide the transform from users of + # this class. If possible, this type should not be exposed to + # users of this class! + self._adapt: t.Callable[[Unpack[_Ts]], t.Any] = map_ + self._adapted_launcher: LauncherProtocol[t.Any] = launcher + + def start(self, *args: Unpack[_Ts]) -> LaunchedJobID: + """Start a new job through the wrapped launcher using the custom + `start` signature + + :param args: The custom start arguments + :returns: The launched job id provided by the wrapped launcher + """ + payload = self._adapt(*args) + return self._adapted_launcher.start(payload) + + +DEFAULT_DISPATCHER: t.Final = Dispatcher() +"""A global `Dispatcher` instance that SmartSim automatically configures to +launch its built in launchables +""" + +# Disabling because we want this to look and feel like a top level function, +# but don't want to have a second copy of the nasty overloads +# pylint: disable-next=invalid-name +dispatch: t.Final = DEFAULT_DISPATCHER.dispatch +"""Function that can be used as a decorator to add a dispatch registration into +`DEFAULT_DISPATCHER`. +""" diff --git a/smartsim/_core/entrypoints/colocated.py b/smartsim/_core/entrypoints/colocated.py deleted file mode 100644 index 508251fe06..0000000000 --- a/smartsim/_core/entrypoints/colocated.py +++ /dev/null @@ -1,348 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024 Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import argparse -import os -import signal -import socket -import sys -import tempfile -import typing as t -from pathlib import Path -from subprocess import STDOUT -from types import FrameType - -import filelock -import psutil -from smartredis import Client, ConfigOptions -from smartredis.error import RedisConnectionError, RedisReplyError - -from smartsim._core.utils.network import current_ip -from smartsim.error import SSInternalError -from smartsim.log import get_logger - -logger = get_logger(__name__) - -DBPID = None - -# kill is not catchable -SIGNALS = [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT] - - -def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: - if not signo: - logger.warning("Received signal with no signo") - cleanup() - - -def launch_db_model(client: Client, db_model: t.List[str]) -> str: - """Parse options to launch model on local cluster - - :param client: SmartRedis client connected to local DB - :param db_model: List of arguments defining the model - :return: Name of model - """ - parser = argparse.ArgumentParser("Set ML model on DB") - parser.add_argument("--name", type=str) - parser.add_argument("--file", type=str) - parser.add_argument("--backend", type=str) - parser.add_argument("--device", type=str) - parser.add_argument("--devices_per_node", type=int, default=1) - parser.add_argument("--first_device", type=int, default=0) - parser.add_argument("--batch_size", type=int, default=0) - parser.add_argument("--min_batch_size", type=int, default=0) - parser.add_argument("--min_batch_timeout", type=int, default=0) - parser.add_argument("--tag", type=str, default="") - parser.add_argument("--inputs", nargs="+", default=None) - parser.add_argument("--outputs", nargs="+", default=None) - args = parser.parse_args(db_model) - - inputs = None - outputs = None - - if args.inputs: - inputs = list(args.inputs) - if args.outputs: - outputs = list(args.outputs) - - name = str(args.name) - - # devices_per_node being greater than one only applies to GPU devices - if args.devices_per_node > 1 and args.device.lower() == "gpu": - client.set_model_from_file_multigpu( - name=name, - model_file=args.file, - backend=args.backend, - first_gpu=args.first_device, - num_gpus=args.devices_per_node, - batch_size=args.batch_size, - min_batch_size=args.min_batch_size, - min_batch_timeout=args.min_batch_timeout, - tag=args.tag, - inputs=inputs, - outputs=outputs, - ) - else: - client.set_model_from_file( - name=name, - model_file=args.file, - backend=args.backend, - device=args.device, - batch_size=args.batch_size, - min_batch_size=args.min_batch_size, - min_batch_timeout=args.min_batch_timeout, - tag=args.tag, - inputs=inputs, - outputs=outputs, - ) - - return name - - -def launch_db_script(client: Client, db_script: t.List[str]) -> str: - """Parse options to launch script on local cluster - - :param client: SmartRedis client connected to local DB - :param db_model: List of arguments defining the script - :return: Name of model - """ - parser = argparse.ArgumentParser("Set script on DB") - parser.add_argument("--name", type=str) - parser.add_argument("--func", type=str) - parser.add_argument("--file", type=str) - parser.add_argument("--backend", type=str) - parser.add_argument("--device", type=str) - parser.add_argument("--devices_per_node", type=int, default=1) - parser.add_argument("--first_device", type=int, default=0) - args = parser.parse_args(db_script) - - if args.file and args.func: - raise ValueError("Both file and func cannot be provided.") - - if args.func: - func = args.func.replace("\\n", "\n") - if args.devices_per_node > 1 and args.device.lower() == "gpu": - client.set_script_multigpu( - args.name, func, args.first_device, args.devices_per_node - ) - else: - client.set_script(args.name, func, args.device) - elif args.file: - if args.devices_per_node > 1 and args.device.lower() == "gpu": - client.set_script_from_file_multigpu( - args.name, args.file, args.first_device, args.devices_per_node - ) - else: - client.set_script_from_file(args.name, args.file, args.device) - else: - raise ValueError("No file or func provided.") - - return str(args.name) - - -def main( - network_interface: str, - db_cpus: int, - command: t.List[str], - db_models: t.List[t.List[str]], - db_scripts: t.List[t.List[str]], - db_identifier: str, -) -> None: - # pylint: disable=too-many-statements - global DBPID # pylint: disable=global-statement - - lo_address = current_ip("lo") - ip_addresses = [] - if network_interface: - try: - ip_addresses = [ - current_ip(interface) for interface in network_interface.split(",") - ] - except ValueError as e: - logger.warning(e) - - if all(lo_address == ip_address for ip_address in ip_addresses) or not ip_addresses: - cmd = command + [f"--bind {lo_address}"] - else: - # bind to both addresses if the user specified a network - # address that exists and is not the loopback address - cmd = command + [f"--bind {lo_address} {' '.join(ip_addresses)}"] - # pin source address to avoid random selection by Redis - cmd += [f"--bind-source-addr {lo_address}"] - - # we generally want to catch all exceptions here as - # if this process dies, the application will most likely fail - try: - hostname = socket.gethostname() - filename = ( - f"colo_orc_{hostname}.log" - if os.getenv("SMARTSIM_LOG_LEVEL") == "debug" - else os.devnull - ) - with open(filename, "w", encoding="utf-8") as file: - process = psutil.Popen(cmd, stdout=file.fileno(), stderr=STDOUT) - DBPID = process.pid - # printing to stdout shell file for extraction - print(f"__PID__{DBPID}__PID__", flush=True) - - except Exception as e: - cleanup() - logger.error(f"Failed to start database process: {str(e)}") - raise SSInternalError("Colocated process failed to start") from e - - try: - logger.debug( - "\n\nColocated database information\n" - f"\n\tIP Address(es): {' '.join(ip_addresses + [lo_address])}" - f"\n\tCommand: {' '.join(cmd)}\n\n" - f"\n\t# of Database CPUs: {db_cpus}" - f"\n\tDatabase Identifier: {db_identifier}" - ) - except Exception as e: - cleanup() - logger.error(f"Failed to start database process: {str(e)}") - raise SSInternalError("Colocated process failed to start") from e - - def launch_models(client: Client, db_models: t.List[t.List[str]]) -> None: - for i, db_model in enumerate(db_models): - logger.debug("Uploading model") - model_name = launch_db_model(client, db_model) - logger.debug(f"Added model {model_name} ({i+1}/{len(db_models)})") - - def launch_db_scripts(client: Client, db_scripts: t.List[t.List[str]]) -> None: - for i, db_script in enumerate(db_scripts): - logger.debug("Uploading script") - script_name = launch_db_script(client, db_script) - logger.debug(f"Added script {script_name} ({i+1}/{len(db_scripts)})") - - try: - if db_models or db_scripts: - try: - options = ConfigOptions.create_from_environment(db_identifier) - client = Client(options, logger_name="SmartSim") - launch_models(client, db_models) - launch_db_scripts(client, db_scripts) - except (RedisConnectionError, RedisReplyError) as ex: - raise SSInternalError( - "Failed to set model or script, could not connect to database" - ) from ex - # Make sure we don't keep this around - del client - - except Exception as e: - cleanup() - logger.error(f"Colocated database process failed: {str(e)}") - raise SSInternalError("Colocated entrypoint raised an error") from e - - -def cleanup() -> None: - try: - logger.debug("Cleaning up colocated database") - # attempt to stop the database process - db_proc = psutil.Process(DBPID) - db_proc.terminate() - - except psutil.NoSuchProcess: - logger.warning("Couldn't find database process to kill.") - - except OSError as e: - logger.warning(f"Failed to clean up colocated database gracefully: {str(e)}") - finally: - if LOCK.is_locked: - LOCK.release() - - if os.path.exists(LOCK.lock_file): - os.remove(LOCK.lock_file) - - -def register_signal_handlers() -> None: - for sig in SIGNALS: - signal.signal(sig, handle_signal) - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - prefix_chars="+", description="SmartSim Process Launcher" - ) - arg_parser.add_argument( - "+ifname", type=str, help="Network Interface name", default="" - ) - arg_parser.add_argument( - "+lockfile", type=str, help="Filename to create for single proc per host" - ) - arg_parser.add_argument( - "+db_cpus", type=int, default=2, help="Number of CPUs to use for DB" - ) - - arg_parser.add_argument( - "+db_identifier", type=str, default="", help="Database Identifier" - ) - - arg_parser.add_argument("+command", nargs="+", help="Command to run") - arg_parser.add_argument( - "+db_model", - nargs="+", - action="append", - default=[], - help="Model to set on DB", - ) - arg_parser.add_argument( - "+db_script", - nargs="+", - action="append", - default=[], - help="Script to set on DB", - ) - - os.environ["PYTHONUNBUFFERED"] = "1" - - try: - parsed_args = arg_parser.parse_args() - tmp_lockfile = Path(tempfile.gettempdir()) / parsed_args.lockfile - - LOCK = filelock.FileLock(tmp_lockfile) - LOCK.acquire(timeout=0.1) - logger.debug(f"Starting colocated database on host: {socket.gethostname()}") - - # make sure to register the cleanup before we start - # the proecss so our signaller will be able to stop - # the database process. - register_signal_handlers() - - main( - parsed_args.ifname, - parsed_args.db_cpus, - parsed_args.command, - parsed_args.db_model, - parsed_args.db_script, - parsed_args.db_identifier, - ) - - # gracefully exit the processes in the distributed application that - # we do not want to have start a colocated process. Only one process - # per node should be running. - except filelock.Timeout: - sys.exit(0) diff --git a/smartsim/_core/entrypoints/dragon.py b/smartsim/_core/entrypoints/dragon.py index 92ebd735fb..b0b941d104 100644 --- a/smartsim/_core/entrypoints/dragon.py +++ b/smartsim/_core/entrypoints/dragon.py @@ -40,8 +40,8 @@ import zmq.auth.thread from smartsim._core.config import get_config -from smartsim._core.launcher.dragon import dragonSockets -from smartsim._core.launcher.dragon.dragonBackend import DragonBackend +from smartsim._core.launcher.dragon import dragon_sockets +from smartsim._core.launcher.dragon.dragon_backend import DragonBackend from smartsim._core.schemas import ( DragonBootstrapRequest, DragonBootstrapResponse, @@ -164,12 +164,12 @@ def run( dragon_pid: int, ) -> None: logger.debug(f"Opening socket {dragon_head_address}") - dragon_head_socket = dragonSockets.get_secure_socket(zmq_context, zmq.REP, True) + dragon_head_socket = dragon_sockets.get_secure_socket(zmq_context, zmq.REP, True) dragon_head_socket.bind(dragon_head_address) dragon_backend = DragonBackend(pid=dragon_pid) backend_updater = start_updater(dragon_backend, None) - server = dragonSockets.as_server(dragon_head_socket) + server = dragon_sockets.as_server(dragon_head_socket) logger.debug(f"Listening to {dragon_head_address}") @@ -236,14 +236,14 @@ def execute_entrypoint(args: DragonEntrypointArgs) -> int: else: dragon_head_address += ":5555" - zmq_authenticator = dragonSockets.get_authenticator(zmq_context, timeout=-1) + zmq_authenticator = dragon_sockets.get_authenticator(zmq_context, timeout=-1) logger.debug("Getting launcher socket") - launcher_socket = dragonSockets.get_secure_socket(zmq_context, zmq.REQ, False) + launcher_socket = dragon_sockets.get_secure_socket(zmq_context, zmq.REQ, False) logger.debug(f"Connecting launcher socket to: {args.launching_address}") launcher_socket.connect(args.launching_address) - client = dragonSockets.as_client(launcher_socket) + client = dragon_sockets.as_client(launcher_socket) logger.debug( f"Sending bootstrap request to launcher_socket with {dragon_head_address}" @@ -297,7 +297,7 @@ def cleanup() -> None: def register_signal_handlers() -> None: # make sure to register the cleanup before the start # the process so our signaller will be able to stop - # the database process. + # the feature store process. for sig in SIGNALS: signal.signal(sig, handle_signal) diff --git a/smartsim/_core/entrypoints/dragon_client.py b/smartsim/_core/entrypoints/dragon_client.py index e998ddce19..0131124121 100644 --- a/smartsim/_core/entrypoints/dragon_client.py +++ b/smartsim/_core/entrypoints/dragon_client.py @@ -37,7 +37,7 @@ import zmq -from smartsim._core.launcher.dragon.dragonConnector import DragonConnector +from smartsim._core.launcher.dragon.dragon_connector import DragonConnector from smartsim._core.schemas import ( DragonHandshakeRequest, DragonRequest, diff --git a/smartsim/_core/entrypoints/file_operations.py b/smartsim/_core/entrypoints/file_operations.py new file mode 100644 index 0000000000..69d7f7565e --- /dev/null +++ b/smartsim/_core/entrypoints/file_operations.py @@ -0,0 +1,293 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024 Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import argparse +import base64 +import functools +import os +import pathlib +import pickle +import shutil +import typing as t +from typing import Callable + +from ...log import get_logger + +logger = get_logger(__name__) + +"""Run file operations move, remove, symlink, copy, and configure +using command line arguments. +""" + + +def _abspath(input_path: str) -> pathlib.Path: + """Helper function to check that paths are absolute""" + path = pathlib.Path(input_path) + if not path.is_absolute(): + raise ValueError(f"Path `{path}` must be absolute.") + return path + + +def _make_substitution( + tag_name: str, replacement: str | int | float, tag_delimiter: str +) -> Callable[[str], str]: + """Helper function to replace tags""" + return lambda s: s.replace( + f"{tag_delimiter}{tag_name}{tag_delimiter}", str(replacement) + ) + + +def _prepare_param_dict(param_dict: str) -> dict[str, t.Any]: + """Decode and deserialize a base64-encoded parameter dictionary. + + This function takes a base64-encoded string representation of a dictionary, + decodes it, and then deserializes it using pickle. It performs validation + to ensure the resulting object is a non-empty dictionary. + """ + decoded_dict = base64.b64decode(param_dict) + deserialized_dict = pickle.loads(decoded_dict) + if not isinstance(deserialized_dict, dict): + raise TypeError("param dict is not a valid dictionary") + if not deserialized_dict: + raise ValueError("param dictionary is empty") + return deserialized_dict + + +def _replace_tags_in( + item: str, + substitutions: t.Sequence[Callable[[str], str]], +) -> str: + """Helper function to derive the lines in which to make the substitutions""" + return functools.reduce(lambda a, fn: fn(a), substitutions, item) + + +def _process_file( + substitutions: t.Sequence[Callable[[str], str]], + source: pathlib.Path, + destination: pathlib.Path, +) -> None: + """ + Process a source file by replacing tags with specified substitutions and + write the result to a destination file. + """ + # Set the lines to iterate over + with open(source, "r+", encoding="utf-8") as file_stream: + lines = [_replace_tags_in(line, substitutions) for line in file_stream] + # write configured file to destination specified + with open(destination, "w+", encoding="utf-8") as file_stream: + file_stream.writelines(lines) + + +def move(parsed_args: argparse.Namespace) -> None: + """Move a source file or directory to another location. If dest is an + existing directory or a symlink to a directory, then the srouce will + be moved inside that directory. The destination path in that directory + must not already exist. If dest is an existing file, it will be overwritten. + + Sample usage: + .. highlight:: bash + .. code-block:: bash + python -m smartsim._core.entrypoints.file_operations \ + move /absolute/file/source/path /absolute/file/dest/path + + /absolute/file/source/path: File or directory to be moved + /absolute/file/dest/path: Path to a file or directory location + """ + shutil.move(parsed_args.source, parsed_args.dest) + + +def remove(parsed_args: argparse.Namespace) -> None: + """Remove a file or directory. + + Sample usage: + .. highlight:: bash + .. code-block:: bash + python -m smartsim._core.entrypoints.file_operations \ + remove /absolute/file/path + + /absolute/file/path: Path to the file or directory to be deleted + """ + if os.path.isdir(parsed_args.to_remove): + os.rmdir(parsed_args.to_remove) + else: + os.remove(parsed_args.to_remove) + + +def copy(parsed_args: argparse.Namespace) -> None: + """Copy the contents from the source file into the dest file. + If source is a directory, copy the entire directory tree source to dest. + + Sample usage: + .. highlight:: bash + .. code-block:: bash + python -m smartsim._core.entrypoints.file_operations copy \ + /absolute/file/source/path /absolute/file/dest/path \ + --dirs_exist_ok + + /absolute/file/source/path: Path to directory, or path to file to + copy to a new location + /absolute/file/dest/path: Path to destination directory or path to + destination file + --dirs_exist_ok: if the flag is included, the copying operation will + continue if the destination directory and files already exist, + and will be overwritten by corresponding files. If the flag is + not included and the destination file already exists, a + FileExistsError will be raised + """ + if os.path.isdir(parsed_args.source): + shutil.copytree( + parsed_args.source, + parsed_args.dest, + dirs_exist_ok=parsed_args.dirs_exist_ok, + ) + else: + shutil.copy(parsed_args.source, parsed_args.dest) + + +def symlink(parsed_args: argparse.Namespace) -> None: + """ + Create a symbolic link pointing to the exisiting source file + named link. + + Sample usage: + .. highlight:: bash + .. code-block:: bash + python -m smartsim._core.entrypoints.file_operations \ + symlink /absolute/file/source/path /absolute/file/dest/path + + /absolute/file/source/path: the exisiting source path + /absolute/file/dest/path: target name where the symlink will be created. + """ + os.symlink(parsed_args.source, parsed_args.dest) + + +def configure(parsed_args: argparse.Namespace) -> None: + """Set, search and replace the tagged parameters for the + configure_file operation within tagged files attached to an entity. + + User-formatted files can be attached using the `configure_file` argument. + These files will be modified during ``Application`` generation to replace + tagged sections in the user-formatted files with values from the `params` + initializer argument used during ``Application`` creation: + + Sample usage: + .. highlight:: bash + .. code-block:: bash + python -m smartsim._core.entrypoints.file_operations \ + configure_file /absolute/file/source/path /absolute/file/dest/path \ + tag_deliminator param_dict + + /absolute/file/source/path: The tagged files the search and replace operations + to be performed upon + /absolute/file/dest/path: The destination for configured files to be + written to. + tag_delimiter: tag for the configure_file operation to search for, defaults to + semi-colon e.g. ";" + param_dict: A dict of parameter names and values set for the file + + """ + tag_delimiter = parsed_args.tag_delimiter + param_dict = _prepare_param_dict(parsed_args.param_dict) + + substitutions = tuple( + _make_substitution(k, v, tag_delimiter) for k, v in param_dict.items() + ) + if parsed_args.source.is_dir(): + for dirpath, _, filenames in os.walk(parsed_args.source): + new_dir_dest = dirpath.replace( + str(parsed_args.source), str(parsed_args.dest), 1 + ) + os.makedirs(new_dir_dest, exist_ok=True) + for file_name in filenames: + src_file = os.path.join(dirpath, file_name) + dst_file = os.path.join(new_dir_dest, file_name) + _process_file(substitutions, src_file, dst_file) + else: + dst_file = parsed_args.dest / os.path.basename(parsed_args.source) + _process_file(substitutions, parsed_args.source, dst_file) + + +def get_parser() -> argparse.ArgumentParser: + """Instantiate a parser to process command line arguments + + :returns: An argument parser ready to accept required command generator parameters + """ + arg_parser = argparse.ArgumentParser(description="Command Generator") + + subparsers = arg_parser.add_subparsers(help="file_operations") + + # Subparser for move op + move_parser = subparsers.add_parser("move") + move_parser.set_defaults(func=move) + move_parser.add_argument("source", type=_abspath) + move_parser.add_argument("dest", type=_abspath) + + # Subparser for remove op + remove_parser = subparsers.add_parser("remove") + remove_parser.set_defaults(func=remove) + remove_parser.add_argument("to_remove", type=_abspath) + + # Subparser for copy op + copy_parser = subparsers.add_parser("copy") + copy_parser.set_defaults(func=copy) + copy_parser.add_argument("source", type=_abspath) + copy_parser.add_argument("dest", type=_abspath) + copy_parser.add_argument("--dirs_exist_ok", action="store_true") + + # Subparser for symlink op + symlink_parser = subparsers.add_parser("symlink") + symlink_parser.set_defaults(func=symlink) + symlink_parser.add_argument("source", type=_abspath) + symlink_parser.add_argument("dest", type=_abspath) + + # Subparser for configure op + configure_parser = subparsers.add_parser("configure") + configure_parser.set_defaults(func=configure) + configure_parser.add_argument("source", type=_abspath) + configure_parser.add_argument("dest", type=_abspath) + configure_parser.add_argument("tag_delimiter", type=str, default=";") + configure_parser.add_argument("param_dict", type=str) + + return arg_parser + + +def parse_arguments() -> argparse.Namespace: + """Parse the command line arguments + + :returns: the parsed command line arguments + """ + parser = get_parser() + parsed_args = parser.parse_args() + return parsed_args + + +if __name__ == "__main__": + os.environ["PYTHONUNBUFFERED"] = "1" + + args = parse_arguments() + args.func(args) diff --git a/smartsim/_core/entrypoints/indirect.py b/smartsim/_core/entrypoints/indirect.py index 1f445ac4a1..38dc9a7ec3 100644 --- a/smartsim/_core/entrypoints/indirect.py +++ b/smartsim/_core/entrypoints/indirect.py @@ -61,7 +61,7 @@ def main( :param cmd: a base64 encoded cmd to execute :param entity_type: `SmartSimEntity` entity class. Valid values - include: orchestrator, dbnode, ensemble, model + include: feature store, fsnode, ensemble, application :param cwd: working directory to execute the cmd from :param status_dir: path to the output directory for status updates """ @@ -233,7 +233,7 @@ def get_parser() -> argparse.ArgumentParser: logger.debug("Starting indirect step execution") # make sure to register the cleanup before the start the process - # so our signaller will be able to stop the database process. + # so our signaller will be able to stop the feature store process. register_signal_handlers() rc = main( diff --git a/smartsim/_core/entrypoints/redis.py b/smartsim/_core/entrypoints/redis.py deleted file mode 100644 index c4d8cbbd63..0000000000 --- a/smartsim/_core/entrypoints/redis.py +++ /dev/null @@ -1,192 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024 Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import argparse -import json -import os -import signal -import textwrap -import typing as t -from subprocess import PIPE, STDOUT -from types import FrameType - -import psutil - -from smartsim._core.utils.network import current_ip -from smartsim.entity.dbnode import LaunchedShardData -from smartsim.log import get_logger - -logger = get_logger(__name__) - -""" -Redis/KeyDB entrypoint script -""" - -DBPID: t.Optional[int] = None - -# kill is not catchable -SIGNALS = [signal.SIGINT, signal.SIGQUIT, signal.SIGTERM, signal.SIGABRT] - - -def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: - if not signo: - logger.warning("Received signal with no signo") - cleanup() - - -def build_bind_args(source_addr: str, *addrs: str) -> t.Tuple[str, ...]: - return ( - "--bind", - source_addr, - *addrs, - # pin source address to avoid random selection by Redis - "--bind-source-addr", - source_addr, - ) - - -def build_cluster_args(shard_data: LaunchedShardData) -> t.Tuple[str, ...]: - if cluster_conf_file := shard_data.cluster_conf_file: - return ("--cluster-enabled", "yes", "--cluster-config-file", cluster_conf_file) - return () - - -def print_summary( - cmd: t.List[str], network_interface: str, shard_data: LaunchedShardData -) -> None: - print( - textwrap.dedent(f"""\ - ----------- Running Command ---------- - COMMAND: {' '.join(cmd)} - IPADDRESS: {shard_data.hostname} - NETWORK: {network_interface} - SMARTSIM_ORC_SHARD_INFO: {json.dumps(shard_data.to_dict())} - -------------------------------------- - - --------------- Output --------------- - - """), - flush=True, - ) - - -def main(args: argparse.Namespace) -> int: - global DBPID # pylint: disable=global-statement - - src_addr, *bind_addrs = (current_ip(net_if) for net_if in args.ifname.split(",")) - shard_data = LaunchedShardData( - name=args.name, hostname=src_addr, port=args.port, cluster=args.cluster - ) - - cmd = [ - args.orc_exe, - args.conf_file, - *args.rai_module, - "--port", - str(args.port), - *build_cluster_args(shard_data), - *build_bind_args(src_addr, *bind_addrs), - ] - - print_summary(cmd, args.ifname, shard_data) - - try: - process = psutil.Popen(cmd, stdout=PIPE, stderr=STDOUT) - DBPID = process.pid - - for line in iter(process.stdout.readline, b""): - print(line.decode("utf-8").rstrip(), flush=True) - except Exception: - cleanup() - logger.error("Database process starter raised an exception", exc_info=True) - return 1 - return 0 - - -def cleanup() -> None: - logger.debug("Cleaning up database instance") - try: - # attempt to stop the database process - if DBPID is not None: - psutil.Process(DBPID).terminate() - except psutil.NoSuchProcess: - logger.warning("Couldn't find database process to kill.") - except OSError as e: - logger.warning(f"Failed to clean up database gracefully: {str(e)}") - - -if __name__ == "__main__": - os.environ["PYTHONUNBUFFERED"] = "1" - - parser = argparse.ArgumentParser( - prefix_chars="+", description="SmartSim Process Launcher" - ) - parser.add_argument( - "+orc-exe", type=str, help="Path to the orchestrator executable", required=True - ) - parser.add_argument( - "+conf-file", - type=str, - help="Path to the orchestrator configuration file", - required=True, - ) - parser.add_argument( - "+rai-module", - nargs="+", - type=str, - help=( - "Command for the orcestrator to load the Redis AI module with " - "symbols seperated by whitespace" - ), - required=True, - ) - parser.add_argument( - "+name", type=str, help="Name to identify the shard", required=True - ) - parser.add_argument( - "+port", - type=int, - help="The port on which to launch the shard of the orchestrator", - required=True, - ) - parser.add_argument( - "+ifname", type=str, help="Network Interface name", required=True - ) - parser.add_argument( - "+cluster", - action="store_true", - help="Specify if this orchestrator shard is part of a cluster", - ) - - args_ = parser.parse_args() - - # make sure to register the cleanup before the start - # the process so our signaller will be able to stop - # the database process. - for sig in SIGNALS: - signal.signal(sig, handle_signal) - - raise SystemExit(main(args_)) diff --git a/smartsim/_core/entrypoints/telemetrymonitor.py b/smartsim/_core/entrypoints/telemetry_monitor.py similarity index 100% rename from smartsim/_core/entrypoints/telemetrymonitor.py rename to smartsim/_core/entrypoints/telemetry_monitor.py diff --git a/smartsim/_core/generation/generator.py b/smartsim/_core/generation/generator.py index 8706cf5686..1cc1670655 100644 --- a/smartsim/_core/generation/generator.py +++ b/smartsim/_core/generation/generator.py @@ -25,314 +25,296 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pathlib -import shutil +import subprocess import typing as t +from collections import namedtuple from datetime import datetime -from distutils import dir_util # pylint: disable=deprecated-module -from logging import DEBUG, INFO -from os import mkdir, path, symlink -from os.path import join, relpath -from tabulate import tabulate - -from ...database import Orchestrator -from ...entity import Ensemble, Model, TaggedFilesHierarchy +from ...entity import entity +from ...launchable import Job from ...log import get_logger -from ..control import Manifest -from .modelwriter import ModelWriter +from ..commands import Command, CommandList +from .operations.operations import ( + ConfigureOperation, + CopyOperation, + FileSysOperationSet, + GenerationContext, + SymlinkOperation, +) logger = get_logger(__name__) logger.propagate = False +@t.runtime_checkable +class _GenerableProtocol(t.Protocol): + """Protocol to ensure that an entity supports both file operations + and parameters.""" + + files: FileSysOperationSet + # TODO change when file_parameters taken off Application during Ensemble refactor ticket + file_parameters: t.Mapping[str, str] + + +Job_Path = namedtuple("Job_Path", ["run_path", "out_path", "err_path"]) +"""Namedtuple that stores a Job's run directory, output file path, and +error file path.""" + + class Generator: - """The primary job of the generator is to create the file structure - for a SmartSim experiment. The Generator is responsible for reading - and writing into configuration files as well. + """The Generator class creates the directory structure for a SmartSim Job by building + and executing file operation commands. """ - def __init__( - self, gen_path: str, overwrite: bool = False, verbose: bool = True - ) -> None: - """Initialize a generator object - - if overwrite is true, replace any existing - configured models within an ensemble if there - is a name collision. Also replace any and all directories - for the experiment with fresh copies. Otherwise, if overwrite - is false, raises EntityExistsError when there is a name - collision between entities. - - :param gen_path: Path in which files need to be generated - :param overwrite: toggle entity replacement - :param verbose: Whether generation information should be logged to std out - """ - self._writer = ModelWriter() - self.gen_path = gen_path - self.overwrite = overwrite - self.log_level = DEBUG if not verbose else INFO - - @property - def log_file(self) -> str: - """Returns the location of the file - summarizing the parameters used for the last generation - of all generated entities. - - :returns: path to file with parameter settings - """ - return join(self.gen_path, "smartsim_params.txt") + run_directory = "run" + """The name of the directory storing run-related files.""" + log_directory = "log" + """The name of the directory storing log-related files.""" - def generate_experiment(self, *args: t.Any) -> None: - """Run ensemble and experiment file structure generation + def __init__(self, root: pathlib.Path) -> None: + """Initialize a Generator object - Generate the file structure for a SmartSim experiment. This - includes the writing and configuring of input files for a - model. + The Generator class is responsible for constructing a Job's directory, performing + the following tasks: - To have files or directories present in the created entity - directories, such as datasets or input files, call - ``entity.attach_generator_files`` prior to generation. See - ``entity.attach_generator_files`` for more information on - what types of files can be included. + - Creating the run and log directories + - Generating the output and error files + - Building the parameter settings file + - Managing symlinking, copying, and configuration of attached files - Tagged model files are read, checked for input variables to - configure, and written. Input variables to configure are - specified with a tag within the input file itself. - The default tag is surronding an input value with semicolons. - e.g. ``THERMO=;90;`` + :param root: The base path for job-related files and directories + """ + self.root = root + """The root directory under which all generated files and directories will be placed.""" + + def _build_job_base_path(self, job: Job, job_index: int) -> pathlib.Path: + """Build and return a Job's base directory. The path is created by combining the + root directory with the Job type (derived from the class name), + the name attribute of the Job, and an index to differentiate between multiple + Job runs. + + :param job: Job object + :param job_index: Job index + :returns: The built file path for the Job + """ + job_type = f"{job.__class__.__name__.lower()}s" + job_path = self.root / f"{job_type}/{job.name}-{job_index}" + return pathlib.Path(job_path) + + def _build_job_run_path(self, job: Job, job_index: int) -> pathlib.Path: + """Build and return a Job's run directory. The path is formed by combining + the base directory with the `run_directory` class-level constant, which specifies + the name of the Job's run folder. + + :param job: Job object + :param job_index: Job index + :returns: The built file path for the Job run folder + """ + path = self._build_job_base_path(job, job_index) / self.run_directory + return pathlib.Path(path) + + def _build_job_log_path(self, job: Job, job_index: int) -> pathlib.Path: + """Build and return a Job's log directory. The path is formed by combining + the base directory with the `log_directory` class-level constant, which specifies + the name of the Job's log folder. + :param job: Job object + :param job_index: Job index + :returns: The built file path for the Job run folder """ - generator_manifest = Manifest(*args) + path = self._build_job_base_path(job, job_index) / self.log_directory + return pathlib.Path(path) - self._gen_exp_dir() - self._gen_orc_dir(generator_manifest.dbs) - self._gen_entity_list_dir(generator_manifest.ensembles) - self._gen_entity_dirs(generator_manifest.models) + @staticmethod + def _build_log_file_path(log_path: pathlib.Path) -> pathlib.Path: + """Build and return a parameters file summarizing the parameters + used for the generation of the entity. - def set_tag(self, tag: str, regex: t.Optional[str] = None) -> None: - """Set the tag used for tagging input files + :param log_path: Path to log directory + :returns: The built file path an entities params file + """ + return pathlib.Path(log_path) / "smartsim_params.txt" - Set a tag or a regular expression for the - generator to look for when configuring new models. + @staticmethod + def _build_out_file_path(log_path: pathlib.Path, job_name: str) -> pathlib.Path: + """Build and return the path to the output file. The path is created by combining + the Job's log directory with the job name and appending the `.out` extension. - For example, a tag might be ``;`` where the - expression being replaced in the model configuration - file would look like ``;expression;`` + :param log_path: Path to log directory + :param job_name: Name of the Job + :returns: Path to the output file + """ + out_file_path = log_path / f"{job_name}.out" + return out_file_path - A full regular expression might tag specific - model configurations such that the configuration - files don't need to be tagged manually. + @staticmethod + def _build_err_file_path(log_path: pathlib.Path, job_name: str) -> pathlib.Path: + """Build and return the path to the error file. The path is created by combining + the Job's log directory with the job name and appending the `.err` extension. - :param tag: A string of characters that signify - the string to be changed. Defaults to ``;`` - :param regex: full regex for the modelwriter to search for + :param log_path: Path to log directory + :param job_name: Name of the Job + :returns: Path to the error file """ - self._writer.set_tag(tag, regex) + err_file_path = log_path / f"{job_name}.err" + return err_file_path + + def generate_job(self, job: Job, job_index: int) -> Job_Path: + """Build and return the Job's run directory, output file, and error file. - def _gen_exp_dir(self) -> None: - """Create the directory for an experiment if it does not - already exist. + This method creates the Job's run and log directories, generates the + `smartsim_params.txt` file to log parameters used for the Job, and sets + up the output and error files for Job execution information. If files are + attached to the Job's entity, it builds file operation commands and executes + them. + + :param job: Job object + :param job_index: Job index + :return: Job's run directory, error file and out file. """ - if path.isfile(self.gen_path): - raise FileExistsError( - f"Experiment directory could not be created. {self.gen_path} exists" - ) - if not path.isdir(self.gen_path): - # keep exists ok for race conditions on NFS - pathlib.Path(self.gen_path).mkdir(exist_ok=True, parents=True) - else: - logger.log( - level=self.log_level, msg="Working in previously created experiment" - ) - - # The log_file only keeps track of the last generation - # this is to avoid gigantic files in case the user repeats - # generation several times. The information is anyhow - # redundant, as it is also written in each entity's dir - with open(self.log_file, mode="w", encoding="utf-8") as log_file: + job_path = self._build_job_run_path(job, job_index) + log_path = self._build_job_log_path(job, job_index) + + out_file = self._build_out_file_path(log_path, job.entity.name) + err_file = self._build_err_file_path(log_path, job.entity.name) + + cmd_list = self._build_commands(job.entity, job_path, log_path) + + self._execute_commands(cmd_list) + + with open( + self._build_log_file_path(log_path), mode="w", encoding="utf-8" + ) as log_file: dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") log_file.write(f"Generation start date and time: {dt_string}\n") - def _gen_orc_dir(self, orchestrator_list: t.List[Orchestrator]) -> None: - """Create the directory that will hold the error, output and - configuration files for the orchestrator. - - :param orchestrator: Orchestrator instance + return Job_Path(job_path, out_file, err_file) + + @classmethod + def _build_commands( + cls, + entity: entity.SmartSimEntity, + job_path: pathlib.Path, + log_path: pathlib.Path, + ) -> CommandList: + """Build file operation commands for a Job's entity. + + This method constructs commands for copying, symlinking, and writing tagged files + associated with the Job's entity. This method builds the constructs the commands to + generate the Job's run and log directory. It aggregates these commands into a CommandList + to return. + + :param job: Job object + :param job_path: The file path for the Job run folder + :param log_path: The file path for the Job log folder + :return: A CommandList containing the file operation commands """ - # Loop through orchestrators - for orchestrator in orchestrator_list: - orc_path = path.join(self.gen_path, orchestrator.name) + context = GenerationContext(job_path) + cmd_list = CommandList() - orchestrator.set_path(orc_path) - # Always remove orchestrator files if present. - if path.isdir(orc_path): - shutil.rmtree(orc_path, ignore_errors=True) - pathlib.Path(orc_path).mkdir(exist_ok=self.overwrite, parents=True) + cls._append_mkdir_commands(cmd_list, job_path, log_path) - def _gen_entity_list_dir(self, entity_lists: t.List[Ensemble]) -> None: - """Generate directories for Ensemble instances + if isinstance(entity, _GenerableProtocol): + cls._append_file_operations(cmd_list, entity, context) - :param entity_lists: list of Ensemble instances - """ + return cmd_list - if not entity_lists: - return - - for elist in entity_lists: - elist_dir = path.join(self.gen_path, elist.name) - if path.isdir(elist_dir): - if self.overwrite: - shutil.rmtree(elist_dir) - mkdir(elist_dir) - else: - mkdir(elist_dir) - elist.path = elist_dir - - self._gen_entity_dirs(list(elist.models), entity_list=elist) - - def _gen_entity_dirs( - self, - entities: t.List[Model], - entity_list: t.Optional[Ensemble] = None, + @classmethod + def _append_mkdir_commands( + cls, cmd_list: CommandList, job_path: pathlib.Path, log_path: pathlib.Path ) -> None: - """Generate directories for Entity instances + """Append file operation Commands (mkdir) for a Job's run and log directory. - :param entities: list of Model instances - :param entity_list: Ensemble instance - :raises EntityExistsError: if a directory already exists for an - entity by that name + :param cmd_list: A CommandList object containing the commands to be executed + :param job_path: The file path for the Job run folder + :param log_path: The file path for the Job log folder """ - if not entities: - return - - for entity in entities: - if entity_list: - dst = path.join(self.gen_path, entity_list.name, entity.name) - else: - dst = path.join(self.gen_path, entity.name) - - if path.isdir(dst): - if self.overwrite: - shutil.rmtree(dst) - else: - error = ( - f"Directory for entity {entity.name} " - f"already exists in path {dst}" - ) - raise FileExistsError(error) - pathlib.Path(dst).mkdir(exist_ok=True) - entity.path = dst - - self._copy_entity_files(entity) - self._link_entity_files(entity) - self._write_tagged_entity_files(entity) - - def _write_tagged_entity_files(self, entity: Model) -> None: - """Read, configure and write the tagged input files for - a Model instance within an ensemble. This function - specifically deals with the tagged files attached to - an Ensemble. - - :param entity: a Model instance - """ - if entity.files: - to_write = [] - - def _build_tagged_files(tagged: TaggedFilesHierarchy) -> None: - """Using a TaggedFileHierarchy, reproduce the tagged file - directory structure - - :param tagged: a TaggedFileHierarchy to be built as a - directory structure - """ - for file in tagged.files: - dst_path = path.join(entity.path, tagged.base, path.basename(file)) - shutil.copyfile(file, dst_path) - to_write.append(dst_path) - - for tagged_dir in tagged.dirs: - mkdir( - path.join( - entity.path, tagged.base, path.basename(tagged_dir.base) - ) - ) - _build_tagged_files(tagged_dir) - - if entity.files.tagged_hierarchy: - _build_tagged_files(entity.files.tagged_hierarchy) - - # write in changes to configurations - if isinstance(entity, Model): - files_to_params = self._writer.configure_tagged_model_files( - to_write, entity.params - ) - self._log_params(entity, files_to_params) - - def _log_params( - self, entity: Model, files_to_params: t.Dict[str, t.Dict[str, str]] + cmd_list.append(cls._mkdir_file(job_path)) + cmd_list.append(cls._mkdir_file(log_path)) + + @classmethod + def _append_file_operations( + cls, + cmd_list: CommandList, + entity: _GenerableProtocol, + context: GenerationContext, ) -> None: - """Log which files were modified during generation + """Append file operation Commands (copy, symlink, configure) for all + files attached to the entity. + + :param cmd_list: A CommandList object containing the commands to be executed + :param entity: The Job's attached entity + :param context: A GenerationContext object that holds the Job's run directory + """ + copy_ret = cls._copy_files(entity.files.copy_operations, context) + cmd_list.extend(copy_ret) + + symlink_ret = cls._symlink_files(entity.files.symlink_operations, context) + cmd_list.extend(symlink_ret) + + configure_ret = cls._configure_files(entity.files.configure_operations, context) + cmd_list.extend(configure_ret) - and what values were set to the parameters + @classmethod + def _execute_commands(cls, cmd_list: CommandList) -> None: + """Execute a list of commands using subprocess. - :param entity: the model being generated - :param files_to_params: a dict connecting each file to its parameter settings + This helper function iterates through each command in the provided CommandList + and executes them using the subprocess module. + + :param cmd_list: A CommandList object containing the commands to be executed """ - used_params: t.Dict[str, str] = {} - file_to_tables: t.Dict[str, str] = {} - for file, params in files_to_params.items(): - used_params.update(params) - table = tabulate(params.items(), headers=["Name", "Value"]) - file_to_tables[relpath(file, self.gen_path)] = table - - if used_params: - used_params_str = ", ".join( - [f"{name}={value}" for name, value in used_params.items()] - ) - logger.log( - level=self.log_level, - msg=f"Configured model {entity.name} with params {used_params_str}", - ) - file_table = tabulate( - file_to_tables.items(), - headers=["File name", "Parameters"], - ) - log_entry = f"Model name: {entity.name}\n{file_table}\n\n" - with open(self.log_file, mode="a", encoding="utf-8") as logfile: - logfile.write(log_entry) - with open( - join(entity.path, "smartsim_params.txt"), mode="w", encoding="utf-8" - ) as local_logfile: - local_logfile.write(log_entry) - - else: - logger.log( - level=self.log_level, - msg=f"Configured model {entity.name} with no parameters", - ) + for cmd in cmd_list: + subprocess.run(cmd.command) @staticmethod - def _copy_entity_files(entity: Model) -> None: - """Copy the entity files and directories attached to this entity. + def _mkdir_file(file_path: pathlib.Path) -> Command: + """Build a Command to create the directory along with any + necessary parent directories. - :param entity: Model + :param file_path: The directory path to be created + :return: A Command object to execute the directory creation """ - if entity.files: - for to_copy in entity.files.copy: - dst_path = path.join(entity.path, path.basename(to_copy)) - if path.isdir(to_copy): - dir_util.copy_tree(to_copy, entity.path) - else: - shutil.copyfile(to_copy, dst_path) + cmd = Command(["mkdir", "-p", str(file_path)]) + return cmd @staticmethod - def _link_entity_files(entity: Model) -> None: - """Symlink the entity files attached to this entity. + def _copy_files( + files: list[CopyOperation], context: GenerationContext + ) -> CommandList: + """Build commands to copy files/directories from specified source paths + to an optional destination in the run directory. + + :param files: A list of CopyOperation objects + :param context: A GenerationContext object that holds the Job's run directory + :return: A CommandList containing the copy commands + """ + return CommandList([file.format(context) for file in files]) - :param entity: Model + @staticmethod + def _symlink_files( + files: list[SymlinkOperation], context: GenerationContext + ) -> CommandList: + """Build commands to symlink files/directories from specified source paths + to an optional destination in the run directory. + + :param files: A list of SymlinkOperation objects + :param context: A GenerationContext object that holds the Job's run directory + :return: A CommandList containing the symlink commands + """ + return CommandList([file.format(context) for file in files]) + + @staticmethod + def _configure_files( + files: list[ConfigureOperation], + context: GenerationContext, + ) -> CommandList: + """Build commands to configure files/directories from specified source paths + to an optional destination in the run directory. + + :param files: A list of ConfigurationOperation objects + :param context: A GenerationContext object that holds the Job's run directory + :return: A CommandList containing the configuration commands """ - if entity.files: - for to_link in entity.files.link: - dst_path = path.join(entity.path, path.basename(to_link)) - symlink(to_link, dst_path) + return CommandList([file.format(context) for file in files]) diff --git a/smartsim/_core/generation/modelwriter.py b/smartsim/_core/generation/modelwriter.py deleted file mode 100644 index 2998d4e354..0000000000 --- a/smartsim/_core/generation/modelwriter.py +++ /dev/null @@ -1,158 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import collections -import re -import typing as t - -from smartsim.error.errors import SmartSimError - -from ...error import ParameterWriterError -from ...log import get_logger - -logger = get_logger(__name__) - - -class ModelWriter: - def __init__(self) -> None: - self.tag = ";" - self.regex = "(;[^;]+;)" - self.lines: t.List[str] = [] - - def set_tag(self, tag: str, regex: t.Optional[str] = None) -> None: - """Set the tag for the modelwriter to search for within - tagged files attached to an entity. - - :param tag: tag for the modelwriter to search for, - defaults to semi-colon e.g. ";" - :param regex: full regex for the modelwriter to search for, - defaults to "(;.+;)" - """ - if regex: - self.regex = regex - else: - self.tag = tag - self.regex = "".join(("(", tag, ".+", tag, ")")) - - def configure_tagged_model_files( - self, - tagged_files: t.List[str], - params: t.Dict[str, str], - make_missing_tags_fatal: bool = False, - ) -> t.Dict[str, t.Dict[str, str]]: - """Read, write and configure tagged files attached to a Model - instance. - - :param tagged_files: list of paths to tagged files - :param params: model parameters - :param make_missing_tags_fatal: raise an error if a tag is missing - :returns: A dict connecting each file to its parameter settings - """ - files_to_tags: t.Dict[str, t.Dict[str, str]] = {} - for tagged_file in tagged_files: - self._set_lines(tagged_file) - used_tags = self._replace_tags(params, make_missing_tags_fatal) - self._write_changes(tagged_file) - files_to_tags[tagged_file] = used_tags - - return files_to_tags - - def _set_lines(self, file_path: str) -> None: - """Set the lines for the modelwrtter to iterate over - - :param file_path: path to the newly created and tagged file - :raises ParameterWriterError: if the newly created file cannot be read - """ - try: - with open(file_path, "r+", encoding="utf-8") as file_stream: - self.lines = file_stream.readlines() - except (IOError, OSError) as e: - raise ParameterWriterError(file_path) from e - - def _write_changes(self, file_path: str) -> None: - """Write the ensemble-specific changes - - :raises ParameterWriterError: if the newly created file cannot be read - """ - try: - with open(file_path, "w+", encoding="utf-8") as file_stream: - for line in self.lines: - file_stream.write(line) - except (IOError, OSError) as e: - raise ParameterWriterError(file_path, read=False) from e - - def _replace_tags( - self, params: t.Dict[str, str], make_fatal: bool = False - ) -> t.Dict[str, str]: - """Replace the tagged parameters within the file attached to this - model. The tag defaults to ";" - - :param model: The model instance - :param make_fatal: (Optional) Set to True to force a fatal error - if a tag is not matched - :returns: A dict of parameter names and values set for the file - """ - edited = [] - unused_tags: t.DefaultDict[str, t.List[int]] = collections.defaultdict(list) - used_params: t.Dict[str, str] = {} - for i, line in enumerate(self.lines, 1): - while search := re.search(self.regex, line): - tagged_line = search.group(0) - previous_value = self._get_prev_value(tagged_line) - if self._is_ensemble_spec(tagged_line, params): - new_val = str(params[previous_value]) - line = re.sub(self.regex, new_val, line, 1) - used_params[previous_value] = new_val - - # if a tag is found but is not in this model's configurations - # put in placeholder value - else: - tag = tagged_line.split(self.tag)[1] - unused_tags[tag].append(i) - line = re.sub(self.regex, previous_value, line) - break - edited.append(line) - - for tag, value in unused_tags.items(): - missing_tag_message = f"Unused tag {tag} on line(s): {str(value)}" - if make_fatal: - raise SmartSimError(missing_tag_message) - logger.warning(missing_tag_message) - self.lines = edited - return used_params - - def _is_ensemble_spec( - self, tagged_line: str, model_params: t.Dict[str, str] - ) -> bool: - split_tag = tagged_line.split(self.tag) - prev_val = split_tag[1] - if prev_val in model_params.keys(): - return True - return False - - def _get_prev_value(self, tagged_line: str) -> str: - split_tag = tagged_line.split(self.tag) - return split_tag[1] diff --git a/smartsim/_core/generation/operations/operations.py b/smartsim/_core/generation/operations/operations.py new file mode 100644 index 0000000000..48ccc6c7b2 --- /dev/null +++ b/smartsim/_core/generation/operations/operations.py @@ -0,0 +1,280 @@ +import base64 +import os +import pathlib +import pickle +import sys +import typing as t +from dataclasses import dataclass, field + +from ...commands import Command +from .utils.helpers import check_src_and_dest_path + +# pylint: disable-next=invalid-name +entry_point_path = "smartsim._core.entrypoints.file_operations" +"""Path to file operations module""" + +# pylint: disable-next=invalid-name +copy_cmd = "copy" +"""Copy file operation command""" +# pylint: disable-next=invalid-name +symlink_cmd = "symlink" +"""Symlink file operation command""" +# pylint: disable-next=invalid-name +configure_cmd = "configure" +"""Configure file operation command""" + +# pylint: disable-next=invalid-name +default_tag = ";" +"""Default configure tag""" + + +def _create_dest_path(job_run_path: pathlib.Path, dest: pathlib.Path) -> str: + """Combine the job run path and destination path. Return as a string for + entry point consumption. + + :param job_run_path: Job run path + :param dest: Destination path + :return: Combined path + """ + return str(job_run_path / dest) + + +def _check_run_path(run_path: pathlib.Path) -> None: + """Validate that the provided run path is of type pathlib.Path + + :param run_path: The run path to be checked + :raises TypeError: If either run path is not an instance of pathlib.Path + :raises ValueError: If the run path is not a directory + """ + if not isinstance(run_path, pathlib.Path): + raise TypeError( + f"The Job's run path must be of type pathlib.Path, not {type(run_path).__name__}" + ) + if not run_path.is_absolute(): + raise ValueError(f"The Job's run path must be absolute.") + + +class GenerationContext: + """Context for file system generation operations.""" + + def __init__(self, job_run_path: pathlib.Path): + """Initialize a GenerationContext object + + :param job_run_path: Job's run path + """ + _check_run_path(job_run_path) + self.job_run_path = job_run_path + """The Job run path""" + + +class GenerationProtocol(t.Protocol): + """Protocol for Generation Operations.""" + + def format(self, context: GenerationContext) -> Command: + """Return a formatted Command.""" + + +class CopyOperation(GenerationProtocol): + """Copy Operation""" + + def __init__( + self, src: pathlib.Path, dest: t.Optional[pathlib.Path] = None + ) -> None: + """Initialize a CopyOperation object + + :param src: Path to source + :param dest: Path to destination + """ + check_src_and_dest_path(src, dest) + self.src = src + """Path to source""" + self.dest = dest or pathlib.Path(src.name) + """Path to destination""" + + def format(self, context: GenerationContext) -> Command: + """Create Command to invoke copy file system entry point + + :param context: Context for copy operation + :return: Copy Command + """ + final_dest = _create_dest_path(context.job_run_path, self.dest) + return Command( + [ + sys.executable, + "-m", + entry_point_path, + copy_cmd, + str(self.src), + final_dest, + "--dirs_exist_ok", + ] + ) + + +class SymlinkOperation(GenerationProtocol): + """Symlink Operation""" + + def __init__( + self, src: pathlib.Path, dest: t.Optional[pathlib.Path] = None + ) -> None: + """Initialize a SymlinkOperation object + + :param src: Path to source + :param dest: Path to destination + """ + check_src_and_dest_path(src, dest) + self.src = src + """Path to source""" + self.dest = dest or pathlib.Path(src.name) + """Path to destination""" + + def format(self, context: GenerationContext) -> Command: + """Create Command to invoke symlink file system entry point + + :param context: Context for symlink operation + :return: Symlink Command + """ + normalized_path = os.path.normpath(self.src) + parent_dir = os.path.dirname(normalized_path) + final_dest = _create_dest_path(context.job_run_path, self.dest) + new_dest = os.path.join(final_dest, parent_dir) + return Command( + [ + sys.executable, + "-m", + entry_point_path, + symlink_cmd, + str(self.src), + new_dest, + ] + ) + + +class ConfigureOperation(GenerationProtocol): + """Configure Operation""" + + def __init__( + self, + src: pathlib.Path, + file_parameters: t.Mapping[str, str], + dest: t.Optional[pathlib.Path] = None, + tag: t.Optional[str] = None, + ) -> None: + """Initialize a ConfigureOperation + + :param src: Path to source + :param file_parameters: File parameters to find and replace + :param dest: Path to destination + :param tag: Tag to use for find and replacement + """ + check_src_and_dest_path(src, dest) + self.src = src + """Path to source""" + self.dest = dest or pathlib.Path(src.name) + """Path to destination""" + pickled_dict = pickle.dumps(file_parameters) + encoded_dict = base64.b64encode(pickled_dict).decode("ascii") + self.file_parameters = encoded_dict + """File parameters to find and replace""" + self.tag = tag if tag else default_tag + """Tag to use for find and replacement""" + + def format(self, context: GenerationContext) -> Command: + """Create Command to invoke configure file system entry point + + :param context: Context for configure operation + :return: Configure Command + """ + final_dest = _create_dest_path(context.job_run_path, self.dest) + return Command( + [ + sys.executable, + "-m", + entry_point_path, + configure_cmd, + str(self.src), + final_dest, + self.tag, + self.file_parameters, + ] + ) + + +GenerationProtocolT = t.TypeVar("GenerationProtocolT", bound=GenerationProtocol) + + +@dataclass +class FileSysOperationSet: + """Dataclass to represent a set of file system operation objects""" + + operations: list[GenerationProtocol] = field(default_factory=list) + """Set of file system objects that match the GenerationProtocol""" + + def add_copy( + self, src: pathlib.Path, dest: t.Optional[pathlib.Path] = None + ) -> None: + """Add a copy operation to the operations list + + :param src: Path to source + :param dest: Path to destination + """ + self.operations.append(CopyOperation(src, dest)) + + def add_symlink( + self, src: pathlib.Path, dest: t.Optional[pathlib.Path] = None + ) -> None: + """Add a symlink operation to the operations list + + :param src: Path to source + :param dest: Path to destination + """ + self.operations.append(SymlinkOperation(src, dest)) + + def add_configuration( + self, + src: pathlib.Path, + file_parameters: t.Mapping[str, str], + dest: t.Optional[pathlib.Path] = None, + tag: t.Optional[str] = None, + ) -> None: + """Add a configure operation to the operations list + + :param src: Path to source + :param file_parameters: File parameters to find and replace + :param dest: Path to destination + :param tag: Tag to use for find and replacement + """ + self.operations.append(ConfigureOperation(src, file_parameters, dest, tag)) + + @property + def copy_operations(self) -> list[CopyOperation]: + """Property to get the list of copy files. + + :return: List of CopyOperation objects + """ + return self._filter(CopyOperation) + + @property + def symlink_operations(self) -> list[SymlinkOperation]: + """Property to get the list of symlink files. + + :return: List of SymlinkOperation objects + """ + return self._filter(SymlinkOperation) + + @property + def configure_operations(self) -> list[ConfigureOperation]: + """Property to get the list of configure files. + + :return: List of ConfigureOperation objects + """ + return self._filter(ConfigureOperation) + + def _filter(self, type_: type[GenerationProtocolT]) -> list[GenerationProtocolT]: + """Filters the operations list to include only instances of the + specified type. + + :param type: The type of operations to filter + :return: A list of operations that are instances of the specified type + """ + return [x for x in self.operations if isinstance(x, type_)] diff --git a/smartsim/_core/generation/operations/utils/helpers.py b/smartsim/_core/generation/operations/utils/helpers.py new file mode 100644 index 0000000000..9d99b0e8bf --- /dev/null +++ b/smartsim/_core/generation/operations/utils/helpers.py @@ -0,0 +1,27 @@ +import pathlib +import typing as t + + +def check_src_and_dest_path( + src: pathlib.Path, dest: t.Union[pathlib.Path, None] +) -> None: + """Validate that the provided source and destination paths are + of type pathlib.Path. Additionally, validate that destination is a + relative Path and source is a absolute Path. + + :param src: The source path to check + :param dest: The destination path to check + :raises TypeError: If either src or dest is not of type pathlib.Path + :raises ValueError: If source is not an absolute Path or if destination is not + a relative Path + """ + if not isinstance(src, pathlib.Path): + raise TypeError(f"src must be of type pathlib.Path, not {type(src).__name__}") + if dest is not None and not isinstance(dest, pathlib.Path): + raise TypeError( + f"dest must be of type pathlib.Path or None, not {type(dest).__name__}" + ) + if dest is not None and dest.is_absolute(): + raise ValueError(f"dest must be a relative Path") + if not src.is_absolute(): + raise ValueError(f"src must be an absolute Path") diff --git a/smartsim/_core/launcher/__init__.py b/smartsim/_core/launcher/__init__.py index c6584ee3d9..3047aaed48 100644 --- a/smartsim/_core/launcher/__init__.py +++ b/smartsim/_core/launcher/__init__.py @@ -24,13 +24,13 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .dragon.dragonLauncher import DragonLauncher +from .dragon.dragon_launcher import DragonLauncher from .launcher import Launcher from .local.local import LocalLauncher -from .lsf.lsfLauncher import LSFLauncher -from .pbs.pbsLauncher import PBSLauncher -from .sge.sgeLauncher import SGELauncher -from .slurm.slurmLauncher import SlurmLauncher +from .lsf.lsf_launcher import LSFLauncher +from .pbs.pbs_launcher import PBSLauncher +from .sge.sge_launcher import SGELauncher +from .slurm.slurm_launcher import SlurmLauncher __all__ = [ "Launcher", diff --git a/smartsim/_core/launcher/colocated.py b/smartsim/_core/launcher/colocated.py deleted file mode 100644 index c69a9cef16..0000000000 --- a/smartsim/_core/launcher/colocated.py +++ /dev/null @@ -1,244 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024 Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import sys -import typing as t - -from ...entity.dbobject import DBModel, DBScript -from ...error import SSInternalError -from ..config import CONFIG -from ..utils.helpers import create_lockfile_name - - -def write_colocated_launch_script( - file_name: str, db_log: str, colocated_settings: t.Dict[str, t.Any] -) -> None: - """Write the colocated launch script - - This file will be written into the cwd of the step that - is created for this entity. - - :param file_name: name of the script to write - :param db_log: log file for the db - :param colocated_settings: db settings from entity run_settings - """ - - colocated_cmd = _build_colocated_wrapper_cmd(db_log, **colocated_settings) - - with open(file_name, "w", encoding="utf-8") as script_file: - script_file.write("#!/bin/bash\n") - script_file.write("set -e\n\n") - - script_file.write("Cleanup () {\n") - script_file.write("if ps -p $DBPID > /dev/null; then\n") - script_file.write("\tkill -15 $DBPID\n") - script_file.write("fi\n}\n\n") - - # run cleanup after all exitcodes - script_file.write("trap Cleanup exit\n\n") - - # force entrypoint to write some debug information to the - # STDOUT of the job - if colocated_settings["debug"]: - script_file.write("export SMARTSIM_LOG_LEVEL=debug\n") - script_file.write(f"db_stdout=$({colocated_cmd})\n") - # extract and set DBPID within the shell script that is - # enclosed between __PID__ and sent to stdout by the colocated - # entrypoints file - script_file.write( - "DBPID=$(echo $db_stdout | sed -n " - "'s/.*__PID__\\([0-9]*\\)__PID__.*/\\1/p')\n" - ) - - # Write the actual launch command for the app - script_file.write("$@\n\n") - - -def _build_colocated_wrapper_cmd( - db_log: str, - cpus: int = 1, - rai_args: t.Optional[t.Dict[str, str]] = None, - extra_db_args: t.Optional[t.Dict[str, str]] = None, - port: int = 6780, - ifname: t.Optional[t.Union[str, t.List[str]]] = None, - custom_pinning: t.Optional[str] = None, - **kwargs: t.Any, -) -> str: - """Build the command use to run a colocated DB application - - :param db_log: log file for the db - :param cpus: db cpus - :param rai_args: redisai args - :param extra_db_args: extra redis args - :param port: port to bind DB to - :param ifname: network interface(s) to bind DB to - :param db_cpu_list: The list of CPUs that the database should be limited to - :return: the command to run - """ - # pylint: disable=too-many-locals - - # create unique lockfile name to avoid symlink vulnerability - # this is the lockfile all the processes in the distributed - # application will try to acquire. since we use a local tmp - # directory on the compute node, only one process can acquire - # the lock on the file. - lockfile = create_lockfile_name() - - # create the command that will be used to launch the - # database with the python entrypoint for starting - # up the backgrounded db process - - cmd = [ - sys.executable, - "-m", - "smartsim._core.entrypoints.colocated", - "+lockfile", - lockfile, - "+db_cpus", - str(cpus), - ] - # Add in the interface if using TCP/IP - if ifname: - if isinstance(ifname, str): - ifname = [ifname] - cmd.extend(["+ifname", ",".join(ifname)]) - cmd.append("+command") - # collect DB binaries and libraries from the config - - db_cmd = [] - if custom_pinning: - db_cmd.extend(["taskset", "-c", custom_pinning]) - db_cmd.extend( - [CONFIG.database_exe, CONFIG.database_conf, "--loadmodule", CONFIG.redisai] - ) - - # add extra redisAI configurations - for arg, value in (rai_args or {}).items(): - if value: - # RAI wants arguments for inference in all caps - # ex. THREADS_PER_QUEUE=1 - db_cmd.append(f"{arg.upper()} {str(value)}") - - db_cmd.extend(["--port", str(port)]) - - # Add socket and permissions for UDS - unix_socket = kwargs.get("unix_socket", None) - socket_permissions = kwargs.get("socket_permissions", None) - - if unix_socket and socket_permissions: - db_cmd.extend( - [ - "--unixsocket", - str(unix_socket), - "--unixsocketperm", - str(socket_permissions), - ] - ) - elif bool(unix_socket) ^ bool(socket_permissions): - raise SSInternalError( - "`unix_socket` and `socket_permissions` must both be defined or undefined." - ) - - db_cmd.extend( - ["--logfile", db_log] - ) # usually /dev/null, unless debug was specified - if extra_db_args: - for db_arg, value in extra_db_args.items(): - # replace "_" with "-" in the db_arg because we use kwargs - # for the extra configurations and Python doesn't allow a hyphen - # in a variable name. All redis and KeyDB configuration options - # use hyphens in their names. - db_arg = db_arg.replace("_", "-") - db_cmd.extend([f"--{db_arg}", value]) - - db_models = kwargs.get("db_models", None) - if db_models: - db_model_cmd = _build_db_model_cmd(db_models) - db_cmd.extend(db_model_cmd) - - db_scripts = kwargs.get("db_scripts", None) - if db_scripts: - db_script_cmd = _build_db_script_cmd(db_scripts) - db_cmd.extend(db_script_cmd) - - cmd.extend(db_cmd) - - return " ".join(cmd) - - -def _build_db_model_cmd(db_models: t.List[DBModel]) -> t.List[str]: - cmd = [] - for db_model in db_models: - cmd.append("+db_model") - cmd.append(f"--name={db_model.name}") - - # Here db_model.file is guaranteed to exist - # because we don't allow the user to pass a serialized DBModel - cmd.append(f"--file={db_model.file}") - - cmd.append(f"--backend={db_model.backend}") - cmd.append(f"--device={db_model.device}") - cmd.append(f"--devices_per_node={db_model.devices_per_node}") - cmd.append(f"--first_device={db_model.first_device}") - if db_model.batch_size: - cmd.append(f"--batch_size={db_model.batch_size}") - if db_model.min_batch_size: - cmd.append(f"--min_batch_size={db_model.min_batch_size}") - if db_model.min_batch_timeout: - cmd.append(f"--min_batch_timeout={db_model.min_batch_timeout}") - if db_model.tag: - cmd.append(f"--tag={db_model.tag}") - if db_model.inputs: - cmd.append("--inputs=" + ",".join(db_model.inputs)) - if db_model.outputs: - cmd.append("--outputs=" + ",".join(db_model.outputs)) - - return cmd - - -def _build_db_script_cmd(db_scripts: t.List[DBScript]) -> t.List[str]: - cmd = [] - for db_script in db_scripts: - cmd.append("+db_script") - cmd.append(f"--name={db_script.name}") - if db_script.func: - # Notice that here db_script.func is guaranteed to be a str - # because we don't allow the user to pass a serialized function - func = db_script.func - sanitized_func = func.replace("\n", "\\n") - if not ( - sanitized_func.startswith("'") - and sanitized_func.endswith("'") - or (sanitized_func.startswith('"') and sanitized_func.endswith('"')) - ): - sanitized_func = '"' + sanitized_func + '"' - cmd.append(f"--func={sanitized_func}") - elif db_script.file: - cmd.append(f"--file={db_script.file}") - cmd.append(f"--device={db_script.device}") - cmd.append(f"--devices_per_node={db_script.devices_per_node}") - cmd.append(f"--first_device={db_script.first_device}") - return cmd diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragon_backend.py similarity index 96% rename from smartsim/_core/launcher/dragon/dragonBackend.py rename to smartsim/_core/launcher/dragon/dragon_backend.py index 5e01299141..82863d73b5 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragon_backend.py @@ -59,8 +59,10 @@ # pylint: enable=import-error # isort: on -from ...._core.config import get_config -from ...._core.schemas import ( +from ....log import get_logger +from ....status import TERMINAL_STATUSES, JobStatus +from ...config import get_config +from ...schemas import ( DragonHandshakeRequest, DragonHandshakeResponse, DragonRequest, @@ -74,9 +76,7 @@ DragonUpdateStatusRequest, DragonUpdateStatusResponse, ) -from ...._core.utils.helpers import create_short_id_str -from ....log import get_logger -from ....status import TERMINAL_STATUSES, SmartSimStatus +from ...utils.helpers import create_short_id_str logger = get_logger(__name__) @@ -91,7 +91,7 @@ def __str__(self) -> str: @dataclass class ProcessGroupInfo: - status: SmartSimStatus + status: JobStatus """Status of step""" process_group: t.Optional[dragon_process_group.ProcessGroup] = None """Internal Process Group object, None for finished or not started steps""" @@ -105,7 +105,7 @@ class ProcessGroupInfo: """Workers used to redirect stdout and stderr to file""" @property - def smartsim_info(self) -> t.Tuple[SmartSimStatus, t.Optional[t.List[int]]]: + def smartsim_info(self) -> t.Tuple[JobStatus, t.Optional[t.List[int]]]: """Information needed by SmartSim Launcher and Job Manager""" return (self.status, self.return_codes) @@ -546,7 +546,7 @@ def _stop_steps(self) -> None: except Exception as e: logger.error(e) - self._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED + self._group_infos[step_id].status = JobStatus.CANCELLED self._group_infos[step_id].return_codes = [-9] def _create_backbone(self) -> BackboneFeatureStore: @@ -708,10 +708,10 @@ def _start_steps(self) -> None: try: grp.init() grp.start() - grp_status = SmartSimStatus.STATUS_RUNNING + grp_status = JobStatus.RUNNING except Exception as e: logger.error(e) - grp_status = SmartSimStatus.STATUS_FAILED + grp_status = JobStatus.FAILED puids = None try: @@ -733,7 +733,7 @@ def _start_steps(self) -> None: if ( puids is not None and len(puids) == len(policies) - and grp_status == SmartSimStatus.STATUS_RUNNING + and grp_status == JobStatus.RUNNING ): redir_grp = DragonBackend._create_redirect_workers( global_policy, @@ -750,7 +750,7 @@ def _start_steps(self) -> None: f"Could not redirect stdout and stderr for PUIDS {puids}" ) from e self._group_infos[step_id].redir_workers = redir_grp - elif puids is not None and grp_status == SmartSimStatus.STATUS_RUNNING: + elif puids is not None and grp_status == JobStatus.RUNNING: logger.error("Cannot redirect workers: some PUIDS are missing") if started: @@ -776,11 +776,11 @@ def _refresh_statuses(self) -> None: group_info = self._group_infos[step_id] grp = group_info.process_group if grp is None: - group_info.status = SmartSimStatus.STATUS_FAILED + group_info.status = JobStatus.FAILED group_info.return_codes = [-1] elif group_info.status not in TERMINAL_STATUSES: if grp.status == str(DragonStatus.RUNNING): - group_info.status = SmartSimStatus.STATUS_RUNNING + group_info.status = JobStatus.RUNNING else: puids = group_info.puids if puids is not None and all( @@ -796,12 +796,12 @@ def _refresh_statuses(self) -> None: group_info.return_codes = [-1 for _ in puids] else: group_info.return_codes = [0] - if not group_info.status == SmartSimStatus.STATUS_CANCELLED: + if not group_info.status == JobStatus.CANCELLED: group_info.status = ( - SmartSimStatus.STATUS_FAILED + JobStatus.FAILED if any(group_info.return_codes) or grp.status == DragonStatus.ERROR - else SmartSimStatus.STATUS_COMPLETED + else JobStatus.COMPLETED ) if group_info.status in TERMINAL_STATUSES: @@ -905,13 +905,11 @@ def _(self, request: DragonRunRequest) -> DragonRunResponse: honorable, err = self._can_honor(request) if not honorable: self._group_infos[step_id] = ProcessGroupInfo( - status=SmartSimStatus.STATUS_FAILED, return_codes=[-1] + status=JobStatus.FAILED, return_codes=[-1] ) else: self._queued_steps[step_id] = request - self._group_infos[step_id] = ProcessGroupInfo( - status=SmartSimStatus.STATUS_NEVER_STARTED - ) + self._group_infos[step_id] = ProcessGroupInfo(status=JobStatus.NEW) return DragonRunResponse(step_id=step_id, error_message=err) @process_request.register diff --git a/smartsim/_core/launcher/dragon/dragonConnector.py b/smartsim/_core/launcher/dragon/dragon_connector.py similarity index 92% rename from smartsim/_core/launcher/dragon/dragonConnector.py rename to smartsim/_core/launcher/dragon/dragon_connector.py index 1144b7764e..9c96592776 100644 --- a/smartsim/_core/launcher/dragon/dragonConnector.py +++ b/smartsim/_core/launcher/dragon/dragon_connector.py @@ -42,7 +42,6 @@ import zmq import zmq.auth.thread -from ...._core.launcher.dragon import dragonSockets from ....error.errors import SmartSimError from ....log import get_logger from ...config import get_config @@ -56,6 +55,12 @@ DragonShutdownRequest, ) from ...utils.network import find_free_port, get_best_interface_and_address +from . import dragon_sockets + +if t.TYPE_CHECKING: + from typing_extensions import Self + + from smartsim.experiment import Experiment logger = get_logger(__name__) @@ -69,7 +74,7 @@ class DragonConnector: to start a Dragon server and communicate with it. """ - def __init__(self) -> None: + def __init__(self, path: str | os.PathLike[str]) -> None: self._context: zmq.Context[t.Any] = zmq.Context.instance() """ZeroMQ context used to share configuration across requests""" self._context.setsockopt(zmq.REQ_CORRELATE, 1) @@ -78,6 +83,12 @@ def __init__(self) -> None: """ZeroMQ authenticator used to secure queue access""" config = get_config() self._reset_timeout(config.dragon_server_timeout) + + # TODO: We should be able to make these "non-optional" + # by simply moving the impl of + # `DragonConnectior.connect_to_dragon` to this method. This is + # fine as we expect the that method should only be called once + # without hitting a guard clause. self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None """ZeroMQ socket exposing the connection to the DragonBackend""" self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None @@ -86,18 +97,10 @@ def __init__(self) -> None: # but process was started by another connector self._dragon_head_pid: t.Optional[int] = None """Process ID of the process executing the DragonBackend""" - self._dragon_server_path = config.dragon_server_path + self._dragon_server_path = _resolve_dragon_path(path) """Path to a dragon installation""" logger.debug(f"Dragon Server path was set to {self._dragon_server_path}") self._env_vars: t.Dict[str, str] = {} - if self._dragon_server_path is None: - raise SmartSimError( - "DragonConnector could not find the dragon server path. " - "This should not happen if the Connector was started by an " - "experiment.\nIf the DragonConnector was started manually, " - "then the environment variable SMARTSIM_DRAGON_SERVER_PATH " - "should be set to an existing directory." - ) @property def is_connected(self) -> bool: @@ -122,7 +125,7 @@ def _handshake(self, address: str) -> None: :param address: The address of the head node socket to initiate a handhake with """ - self._dragon_head_socket = dragonSockets.get_secure_socket( + self._dragon_head_socket = dragon_sockets.get_secure_socket( self._context, zmq.REQ, False ) self._dragon_head_socket.connect(address) @@ -190,7 +193,7 @@ def _get_new_authenticator( except zmq.Again: logger.debug("Could not stop authenticator") try: - self._authenticator = dragonSockets.get_authenticator( + self._authenticator = dragon_sockets.get_authenticator( self._context, timeout ) return @@ -251,7 +254,9 @@ def _start_connector_socket(self, socket_addr: str) -> zmq.Socket[t.Any]: connector_socket: t.Optional[zmq.Socket[t.Any]] = None self._reset_timeout(config.dragon_server_startup_timeout) self._get_new_authenticator(-1) - connector_socket = dragonSockets.get_secure_socket(self._context, zmq.REP, True) + connector_socket = dragon_sockets.get_secure_socket( + self._context, zmq.REP, True + ) logger.debug(f"Binding connector to {socket_addr}") connector_socket.bind(socket_addr) if connector_socket is None: @@ -328,8 +333,7 @@ def connect_to_dragon(self) -> None: "Establishing connection with Dragon server or starting a new one..." ) - path = _resolve_dragon_path(self._dragon_server_path) - + path = self._dragon_server_path self._connect_to_existing_server(path) if self.is_connected: return @@ -386,7 +390,7 @@ def connect_to_dragon(self) -> None: start_new_session=True, ) - server = dragonSockets.as_server(connector_socket) + server = dragon_sockets.as_server(connector_socket) logger.debug(f"Listening to {socket_addr}") request = _assert_schema_type(server.recv(), DragonBootstrapRequest) server.send( @@ -520,7 +524,7 @@ def _send_req_with_socket( allow the receiver to immediately respond to the sent request. :returns: The response from the target """ - client = dragonSockets.as_client(socket) + client = dragon_sockets.as_client(socket) with DRG_LOCK: logger.debug(f"Sending {type(request).__name__}: {request}") client.send(request, send_flags) @@ -589,14 +593,25 @@ def _dragon_cleanup( def _resolve_dragon_path(fallback: t.Union[str, "os.PathLike[str]"]) -> Path: - """Determine the applicable dragon server path for the connector + """Return the path at which a user should set up a dragon server. + + The order of path resolution is: + 1) If the the user has set a global dragon path via + `Config.dragon_server_path` use that without alteration. + 2) Use the `fallback` path which should be the path to an existing + directory. Append the default dragon server subdirectory defined by + `Config.dragon_default_subdir` + + Currently this function will raise if a user attempts to specify multiple + dragon server paths via `:` seperation. - :param fallback: A default dragon server path to use if one is not - found in the runtime configuration - :returns: The path to the dragon libraries + :param fallback: The path to an existing directory on the file system to + use if the global dragon directory is not set. + :returns: The path to directory in which the dragon server should run. """ - dragon_server_path = get_config().dragon_server_path or os.path.join( - fallback, ".smartsim", "dragon" + config = get_config() + dragon_server_path = config.dragon_server_path or os.path.join( + fallback, config.dragon_default_subdir ) dragon_server_paths = dragon_server_path.split(":") if len(dragon_server_paths) > 1: diff --git a/smartsim/_core/launcher/dragon/dragonLauncher.py b/smartsim/_core/launcher/dragon/dragon_launcher.py similarity index 71% rename from smartsim/_core/launcher/dragon/dragonLauncher.py rename to smartsim/_core/launcher/dragon/dragon_launcher.py index 75ca675225..752b6c2495 100644 --- a/smartsim/_core/launcher/dragon/dragonLauncher.py +++ b/smartsim/_core/launcher/dragon/dragon_launcher.py @@ -27,11 +27,13 @@ from __future__ import annotations import os +import pathlib import typing as t -from smartsim._core.schemas.dragonRequests import DragonRunPolicy +from smartsim._core.schemas.dragon_requests import DragonRunPolicy +from smartsim.error import errors +from smartsim.types import LaunchedJobID -from ...._core.launcher.stepMapping import StepMap from ....error import LauncherError, SmartSimError from ....log import get_logger from ....settings import ( @@ -41,9 +43,10 @@ SbatchSettings, SettingsBase, ) -from ....status import SmartSimStatus +from ....status import JobStatus from ...schemas import ( DragonRunRequest, + DragonRunRequestView, DragonRunResponse, DragonStopRequest, DragonStopResponse, @@ -51,11 +54,18 @@ DragonUpdateStatusResponse, ) from ..launcher import WLMLauncher -from ..pbs.pbsLauncher import PBSLauncher -from ..slurm.slurmLauncher import SlurmLauncher +from ..pbs.pbs_launcher import PBSLauncher +from ..slurm.slurm_launcher import SlurmLauncher from ..step import DragonBatchStep, DragonStep, LocalStep, Step -from ..stepInfo import StepInfo -from .dragonConnector import DragonConnector, _SchemaT +from ..step_info import StepInfo +from ..step_mapping import StepMap +from .dragon_connector import DragonConnector, _SchemaT + +if t.TYPE_CHECKING: + from typing_extensions import Self + + from smartsim.experiment import Experiment + logger = get_logger(__name__) @@ -74,9 +84,9 @@ class DragonLauncher(WLMLauncher): the Job Manager to interact with it. """ - def __init__(self) -> None: + def __init__(self, server_path: str | os.PathLike[str]) -> None: super().__init__() - self._connector = DragonConnector() + self._connector = DragonConnector(server_path) """Connector used to start and interact with the Dragon server""" self._slurm_launcher = SlurmLauncher() """Slurm sub-launcher, used only for batch jobs""" @@ -121,6 +131,28 @@ def add_step_to_mapping_table(self, name: str, step_map: StepMap) -> None: ) sublauncher.add_step_to_mapping_table(name, sublauncher_step_map) + @classmethod + def create(cls, exp: Experiment) -> Self: + self = cls(exp.exp_path) + self._connector.connect_to_dragon() # pylint: disable=protected-access + return self + + def start( + self, args_and_policy: tuple[DragonRunRequestView, DragonRunPolicy] + ) -> LaunchedJobID: + req_args, policy = args_and_policy + self._connector.load_persisted_env() + merged_env = self._connector.merge_persisted_env(os.environ.copy()) + req = DragonRunRequest(**dict(req_args), current_env=merged_env, policy=policy) + res = _assert_schema_type(self._connector.send_request(req), DragonRunResponse) + return LaunchedJobID(res.step_id) + + def get_status( + self, *launched_ids: LaunchedJobID + ) -> t.Mapping[LaunchedJobID, JobStatus]: + infos = self._get_managed_step_update(list(launched_ids)) + return {id_: info.status for id_, info in zip(launched_ids, infos)} + def run(self, step: Step) -> t.Optional[str]: """Run a job step through Slurm @@ -167,16 +199,14 @@ def run(self, step: Step) -> t.Optional[str]: run_args = step.run_settings.run_args req_env = step.run_settings.env_vars self._connector.load_persisted_env() - merged_env = self._connector.merge_persisted_env(os.environ.copy()) nodes = int(run_args.get("nodes", None) or 1) tasks_per_node = int(run_args.get("tasks-per-node", None) or 1) hosts = run_args.get("host-list", None) policy = DragonRunPolicy.from_run_args(run_args) - - response = _assert_schema_type( - self._connector.send_request( - DragonRunRequest( + step_id = self.start( + ( + DragonRunRequestView( exe=cmd[0], exe_args=cmd[1:], path=step.cwd, @@ -184,16 +214,13 @@ def run(self, step: Step) -> t.Optional[str]: nodes=nodes, tasks_per_node=tasks_per_node, env=req_env, - current_env=merged_env, output_file=out, error_file=err, - policy=policy, hostlist=hosts, - ) - ), - DragonRunResponse, + ), + policy, + ) ) - step_id = str(response.step_id) else: # pylint: disable-next=consider-using-with out_strm = open(out, "w+", encoding="utf-8") @@ -234,11 +261,23 @@ def stop(self, step_name: str) -> StepInfo: raise LauncherError(f"Could not get step_info for job step {step_name}") step_info.status = ( - SmartSimStatus.STATUS_CANCELLED # set status to cancelled instead of failed + JobStatus.CANCELLED # set status to cancelled instead of failed ) - step_info.launcher_status = str(SmartSimStatus.STATUS_CANCELLED) + step_info.launcher_status = str(JobStatus.CANCELLED) return step_info + def stop_jobs( + self, *launched_ids: LaunchedJobID + ) -> t.Mapping[LaunchedJobID, JobStatus]: + """Take a collection of job ids and issue stop requests to the dragon + backend for each. + + :param launched_ids: The ids of the launched jobs to stop. + :returns: A mapping of ids for jobs to stop to their reported status + after attempting to stop them. + """ + return {id_: self.stop(id_).status for id_ in launched_ids} + @staticmethod def _unprefix_step_id(step_id: str) -> str: return step_id.split("-", maxsplit=1)[1] @@ -296,8 +335,8 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: msg += response.error_message logger.error(msg) info = StepInfo( - SmartSimStatus.STATUS_FAILED, - SmartSimStatus.STATUS_FAILED.value, + JobStatus.FAILED, + JobStatus.FAILED.value, -1, ) else: @@ -316,8 +355,12 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: step_id_updates[step_id] = info - # Order matters as we return an ordered list of StepInfo objects - return [step_id_updates[step_id] for step_id in step_ids] + try: + # Order matters as we return an ordered list of StepInfo objects + return [step_id_updates[step_id] for step_id in step_ids] + except KeyError: + msg = "A step info could not be found for one or more of the requested ids" + raise errors.LauncherJobNotFound(msg) from None def __str__(self) -> str: return "Dragon" @@ -327,3 +370,51 @@ def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT: if not isinstance(obj, typ): raise TypeError(f"Expected schema of type `{typ}`, but got {type(obj)}") return obj + + +from smartsim._core.dispatch import dispatch + +# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +# TODO: Remove this registry and move back to builder file after fixing +# circular import caused by `DragonLauncher.supported_rs` +# ----------------------------------------------------------------------------- +from smartsim.settings.arguments.launch.dragon import DragonLaunchArguments + + +def _as_run_request_args_and_policy( + run_req_args: DragonLaunchArguments, + exe: t.Sequence[str], + path: str | os.PathLike[str], + env: t.Mapping[str, str | None], + stdout_path: pathlib.Path, + stderr_path: pathlib.Path, +) -> tuple[DragonRunRequestView, DragonRunPolicy]: + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # FIXME: This type is 100% unacceptable, but I don't want to spend too much + # time on fixing the dragon launcher API. Something that we need to + # revisit in the future though. + exe_, *args = exe + run_args = dict[str, "int | str | float | None"](run_req_args._launch_args) + policy = DragonRunPolicy.from_run_args(run_args) + return ( + DragonRunRequestView( + exe=exe_, + exe_args=args, + path=path, + env=env, + # TODO: Not sure how this info is injected + name=None, + output_file=stdout_path, + error_file=stderr_path, + **run_args, + ), + policy, + ) + + +dispatch( + DragonLaunchArguments, + with_format=_as_run_request_args_and_policy, + to_launcher=DragonLauncher, +) +# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< diff --git a/smartsim/_core/launcher/dragon/dragonSockets.py b/smartsim/_core/launcher/dragon/dragon_sockets.py similarity index 97% rename from smartsim/_core/launcher/dragon/dragonSockets.py rename to smartsim/_core/launcher/dragon/dragon_sockets.py index 80acd61a2a..0ffe857e6d 100644 --- a/smartsim/_core/launcher/dragon/dragonSockets.py +++ b/smartsim/_core/launcher/dragon/dragon_sockets.py @@ -30,8 +30,8 @@ import zmq.auth.thread from smartsim._core.config.config import get_config -from smartsim._core.schemas import dragonRequests as _dragonRequests -from smartsim._core.schemas import dragonResponses as _dragonResponses +from smartsim._core.schemas import dragon_requests as _dragonRequests +from smartsim._core.schemas import dragon_responses as _dragonResponses from smartsim._core.schemas import utils as _utils from smartsim._core.utils.security import KeyManager from smartsim.log import get_logger diff --git a/smartsim/_core/launcher/launcher.py b/smartsim/_core/launcher/launcher.py index 1bf768065c..5b2894cf35 100644 --- a/smartsim/_core/launcher/launcher.py +++ b/smartsim/_core/launcher/launcher.py @@ -27,13 +27,13 @@ import abc import typing as t -from ..._core.launcher.stepMapping import StepMap +from ...entity import SmartSimEntity from ...error import AllocationError, LauncherError, SSUnsupportedError from ...settings import SettingsBase from .step import Step -from .stepInfo import StepInfo, UnmanagedStepInfo -from .stepMapping import StepMapping -from .taskManager import TaskManager +from .step_info import StepInfo, UnmanagedStepInfo +from .step_mapping import StepMap, StepMapping +from .task_manager import TaskManager class Launcher(abc.ABC): # pragma: no cover @@ -49,7 +49,7 @@ class Launcher(abc.ABC): # pragma: no cover task_manager: TaskManager @abc.abstractmethod - def create_step(self, name: str, cwd: str, step_settings: SettingsBase) -> Step: + def create_step(self, entity: SmartSimEntity, step_settings: SettingsBase) -> Step: raise NotImplementedError @abc.abstractmethod @@ -99,7 +99,7 @@ def supported_rs(self) -> t.Dict[t.Type[SettingsBase], t.Type[Step]]: # every launcher utilizing this interface must have a map # of supported RunSettings types (see slurmLauncher.py for ex) def create_step( - self, name: str, cwd: str, step_settings: SettingsBase + self, entity: SmartSimEntity, step_settings: SettingsBase ) -> Step: # cov-wlm """Create a WLM job step @@ -117,7 +117,7 @@ def create_step( f"RunSettings type {type(step_settings)} not supported by this launcher" ) from None try: - return step_class(name, cwd, step_settings) + return step_class(entity, step_settings) except AllocationError as e: raise LauncherError("Step creation failed") from e diff --git a/smartsim/_core/launcher/local/local.py b/smartsim/_core/launcher/local/local.py index ffcb84f231..9a902f003d 100644 --- a/smartsim/_core/launcher/local/local.py +++ b/smartsim/_core/launcher/local/local.py @@ -26,12 +26,13 @@ import typing as t +from ....entity import SmartSimEntity from ....settings import RunSettings, SettingsBase from ..launcher import Launcher from ..step import LocalStep, Step -from ..stepInfo import StepInfo, UnmanagedStepInfo -from ..stepMapping import StepMapping -from ..taskManager import TaskManager +from ..step_info import StepInfo, UnmanagedStepInfo +from ..step_mapping import StepMapping +from ..task_manager import TaskManager class LocalLauncher(Launcher): @@ -41,17 +42,18 @@ def __init__(self) -> None: self.task_manager = TaskManager() self.step_mapping = StepMapping() - def create_step(self, name: str, cwd: str, step_settings: SettingsBase) -> Step: + def create_step(self, entity: SmartSimEntity, step_settings: SettingsBase) -> Step: """Create a job step to launch an entity locally :return: Step object """ + # probably need to instead change this to exe and exe_args if not isinstance(step_settings, RunSettings): raise TypeError( "Local Launcher only supports entities with RunSettings, " f"not {type(step_settings)}" ) - return LocalStep(name, cwd, step_settings) + return LocalStep(entity, step_settings) def get_step_update( self, step_names: t.List[str] diff --git a/smartsim/_core/launcher/lsf/lsfCommands.py b/smartsim/_core/launcher/lsf/lsf_commands.py similarity index 100% rename from smartsim/_core/launcher/lsf/lsfCommands.py rename to smartsim/_core/launcher/lsf/lsf_commands.py diff --git a/smartsim/_core/launcher/lsf/lsfLauncher.py b/smartsim/_core/launcher/lsf/lsf_launcher.py similarity index 96% rename from smartsim/_core/launcher/lsf/lsfLauncher.py rename to smartsim/_core/launcher/lsf/lsf_launcher.py index e0ad808ed8..472d66b89b 100644 --- a/smartsim/_core/launcher/lsf/lsfLauncher.py +++ b/smartsim/_core/launcher/lsf/lsf_launcher.py @@ -38,7 +38,7 @@ RunSettings, SettingsBase, ) -from ....status import SmartSimStatus +from ....status import JobStatus from ...config import CONFIG from ..launcher import WLMLauncher from ..step import ( @@ -50,9 +50,9 @@ OrterunStep, Step, ) -from ..stepInfo import LSFBatchStepInfo, LSFJsrunStepInfo, StepInfo -from .lsfCommands import bjobs, bkill, jskill, jslist -from .lsfParser import ( +from ..step_info import LSFBatchStepInfo, LSFJsrunStepInfo, StepInfo +from .lsf_commands import bjobs, bkill, jskill, jslist +from .lsf_parser import ( parse_bjobs_jobid, parse_bsub, parse_jslist_stepid, @@ -152,7 +152,7 @@ def stop(self, step_name: str) -> StepInfo: raise LauncherError(f"Could not get step_info for job step {step_name}") step_info.status = ( - SmartSimStatus.STATUS_CANCELLED + JobStatus.CANCELLED ) # set status to cancelled instead of failed return step_info @@ -203,7 +203,7 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: # create LSFBatchStepInfo objects to return batch_info = LSFBatchStepInfo(stat, None) # account for case where job history is not logged by LSF - if batch_info.status == SmartSimStatus.STATUS_COMPLETED: + if batch_info.status == JobStatus.COMPLETED: batch_info.returncode = 0 updates.append(batch_info) return updates diff --git a/smartsim/_core/launcher/lsf/lsfParser.py b/smartsim/_core/launcher/lsf/lsf_parser.py similarity index 100% rename from smartsim/_core/launcher/lsf/lsfParser.py rename to smartsim/_core/launcher/lsf/lsf_parser.py diff --git a/smartsim/_core/launcher/pbs/pbsCommands.py b/smartsim/_core/launcher/pbs/pbs_commands.py similarity index 100% rename from smartsim/_core/launcher/pbs/pbsCommands.py rename to smartsim/_core/launcher/pbs/pbs_commands.py diff --git a/smartsim/_core/launcher/pbs/pbsLauncher.py b/smartsim/_core/launcher/pbs/pbs_launcher.py similarity index 96% rename from smartsim/_core/launcher/pbs/pbsLauncher.py rename to smartsim/_core/launcher/pbs/pbs_launcher.py index 8c2099a8bc..fe8a9538b9 100644 --- a/smartsim/_core/launcher/pbs/pbsLauncher.py +++ b/smartsim/_core/launcher/pbs/pbs_launcher.py @@ -39,7 +39,7 @@ RunSettings, SettingsBase, ) -from ....status import SmartSimStatus +from ....status import JobStatus from ...config import CONFIG from ..launcher import WLMLauncher from ..step import ( @@ -51,9 +51,9 @@ QsubBatchStep, Step, ) -from ..stepInfo import PBSStepInfo, StepInfo -from .pbsCommands import qdel, qstat -from .pbsParser import ( +from ..step_info import PBSStepInfo, StepInfo +from .pbs_commands import qdel, qstat +from .pbs_parser import ( parse_qstat_jobid, parse_qstat_jobid_json, parse_step_id_from_qstat, @@ -150,7 +150,7 @@ def stop(self, step_name: str) -> StepInfo: raise LauncherError(f"Could not get step_info for job step {step_name}") step_info.status = ( - SmartSimStatus.STATUS_CANCELLED + JobStatus.CANCELLED ) # set status to cancelled instead of failed return step_info @@ -202,7 +202,7 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: for stat, _ in zip(stats, step_ids): info = PBSStepInfo(stat or "NOTFOUND", None) # account for case where job history is not logged by PBS - if info.status == SmartSimStatus.STATUS_COMPLETED: + if info.status == JobStatus.COMPLETED: info.returncode = 0 updates.append(info) diff --git a/smartsim/_core/launcher/pbs/pbsParser.py b/smartsim/_core/launcher/pbs/pbs_parser.py similarity index 100% rename from smartsim/_core/launcher/pbs/pbsParser.py rename to smartsim/_core/launcher/pbs/pbs_parser.py diff --git a/smartsim/_core/launcher/sge/sgeCommands.py b/smartsim/_core/launcher/sge/sge_commands.py similarity index 100% rename from smartsim/_core/launcher/sge/sgeCommands.py rename to smartsim/_core/launcher/sge/sge_commands.py diff --git a/smartsim/_core/launcher/sge/sgeLauncher.py b/smartsim/_core/launcher/sge/sge_launcher.py similarity index 93% rename from smartsim/_core/launcher/sge/sgeLauncher.py rename to smartsim/_core/launcher/sge/sge_launcher.py index af600cf1d2..82c1f8fe94 100644 --- a/smartsim/_core/launcher/sge/sgeLauncher.py +++ b/smartsim/_core/launcher/sge/sge_launcher.py @@ -37,7 +37,7 @@ SettingsBase, SgeQsubBatchSettings, ) -from ....status import SmartSimStatus +from ....status import JobStatus from ...config import CONFIG from ..launcher import WLMLauncher from ..step import ( @@ -48,9 +48,9 @@ SgeQsubBatchStep, Step, ) -from ..stepInfo import SGEStepInfo, StepInfo -from .sgeCommands import qacct, qdel, qstat -from .sgeParser import parse_qacct_job_output, parse_qstat_jobid_xml +from ..step_info import SGEStepInfo, StepInfo +from .sge_commands import qacct, qdel, qstat +from .sge_parser import parse_qacct_job_output, parse_qstat_jobid_xml logger = get_logger(__name__) @@ -137,7 +137,7 @@ def stop(self, step_name: str) -> StepInfo: raise LauncherError(f"Could not get step_info for job step {step_name}") step_info.status = ( - SmartSimStatus.STATUS_CANCELLED + JobStatus.CANCELLED ) # set status to cancelled instead of failed return step_info @@ -166,13 +166,13 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: if qacct_output: failed = bool(int(parse_qacct_job_output(qacct_output, "failed"))) if failed: - info.status = SmartSimStatus.STATUS_FAILED + info.status = JobStatus.FAILED info.returncode = 0 else: - info.status = SmartSimStatus.STATUS_COMPLETED + info.status = JobStatus.COMPLETED info.returncode = 0 else: # Assume if qacct did not find it, that the job completed - info.status = SmartSimStatus.STATUS_COMPLETED + info.status = JobStatus.COMPLETED info.returncode = 0 else: info = SGEStepInfo(stat) diff --git a/smartsim/_core/launcher/sge/sgeParser.py b/smartsim/_core/launcher/sge/sge_parser.py similarity index 100% rename from smartsim/_core/launcher/sge/sgeParser.py rename to smartsim/_core/launcher/sge/sge_parser.py diff --git a/smartsim/_core/launcher/slurm/slurmCommands.py b/smartsim/_core/launcher/slurm/slurm_commands.py similarity index 100% rename from smartsim/_core/launcher/slurm/slurmCommands.py rename to smartsim/_core/launcher/slurm/slurm_commands.py diff --git a/smartsim/_core/launcher/slurm/slurmLauncher.py b/smartsim/_core/launcher/slurm/slurm_launcher.py similarity index 97% rename from smartsim/_core/launcher/slurm/slurmLauncher.py rename to smartsim/_core/launcher/slurm/slurm_launcher.py index 2e41023919..038176d937 100644 --- a/smartsim/_core/launcher/slurm/slurmLauncher.py +++ b/smartsim/_core/launcher/slurm/slurm_launcher.py @@ -40,7 +40,7 @@ SettingsBase, SrunSettings, ) -from ....status import SmartSimStatus +from ....status import JobStatus from ...config import CONFIG from ..launcher import WLMLauncher from ..step import ( @@ -52,9 +52,9 @@ SrunStep, Step, ) -from ..stepInfo import SlurmStepInfo, StepInfo -from .slurmCommands import sacct, scancel, sstat -from .slurmParser import parse_sacct, parse_sstat_nodes, parse_step_id_from_sacct +from ..step_info import SlurmStepInfo, StepInfo +from .slurm_commands import sacct, scancel, sstat +from .slurm_parser import parse_sacct, parse_sstat_nodes, parse_step_id_from_sacct logger = get_logger(__name__) @@ -213,7 +213,7 @@ def stop(self, step_name: str) -> StepInfo: raise LauncherError(f"Could not get step_info for job step {step_name}") step_info.status = ( - SmartSimStatus.STATUS_CANCELLED + JobStatus.CANCELLED ) # set status to cancelled instead of failed return step_info diff --git a/smartsim/_core/launcher/slurm/slurmParser.py b/smartsim/_core/launcher/slurm/slurm_parser.py similarity index 100% rename from smartsim/_core/launcher/slurm/slurmParser.py rename to smartsim/_core/launcher/slurm/slurm_parser.py diff --git a/smartsim/_core/launcher/step/__init__.py b/smartsim/_core/launcher/step/__init__.py index 8331a18bf8..b11e54a50d 100644 --- a/smartsim/_core/launcher/step/__init__.py +++ b/smartsim/_core/launcher/step/__init__.py @@ -24,12 +24,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .alpsStep import AprunStep -from .dragonStep import DragonBatchStep, DragonStep -from .localStep import LocalStep -from .lsfStep import BsubBatchStep, JsrunStep -from .mpiStep import MpiexecStep, MpirunStep, OrterunStep -from .pbsStep import QsubBatchStep -from .sgeStep import SgeQsubBatchStep -from .slurmStep import SbatchStep, SrunStep +from .alps_step import AprunStep +from .dragon_step import DragonBatchStep, DragonStep +from .local_step import LocalStep +from .lsf_step import BsubBatchStep, JsrunStep +from .mpi_step import MpiexecStep, MpirunStep, OrterunStep +from .pbs_step import QsubBatchStep +from .sge_step import SgeQsubBatchStep +from .slurm_step import SbatchStep, SrunStep from .step import Step diff --git a/smartsim/_core/launcher/step/alpsStep.py b/smartsim/_core/launcher/step/alps_step.py similarity index 91% rename from smartsim/_core/launcher/step/alpsStep.py rename to smartsim/_core/launcher/step/alps_step.py index eb7903af98..dc9f3bff61 100644 --- a/smartsim/_core/launcher/step/alpsStep.py +++ b/smartsim/_core/launcher/step/alps_step.py @@ -29,6 +29,7 @@ import typing as t from shlex import split as sh_split +from ....entity import Application, FSNode from ....error import AllocationError from ....log import get_logger from ....settings import AprunSettings, RunSettings, Singularity @@ -38,14 +39,16 @@ class AprunStep(Step): - def __init__(self, name: str, cwd: str, run_settings: AprunSettings) -> None: + def __init__( + self, entity: t.Union[Application, FSNode], run_settings: AprunSettings + ) -> None: """Initialize a ALPS aprun job step :param name: name of the entity to be launched :param cwd: path to launch dir :param run_settings: run settings for entity """ - super().__init__(name, cwd, run_settings) + super().__init__(entity, run_settings) self.alloc: t.Optional[str] = None if not run_settings.in_batch: self._set_alloc() @@ -74,9 +77,9 @@ def get_launch_cmd(self) -> t.List[str]: aprun_cmd.extend(self.run_settings.format_env_vars()) aprun_cmd.extend(self.run_settings.format_run_args()) - if self.run_settings.colocated_db_settings: + if self.run_settings.colocated_fs_settings: # disable cpu binding as the entrypoint will set that - # for the application and database process now + # for the application and feature store process now aprun_cmd.extend(["--cc", "none"]) # Replace the command with the entrypoint wrapper script @@ -122,15 +125,15 @@ def _build_exe(self) -> t.List[str]: if self._get_mpmd(): return self._make_mpmd() - exe = self.run_settings.exe - args = self.run_settings._exe_args # pylint: disable=protected-access + exe = self.entity.exe + args = self.entity.exe_args return exe + args def _make_mpmd(self) -> t.List[str]: """Build Aprun (MPMD) executable""" - exe = self.run_settings.exe - exe_args = self.run_settings._exe_args # pylint: disable=protected-access + exe = self.entity.exe + exe_args = self.entity.exe_args cmd = exe + exe_args for mpmd in self._get_mpmd(): diff --git a/smartsim/_core/launcher/step/dragonStep.py b/smartsim/_core/launcher/step/dragon_step.py similarity index 98% rename from smartsim/_core/launcher/step/dragonStep.py rename to smartsim/_core/launcher/step/dragon_step.py index 8583ceeb1b..63bc1e6c4b 100644 --- a/smartsim/_core/launcher/step/dragonStep.py +++ b/smartsim/_core/launcher/step/dragon_step.py @@ -30,11 +30,6 @@ import sys import typing as t -from ...._core.schemas.dragonRequests import ( - DragonRunPolicy, - DragonRunRequest, - request_registry, -) from ....error.errors import SSUnsupportedError from ....log import get_logger from ....settings import ( @@ -43,6 +38,11 @@ SbatchSettings, Singularity, ) +from ...schemas.dragon_requests import ( + DragonRunPolicy, + DragonRunRequest, + request_registry, +) from .step import Step logger = get_logger(__name__) @@ -72,7 +72,7 @@ def get_launch_cmd(self) -> t.List[str]: run_settings = self.run_settings exe_cmd = [] - if run_settings.colocated_db_settings: + if run_settings.colocated_fs_settings: # Replace the command with the entrypoint wrapper script bash = shutil.which("bash") if not bash: diff --git a/smartsim/_core/launcher/step/localStep.py b/smartsim/_core/launcher/step/local_step.py similarity index 86% rename from smartsim/_core/launcher/step/localStep.py rename to smartsim/_core/launcher/step/local_step.py index 968152a412..49666a2059 100644 --- a/smartsim/_core/launcher/step/localStep.py +++ b/smartsim/_core/launcher/step/local_step.py @@ -28,15 +28,15 @@ import shutil import typing as t -from ....settings import Singularity -from ....settings.base import RunSettings +from ....entity import Application, FSNode +from ....settings import RunSettings, Singularity from .step import Step, proxyable_launch_cmd class LocalStep(Step): - def __init__(self, name: str, cwd: str, run_settings: RunSettings): - super().__init__(name, cwd, run_settings) - self.run_settings = run_settings + def __init__(self, entity: t.Union[Application, FSNode], run_settings: RunSettings): + super().__init__(entity, run_settings) + self.run_settings = entity.run_settings self._env = self._set_env() @property @@ -54,7 +54,7 @@ def get_launch_cmd(self) -> t.List[str]: run_args = self.run_settings.format_run_args() cmd.extend(run_args) - if self.run_settings.colocated_db_settings: + if self.run_settings.colocated_fs_settings: # Replace the command with the entrypoint wrapper script if not (bash := shutil.which("bash")): raise RuntimeError("Unable to locate bash interpreter") @@ -68,9 +68,9 @@ def get_launch_cmd(self) -> t.List[str]: cmd += container._container_cmds(self.cwd) # build executable - cmd.extend(self.run_settings.exe) - if self.run_settings.exe_args: - cmd.extend(self.run_settings.exe_args) + cmd.extend(self.entity.exe) + if self.entity.exe_args: + cmd.extend(self.entity.exe_args) return cmd def _set_env(self) -> t.Dict[str, str]: diff --git a/smartsim/_core/launcher/step/lsfStep.py b/smartsim/_core/launcher/step/lsf_step.py similarity index 95% rename from smartsim/_core/launcher/step/lsfStep.py rename to smartsim/_core/launcher/step/lsf_step.py index 0cb921e19a..80583129c1 100644 --- a/smartsim/_core/launcher/step/lsfStep.py +++ b/smartsim/_core/launcher/step/lsf_step.py @@ -28,24 +28,26 @@ import shutil import typing as t +from ....entity import Application, FSNode from ....error import AllocationError from ....log import get_logger -from ....settings import BsubBatchSettings, JsrunSettings -from ....settings.base import RunSettings +from ....settings import BsubBatchSettings, JsrunSettings, RunSettings from .step import Step logger = get_logger(__name__) class BsubBatchStep(Step): - def __init__(self, name: str, cwd: str, batch_settings: BsubBatchSettings) -> None: + def __init__( + self, entity: t.Union[Application, FSNode], batch_settings: BsubBatchSettings + ) -> None: """Initialize a LSF bsub step :param name: name of the entity to launch :param cwd: path to launch dir :param batch_settings: batch settings for entity """ - super().__init__(name, cwd, batch_settings) + super().__init__(entity, batch_settings) self.step_cmds: t.List[t.List[str]] = [] self.managed = True self.batch_settings = batch_settings @@ -103,14 +105,14 @@ def _write_script(self) -> str: class JsrunStep(Step): - def __init__(self, name: str, cwd: str, run_settings: RunSettings): + def __init__(self, entity: t.Union[Application, FSNode], run_settings: RunSettings): """Initialize a LSF jsrun job step :param name: name of the entity to be launched :param cwd: path to launch dir :param run_settings: run settings for entity """ - super().__init__(name, cwd, run_settings) + super().__init__(entity, run_settings) self.alloc: t.Optional[str] = None self.managed = True self.run_settings = run_settings @@ -170,9 +172,9 @@ def get_launch_cmd(self) -> t.List[str]: jsrun_cmd.extend(self.run_settings.format_run_args()) - if self.run_settings.colocated_db_settings: + if self.run_settings.colocated_fs_settings: # disable cpu binding as the entrypoint will set that - # for the application and database process now + # for the application and feature store process now jsrun_cmd.extend(["--bind", "none"]) # Replace the command with the entrypoint wrapper script @@ -214,8 +216,8 @@ def _build_exe(self) -> t.List[str]: :return: executable list """ - exe = self.run_settings.exe - args = self.run_settings._exe_args # pylint: disable=protected-access + exe = self.entity.exe + args = self.entity.exe_args if self._get_mpmd(): erf_file = self.get_step_file(ending=".mpmd") diff --git a/smartsim/_core/launcher/step/mpiStep.py b/smartsim/_core/launcher/step/mpi_step.py similarity index 86% rename from smartsim/_core/launcher/step/mpiStep.py rename to smartsim/_core/launcher/step/mpi_step.py index 9ae3af2fcd..0eb2f34fdb 100644 --- a/smartsim/_core/launcher/step/mpiStep.py +++ b/smartsim/_core/launcher/step/mpi_step.py @@ -29,17 +29,19 @@ import typing as t from shlex import split as sh_split +from ....entity import Application, FSNode from ....error import AllocationError, SmartSimError from ....log import get_logger -from ....settings import MpiexecSettings, MpirunSettings, OrterunSettings -from ....settings.base import RunSettings +from ....settings import MpiexecSettings, MpirunSettings, OrterunSettings, RunSettings from .step import Step, proxyable_launch_cmd logger = get_logger(__name__) class _BaseMPIStep(Step): - def __init__(self, name: str, cwd: str, run_settings: RunSettings) -> None: + def __init__( + self, entity: t.Union[Application, FSNode], run_settings: RunSettings + ) -> None: """Initialize a job step conforming to the MPI standard :param name: name of the entity to be launched @@ -47,7 +49,7 @@ def __init__(self, name: str, cwd: str, run_settings: RunSettings) -> None: :param run_settings: run settings for entity """ - super().__init__(name, cwd, run_settings) + super().__init__(entity, run_settings) self.alloc: t.Optional[str] = None if not run_settings.in_batch: @@ -73,9 +75,9 @@ def get_launch_cmd(self) -> t.List[str]: # add mpi settings to command mpi_cmd.extend(self.run_settings.format_run_args()) - if self.run_settings.colocated_db_settings: + if self.run_settings.colocated_fs_settings: # disable cpu binding as the entrypoint will set that - # for the application and database process now + # for the application and feature store process now # mpi_cmd.extend(["--cpu-bind", "none"]) # Replace the command with the entrypoint wrapper script @@ -133,14 +135,14 @@ def _build_exe(self) -> t.List[str]: if self._get_mpmd(): return self._make_mpmd() - exe = self.run_settings.exe - args = self.run_settings._exe_args # pylint: disable=protected-access + exe = self.entity.exe + args = self.entity.exe_args return exe + args def _make_mpmd(self) -> t.List[str]: """Build mpiexec (MPMD) executable""" - exe = self.run_settings.exe - args = self.run_settings._exe_args # pylint: disable=protected-access + exe = self.entity.exe + args = self.entity.exe_args cmd = exe + args for mpmd in self._get_mpmd(): @@ -148,14 +150,16 @@ def _make_mpmd(self) -> t.List[str]: cmd += mpmd.format_run_args() cmd += mpmd.format_env_vars() cmd += mpmd.exe - cmd += mpmd._exe_args # pylint: disable=protected-access + cmd += mpmd.exe_args cmd = sh_split(" ".join(cmd)) return cmd class MpiexecStep(_BaseMPIStep): - def __init__(self, name: str, cwd: str, run_settings: MpiexecSettings) -> None: + def __init__( + self, entity: t.Union[Application, FSNode], run_settings: MpiexecSettings + ) -> None: """Initialize an mpiexec job step :param name: name of the entity to be launched @@ -165,11 +169,13 @@ def __init__(self, name: str, cwd: str, run_settings: MpiexecSettings) -> None: application """ - super().__init__(name, cwd, run_settings) + super().__init__(entity, run_settings) class MpirunStep(_BaseMPIStep): - def __init__(self, name: str, cwd: str, run_settings: MpirunSettings) -> None: + def __init__( + self, entity: t.Union[Application, FSNode], run_settings: MpirunSettings + ) -> None: """Initialize an mpirun job step :param name: name of the entity to be launched @@ -179,11 +185,13 @@ def __init__(self, name: str, cwd: str, run_settings: MpirunSettings) -> None: application """ - super().__init__(name, cwd, run_settings) + super().__init__(entity, run_settings) class OrterunStep(_BaseMPIStep): - def __init__(self, name: str, cwd: str, run_settings: OrterunSettings) -> None: + def __init__( + self, entity: t.Union[Application, FSNode], run_settings: OrterunSettings + ) -> None: """Initialize an orterun job step :param name: name of the entity to be launched @@ -193,4 +201,4 @@ def __init__(self, name: str, cwd: str, run_settings: OrterunSettings) -> None: application """ - super().__init__(name, cwd, run_settings) + super().__init__(entity, run_settings) diff --git a/smartsim/_core/launcher/step/pbsStep.py b/smartsim/_core/launcher/step/pbs_step.py similarity index 94% rename from smartsim/_core/launcher/step/pbsStep.py rename to smartsim/_core/launcher/step/pbs_step.py index 82a91aaa43..b9e3b3f0c4 100644 --- a/smartsim/_core/launcher/step/pbsStep.py +++ b/smartsim/_core/launcher/step/pbs_step.py @@ -26,6 +26,7 @@ import typing as t +from ....entity import Application, FSNode from ....log import get_logger from ....settings import QsubBatchSettings from .step import Step @@ -34,14 +35,16 @@ class QsubBatchStep(Step): - def __init__(self, name: str, cwd: str, batch_settings: QsubBatchSettings) -> None: + def __init__( + self, entity: t.Union[Application, FSNode], batch_settings: QsubBatchSettings + ) -> None: """Initialize a PBSpro qsub step :param name: name of the entity to launch :param cwd: path to launch dir :param batch_settings: batch settings for entity """ - super().__init__(name, cwd, batch_settings) + super().__init__(entity, batch_settings) self.step_cmds: t.List[t.List[str]] = [] self.managed = True self.batch_settings = batch_settings diff --git a/smartsim/_core/launcher/step/sgeStep.py b/smartsim/_core/launcher/step/sge_step.py similarity index 100% rename from smartsim/_core/launcher/step/sgeStep.py rename to smartsim/_core/launcher/step/sge_step.py diff --git a/smartsim/_core/launcher/step/slurmStep.py b/smartsim/_core/launcher/step/slurm_step.py similarity index 89% rename from smartsim/_core/launcher/step/slurmStep.py rename to smartsim/_core/launcher/step/slurm_step.py index 83f39cf093..90d457f1b3 100644 --- a/smartsim/_core/launcher/step/slurmStep.py +++ b/smartsim/_core/launcher/step/slurm_step.py @@ -29,6 +29,8 @@ import typing as t from shlex import split as sh_split +from ....builders import Ensemble +from ....entity import Application, FSNode from ....error import AllocationError from ....log import get_logger from ....settings import RunSettings, SbatchSettings, Singularity, SrunSettings @@ -38,14 +40,16 @@ class SbatchStep(Step): - def __init__(self, name: str, cwd: str, batch_settings: SbatchSettings) -> None: + def __init__( + self, entity: t.Union[Application, FSNode], batch_settings: SbatchSettings + ) -> None: """Initialize a Slurm Sbatch step :param name: name of the entity to launch :param cwd: path to launch dir :param batch_settings: batch settings for entity """ - super().__init__(name, cwd, batch_settings) + super().__init__(entity, batch_settings) self.step_cmds: t.List[t.List[str]] = [] self.managed = True self.batch_settings = batch_settings @@ -98,16 +102,19 @@ def _write_script(self) -> str: class SrunStep(Step): - def __init__(self, name: str, cwd: str, run_settings: SrunSettings) -> None: + def __init__( + self, entity: t.Union[Application, FSNode], run_settings: SrunSettings + ) -> None: """Initialize a srun job step :param name: name of the entity to be launched :param cwd: path to launch dir :param run_settings: run settings for entity """ - super().__init__(name, cwd, run_settings) + super().__init__(entity, run_settings) self.alloc: t.Optional[str] = None self.managed = True + self.entity = entity self.run_settings = run_settings if not self.run_settings.in_batch: self._set_alloc() @@ -140,7 +147,7 @@ def get_launch_cmd(self) -> t.List[str]: srun_cmd += self.run_settings.format_run_args() - if self.run_settings.colocated_db_settings: + if self.run_settings.colocated_fs_settings: # Replace the command with the entrypoint wrapper script bash = shutil.which("bash") if not bash: @@ -184,11 +191,11 @@ def _get_mpmd(self) -> t.List[RunSettings]: return self.run_settings.mpmd @staticmethod - def _get_exe_args_list(run_setting: RunSettings) -> t.List[str]: + def _get_exe_args_list(entity: t.Union[Application, FSNode]) -> t.List[str]: """Convenience function to encapsulate checking the runsettings.exe_args type to always return a list """ - exe_args = run_setting.exe_args + exe_args = entity.exe_args args: t.List[str] = exe_args if isinstance(exe_args, list) else [exe_args] return args @@ -200,14 +207,17 @@ def _build_exe(self) -> t.List[str]: if self._get_mpmd(): return self._make_mpmd() - exe = self.run_settings.exe - args = self._get_exe_args_list(self.run_settings) + exe = self.entity.exe + args = self._get_exe_args_list(self.entity) return exe + args + # There is an issue here, exe and exe_args are no longer attached to the runsettings + # This functions is looping through the list of run_settings.mpmd and build the variable + # cmd def _make_mpmd(self) -> t.List[str]: """Build Slurm multi-prog (MPMD) executable""" - exe = self.run_settings.exe - args = self._get_exe_args_list(self.run_settings) + exe = self.entity.exe + args = self._get_exe_args_list(self.entity) cmd = exe + args compound_env_vars = [] diff --git a/smartsim/_core/launcher/step/step.py b/smartsim/_core/launcher/step/step.py index 171254e32a..b5e79a3638 100644 --- a/smartsim/_core/launcher/step/step.py +++ b/smartsim/_core/launcher/step/step.py @@ -38,19 +38,27 @@ from smartsim._core.config import CONFIG from smartsim.error.errors import SmartSimError, UnproxyableStepError +from ....builders import Ensemble +from ....entity import Application, FSNode from ....log import get_logger -from ....settings.base import RunSettings, SettingsBase +from ....settings import RunSettings, SettingsBase from ...utils.helpers import encode_cmd, get_base_36_repr -from ..colocated import write_colocated_launch_script logger = get_logger(__name__) +def write_colocated_launch_script(): + pass + + class Step: - def __init__(self, name: str, cwd: str, step_settings: SettingsBase) -> None: - self.name = self._create_unique_name(name) - self.entity_name = name - self.cwd = cwd + def __init__( + self, entity: t.Union[Application, FSNode], step_settings: SettingsBase + ) -> None: + self.name = self._create_unique_name(entity.name) + self.entity = entity + self.entity_name = entity.name + self.cwd = entity.path self.managed = False self.step_settings = copy.deepcopy(step_settings) self.meta: t.Dict[str, str] = {} @@ -106,20 +114,20 @@ def get_colocated_launch_script(self) -> str: ) makedirs(osp.dirname(script_path), exist_ok=True) - db_settings = {} + fs_settings = {} if isinstance(self.step_settings, RunSettings): - db_settings = self.step_settings.colocated_db_settings or {} + fs_settings = self.step_settings.colocated_fs_settings or {} - # db log file causes write contention and kills performance so by + # fs log file causes write contention and kills performance so by # default we turn off logging unless user specified debug=True - if db_settings.get("debug", False): - db_log_file = self.get_step_file(ending="-db.log") + if fs_settings.get("debug", False): + fs_log_file = self.get_step_file(ending="-fs.log") else: - db_log_file = "/dev/null" + fs_log_file = "/dev/null" # write the colocated wrapper shell script to the directory for this # entity currently being prepped to launch - write_colocated_launch_script(script_path, db_log_file, db_settings) + write_colocated_launch_script(script_path, fs_log_file, fs_settings) return script_path # pylint: disable=no-self-use diff --git a/smartsim/_core/launcher/stepInfo.py b/smartsim/_core/launcher/step_info.py similarity index 52% rename from smartsim/_core/launcher/stepInfo.py rename to smartsim/_core/launcher/step_info.py index b68527cb30..4fa307a8f9 100644 --- a/smartsim/_core/launcher/stepInfo.py +++ b/smartsim/_core/launcher/step_info.py @@ -28,13 +28,13 @@ import psutil -from ...status import SmartSimStatus +from ...status import JobStatus class StepInfo: def __init__( self, - status: SmartSimStatus, + status: JobStatus, launcher_status: str = "", returncode: t.Optional[int] = None, output: t.Optional[str] = None, @@ -53,44 +53,42 @@ def __str__(self) -> str: return info_str @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> t.Dict[str, JobStatus]: raise NotImplementedError def _get_smartsim_status( self, status: str, returncode: t.Optional[int] = None - ) -> SmartSimStatus: + ) -> JobStatus: """ Map the status of the WLM step to a smartsim-specific status """ - if any(ss_status.value == status for ss_status in SmartSimStatus): - return SmartSimStatus(status) + if any(ss_status.value == status for ss_status in JobStatus): + return JobStatus(status) if status in self.mapping and returncode in [None, 0]: return self.mapping[status] - return SmartSimStatus.STATUS_FAILED + return JobStatus.FAILED class UnmanagedStepInfo(StepInfo): @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> t.Dict[str, JobStatus]: # see https://github.com/giampaolo/psutil/blob/master/psutil/_pslinux.py # see https://github.com/giampaolo/psutil/blob/master/psutil/_common.py return { - psutil.STATUS_RUNNING: SmartSimStatus.STATUS_RUNNING, - psutil.STATUS_SLEEPING: ( - SmartSimStatus.STATUS_RUNNING - ), # sleeping thread is still alive - psutil.STATUS_WAKING: SmartSimStatus.STATUS_RUNNING, - psutil.STATUS_DISK_SLEEP: SmartSimStatus.STATUS_RUNNING, - psutil.STATUS_DEAD: SmartSimStatus.STATUS_FAILED, - psutil.STATUS_TRACING_STOP: SmartSimStatus.STATUS_PAUSED, - psutil.STATUS_WAITING: SmartSimStatus.STATUS_PAUSED, - psutil.STATUS_STOPPED: SmartSimStatus.STATUS_PAUSED, - psutil.STATUS_LOCKED: SmartSimStatus.STATUS_PAUSED, - psutil.STATUS_PARKED: SmartSimStatus.STATUS_PAUSED, - psutil.STATUS_IDLE: SmartSimStatus.STATUS_PAUSED, - psutil.STATUS_ZOMBIE: SmartSimStatus.STATUS_COMPLETED, + psutil.STATUS_RUNNING: JobStatus.RUNNING, + psutil.STATUS_SLEEPING: JobStatus.RUNNING, # sleeping thread is still alive + psutil.STATUS_WAKING: JobStatus.RUNNING, + psutil.STATUS_DISK_SLEEP: JobStatus.RUNNING, + psutil.STATUS_DEAD: JobStatus.FAILED, + psutil.STATUS_TRACING_STOP: JobStatus.PAUSED, + psutil.STATUS_WAITING: JobStatus.PAUSED, + psutil.STATUS_STOPPED: JobStatus.PAUSED, + psutil.STATUS_LOCKED: JobStatus.PAUSED, + psutil.STATUS_PARKED: JobStatus.PAUSED, + psutil.STATUS_IDLE: JobStatus.PAUSED, + psutil.STATUS_ZOMBIE: JobStatus.COMPLETED, } def __init__( @@ -109,30 +107,30 @@ def __init__( class SlurmStepInfo(StepInfo): # cov-slurm # see https://slurm.schedmd.com/squeue.html#lbAG mapping = { - "RUNNING": SmartSimStatus.STATUS_RUNNING, - "CONFIGURING": SmartSimStatus.STATUS_RUNNING, - "STAGE_OUT": SmartSimStatus.STATUS_RUNNING, - "COMPLETED": SmartSimStatus.STATUS_COMPLETED, - "DEADLINE": SmartSimStatus.STATUS_COMPLETED, - "TIMEOUT": SmartSimStatus.STATUS_COMPLETED, - "BOOT_FAIL": SmartSimStatus.STATUS_FAILED, - "FAILED": SmartSimStatus.STATUS_FAILED, - "NODE_FAIL": SmartSimStatus.STATUS_FAILED, - "OUT_OF_MEMORY": SmartSimStatus.STATUS_FAILED, - "CANCELLED": SmartSimStatus.STATUS_CANCELLED, - "CANCELLED+": SmartSimStatus.STATUS_CANCELLED, - "REVOKED": SmartSimStatus.STATUS_CANCELLED, - "PENDING": SmartSimStatus.STATUS_PAUSED, - "PREEMPTED": SmartSimStatus.STATUS_PAUSED, - "RESV_DEL_HOLD": SmartSimStatus.STATUS_PAUSED, - "REQUEUE_FED": SmartSimStatus.STATUS_PAUSED, - "REQUEUE_HOLD": SmartSimStatus.STATUS_PAUSED, - "REQUEUED": SmartSimStatus.STATUS_PAUSED, - "RESIZING": SmartSimStatus.STATUS_PAUSED, - "SIGNALING": SmartSimStatus.STATUS_PAUSED, - "SPECIAL_EXIT": SmartSimStatus.STATUS_PAUSED, - "STOPPED": SmartSimStatus.STATUS_PAUSED, - "SUSPENDED": SmartSimStatus.STATUS_PAUSED, + "RUNNING": JobStatus.RUNNING, + "CONFIGURING": JobStatus.RUNNING, + "STAGE_OUT": JobStatus.RUNNING, + "COMPLETED": JobStatus.COMPLETED, + "DEADLINE": JobStatus.COMPLETED, + "TIMEOUT": JobStatus.COMPLETED, + "BOOT_FAIL": JobStatus.FAILED, + "FAILED": JobStatus.FAILED, + "NODE_FAIL": JobStatus.FAILED, + "OUT_OF_MEMORY": JobStatus.FAILED, + "CANCELLED": JobStatus.CANCELLED, + "CANCELLED+": JobStatus.CANCELLED, + "REVOKED": JobStatus.CANCELLED, + "PENDING": JobStatus.PAUSED, + "PREEMPTED": JobStatus.PAUSED, + "RESV_DEL_HOLD": JobStatus.PAUSED, + "REQUEUE_FED": JobStatus.PAUSED, + "REQUEUE_HOLD": JobStatus.PAUSED, + "REQUEUED": JobStatus.PAUSED, + "RESIZING": JobStatus.PAUSED, + "SIGNALING": JobStatus.PAUSED, + "SPECIAL_EXIT": JobStatus.PAUSED, + "STOPPED": JobStatus.PAUSED, + "SUSPENDED": JobStatus.PAUSED, } def __init__( @@ -150,27 +148,25 @@ def __init__( class PBSStepInfo(StepInfo): # cov-pbs @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> t.Dict[str, JobStatus]: # pylint: disable-next=line-too-long # see http://nusc.nsu.ru/wiki/lib/exe/fetch.php/doc/pbs/PBSReferenceGuide19.2.1.pdf#M11.9.90788.PBSHeading1.81.Job.States return { - "R": SmartSimStatus.STATUS_RUNNING, - "B": SmartSimStatus.STATUS_RUNNING, - "H": SmartSimStatus.STATUS_PAUSED, + "R": JobStatus.RUNNING, + "B": JobStatus.RUNNING, + "H": JobStatus.PAUSED, "M": ( - SmartSimStatus.STATUS_PAUSED + JobStatus.PAUSED ), # Actually means that it was moved to another server, # TODO: understand what this implies - "Q": SmartSimStatus.STATUS_PAUSED, - "S": SmartSimStatus.STATUS_PAUSED, - "T": ( - SmartSimStatus.STATUS_PAUSED - ), # This means in transition, see above for comment - "U": SmartSimStatus.STATUS_PAUSED, - "W": SmartSimStatus.STATUS_PAUSED, - "E": SmartSimStatus.STATUS_COMPLETED, - "F": SmartSimStatus.STATUS_COMPLETED, - "X": SmartSimStatus.STATUS_COMPLETED, + "Q": JobStatus.PAUSED, + "S": JobStatus.PAUSED, + "T": JobStatus.PAUSED, # This means in transition, see above for comment + "U": JobStatus.PAUSED, + "W": JobStatus.PAUSED, + "E": JobStatus.COMPLETED, + "F": JobStatus.COMPLETED, + "X": JobStatus.COMPLETED, } def __init__( @@ -183,13 +179,11 @@ def __init__( if status == "NOTFOUND": if returncode is not None: smartsim_status = ( - SmartSimStatus.STATUS_COMPLETED - if returncode == 0 - else SmartSimStatus.STATUS_FAILED + JobStatus.COMPLETED if returncode == 0 else JobStatus.FAILED ) else: # if PBS job history isnt available, and job isnt in queue - smartsim_status = SmartSimStatus.STATUS_COMPLETED + smartsim_status = JobStatus.COMPLETED returncode = 0 else: smartsim_status = self._get_smartsim_status(status) @@ -200,16 +194,16 @@ def __init__( class LSFBatchStepInfo(StepInfo): # cov-lsf @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> t.Dict[str, JobStatus]: # pylint: disable-next=line-too-long # see https://www.ibm.com/docs/en/spectrum-lsf/10.1.0?topic=execution-about-job-states return { - "RUN": SmartSimStatus.STATUS_RUNNING, - "PSUSP": SmartSimStatus.STATUS_PAUSED, - "USUSP": SmartSimStatus.STATUS_PAUSED, - "SSUSP": SmartSimStatus.STATUS_PAUSED, - "PEND": SmartSimStatus.STATUS_PAUSED, - "DONE": SmartSimStatus.STATUS_COMPLETED, + "RUN": JobStatus.RUNNING, + "PSUSP": JobStatus.PAUSED, + "USUSP": JobStatus.PAUSED, + "SSUSP": JobStatus.PAUSED, + "PEND": JobStatus.PAUSED, + "DONE": JobStatus.COMPLETED, } def __init__( @@ -222,12 +216,10 @@ def __init__( if status == "NOTFOUND": if returncode is not None: smartsim_status = ( - SmartSimStatus.STATUS_COMPLETED - if returncode == 0 - else SmartSimStatus.STATUS_FAILED + JobStatus.COMPLETED if returncode == 0 else JobStatus.FAILED ) else: - smartsim_status = SmartSimStatus.STATUS_COMPLETED + smartsim_status = JobStatus.COMPLETED returncode = 0 else: smartsim_status = self._get_smartsim_status(status) @@ -238,14 +230,14 @@ def __init__( class LSFJsrunStepInfo(StepInfo): # cov-lsf @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> t.Dict[str, JobStatus]: # pylint: disable-next=line-too-long # see https://www.ibm.com/docs/en/spectrum-lsf/10.1.0?topic=execution-about-job-states return { - "Killed": SmartSimStatus.STATUS_COMPLETED, - "Running": SmartSimStatus.STATUS_RUNNING, - "Queued": SmartSimStatus.STATUS_PAUSED, - "Complete": SmartSimStatus.STATUS_COMPLETED, + "Killed": JobStatus.COMPLETED, + "Running": JobStatus.RUNNING, + "Queued": JobStatus.PAUSED, + "Complete": JobStatus.COMPLETED, } def __init__( @@ -258,12 +250,10 @@ def __init__( if status == "NOTFOUND": if returncode is not None: smartsim_status = ( - SmartSimStatus.STATUS_COMPLETED - if returncode == 0 - else SmartSimStatus.STATUS_FAILED + JobStatus.COMPLETED if returncode == 0 else JobStatus.FAILED ) else: - smartsim_status = SmartSimStatus.STATUS_COMPLETED + smartsim_status = JobStatus.COMPLETED returncode = 0 else: smartsim_status = self._get_smartsim_status(status, returncode) @@ -274,51 +264,51 @@ def __init__( class SGEStepInfo(StepInfo): # cov-pbs @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> t.Dict[str, JobStatus]: # pylint: disable-next=line-too-long # see https://manpages.ubuntu.com/manpages/jammy/man5/sge_status.5.html return { # Running states - "r": SmartSimStatus.STATUS_RUNNING, - "hr": SmartSimStatus.STATUS_RUNNING, - "t": SmartSimStatus.STATUS_RUNNING, - "Rr": SmartSimStatus.STATUS_RUNNING, - "Rt": SmartSimStatus.STATUS_RUNNING, + "r": JobStatus.RUNNING, + "hr": JobStatus.RUNNING, + "t": JobStatus.RUNNING, + "Rr": JobStatus.RUNNING, + "Rt": JobStatus.RUNNING, # Queued states - "qw": SmartSimStatus.STATUS_QUEUED, - "Rq": SmartSimStatus.STATUS_QUEUED, - "hqw": SmartSimStatus.STATUS_QUEUED, - "hRwq": SmartSimStatus.STATUS_QUEUED, + "qw": JobStatus.QUEUED, + "Rq": JobStatus.QUEUED, + "hqw": JobStatus.QUEUED, + "hRwq": JobStatus.QUEUED, # Paused states - "s": SmartSimStatus.STATUS_PAUSED, - "ts": SmartSimStatus.STATUS_PAUSED, - "S": SmartSimStatus.STATUS_PAUSED, - "tS": SmartSimStatus.STATUS_PAUSED, - "T": SmartSimStatus.STATUS_PAUSED, - "tT": SmartSimStatus.STATUS_PAUSED, - "Rs": SmartSimStatus.STATUS_PAUSED, - "Rts": SmartSimStatus.STATUS_PAUSED, - "RS": SmartSimStatus.STATUS_PAUSED, - "RtS": SmartSimStatus.STATUS_PAUSED, - "RT": SmartSimStatus.STATUS_PAUSED, - "RtT": SmartSimStatus.STATUS_PAUSED, + "s": JobStatus.PAUSED, + "ts": JobStatus.PAUSED, + "S": JobStatus.PAUSED, + "tS": JobStatus.PAUSED, + "T": JobStatus.PAUSED, + "tT": JobStatus.PAUSED, + "Rs": JobStatus.PAUSED, + "Rts": JobStatus.PAUSED, + "RS": JobStatus.PAUSED, + "RtS": JobStatus.PAUSED, + "RT": JobStatus.PAUSED, + "RtT": JobStatus.PAUSED, # Failed states - "Eqw": SmartSimStatus.STATUS_FAILED, - "Ehqw": SmartSimStatus.STATUS_FAILED, - "EhRqw": SmartSimStatus.STATUS_FAILED, + "Eqw": JobStatus.FAILED, + "Ehqw": JobStatus.FAILED, + "EhRqw": JobStatus.FAILED, # Finished states - "z": SmartSimStatus.STATUS_COMPLETED, + "z": JobStatus.COMPLETED, # Cancelled - "dr": SmartSimStatus.STATUS_CANCELLED, - "dt": SmartSimStatus.STATUS_CANCELLED, - "dRr": SmartSimStatus.STATUS_CANCELLED, - "dRt": SmartSimStatus.STATUS_CANCELLED, - "ds": SmartSimStatus.STATUS_CANCELLED, - "dS": SmartSimStatus.STATUS_CANCELLED, - "dT": SmartSimStatus.STATUS_CANCELLED, - "dRs": SmartSimStatus.STATUS_CANCELLED, - "dRS": SmartSimStatus.STATUS_CANCELLED, - "dRT": SmartSimStatus.STATUS_CANCELLED, + "dr": JobStatus.CANCELLED, + "dt": JobStatus.CANCELLED, + "dRr": JobStatus.CANCELLED, + "dRt": JobStatus.CANCELLED, + "ds": JobStatus.CANCELLED, + "dS": JobStatus.CANCELLED, + "dT": JobStatus.CANCELLED, + "dRs": JobStatus.CANCELLED, + "dRS": JobStatus.CANCELLED, + "dRT": JobStatus.CANCELLED, } def __init__( @@ -331,13 +321,11 @@ def __init__( if status == "NOTFOUND": if returncode is not None: smartsim_status = ( - SmartSimStatus.STATUS_COMPLETED - if returncode == 0 - else SmartSimStatus.STATUS_FAILED + JobStatus.COMPLETED if returncode == 0 else JobStatus.FAILED ) else: # if PBS job history is not available, and job is not in queue - smartsim_status = SmartSimStatus.STATUS_COMPLETED + smartsim_status = JobStatus.COMPLETED returncode = 0 else: smartsim_status = self._get_smartsim_status(status) diff --git a/smartsim/_core/launcher/stepMapping.py b/smartsim/_core/launcher/step_mapping.py similarity index 100% rename from smartsim/_core/launcher/stepMapping.py rename to smartsim/_core/launcher/step_mapping.py diff --git a/smartsim/_core/launcher/taskManager.py b/smartsim/_core/launcher/task_manager.py similarity index 100% rename from smartsim/_core/launcher/taskManager.py rename to smartsim/_core/launcher/task_manager.py diff --git a/smartsim/_core/launcher/util/launcherUtil.py b/smartsim/_core/launcher/util/launcher_util.py similarity index 100% rename from smartsim/_core/launcher/util/launcherUtil.py rename to smartsim/_core/launcher/util/launcher_util.py diff --git a/smartsim/_core/mli/comm/channel/dragon_fli.py b/smartsim/_core/mli/comm/channel/dragon_fli.py index 5fb0790a84..01849247cd 100644 --- a/smartsim/_core/mli/comm/channel/dragon_fli.py +++ b/smartsim/_core/mli/comm/channel/dragon_fli.py @@ -25,7 +25,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # isort: off -from dragon import fli + +import dragon +import dragon.fli as fli from dragon.channels import Channel # isort: on diff --git a/smartsim/_core/mli/infrastructure/control/dragon_util.py b/smartsim/_core/mli/infrastructure/control/dragon_util.py new file mode 100644 index 0000000000..93bae64e69 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/dragon_util.py @@ -0,0 +1,79 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import os +import socket +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +# isort: off + +import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc +import dragon.native.process as dragon_process + +# isort: on + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +def function_as_dragon_proc( + entrypoint_fn: t.Callable[[t.Any], None], + args: t.List[t.Any], + cpu_affinity: t.List[int], + gpu_affinity: t.List[int], +) -> dragon_process.Process: + """Execute a function as an independent dragon process. + + :param entrypoint_fn: The function to execute + :param args: The arguments for the entrypoint function + :param cpu_affinity: The cpu affinity for the process + :param gpu_affinity: The gpu affinity for the process + :returns: The dragon process handle + """ + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + return dragon_process.Process( + target=entrypoint_fn, + args=args, + cwd=os.getcwd(), + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) diff --git a/smartsim/_core/schemas/__init__.py b/smartsim/_core/schemas/__init__.py index d7ee9d83d8..54ae3947de 100644 --- a/smartsim/_core/schemas/__init__.py +++ b/smartsim/_core/schemas/__init__.py @@ -24,8 +24,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .dragonRequests import * -from .dragonResponses import * +from .dragon_requests import * +from .dragon_responses import * __all__ = [ "DragonRequest", diff --git a/smartsim/_core/schemas/dragonRequests.py b/smartsim/_core/schemas/dragon_requests.py similarity index 100% rename from smartsim/_core/schemas/dragonRequests.py rename to smartsim/_core/schemas/dragon_requests.py diff --git a/smartsim/_core/schemas/dragonResponses.py b/smartsim/_core/schemas/dragon_responses.py similarity index 96% rename from smartsim/_core/schemas/dragonResponses.py rename to smartsim/_core/schemas/dragon_responses.py index 3c5c30a103..1a6507db41 100644 --- a/smartsim/_core/schemas/dragonResponses.py +++ b/smartsim/_core/schemas/dragon_responses.py @@ -29,7 +29,7 @@ from pydantic import BaseModel, Field import smartsim._core.schemas.utils as _utils -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # Black and Pylint disagree about where to put the `...` # pylint: disable=multiple-statements @@ -51,7 +51,7 @@ class DragonUpdateStatusResponse(DragonResponse): # status is a dict: {step_id: (is_alive, returncode)} statuses: t.Mapping[ t.Annotated[str, Field(min_length=1)], - t.Tuple[SmartSimStatus, t.Optional[t.List[int]]], + t.Tuple[JobStatus, t.Optional[t.List[int]]], ] = {} diff --git a/tests/__init__.py b/smartsim/_core/shell/__init__.py similarity index 100% rename from tests/__init__.py rename to smartsim/_core/shell/__init__.py diff --git a/smartsim/_core/shell/shell_launcher.py b/smartsim/_core/shell/shell_launcher.py new file mode 100644 index 0000000000..9f88d0545c --- /dev/null +++ b/smartsim/_core/shell/shell_launcher.py @@ -0,0 +1,268 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from __future__ import annotations + +import io +import pathlib +import subprocess as sp +import typing as t + +import psutil + +from smartsim._core.arguments.shell import ShellLaunchArguments +from smartsim._core.dispatch import EnvironMappingType, FormatterType, WorkingDirectory +from smartsim._core.utils import helpers +from smartsim._core.utils.launcher import create_job_id +from smartsim.error import errors +from smartsim.log import get_logger +from smartsim.settings.arguments.launch_arguments import LaunchArguments +from smartsim.status import JobStatus +from smartsim.types import LaunchedJobID + +if t.TYPE_CHECKING: + from typing_extensions import Self + + from smartsim.experiment import Experiment + +logger = get_logger(__name__) + + +class ShellLauncherCommand(t.NamedTuple): + env: EnvironMappingType + path: pathlib.Path + stdout: io.TextIOWrapper | int + stderr: io.TextIOWrapper | int + command_tuple: t.Sequence[str] + + +def make_shell_format_fn( + run_command: str | None, +) -> FormatterType[ShellLaunchArguments, ShellLauncherCommand]: + """A function that builds a function that formats a `LaunchArguments` as a + shell executable sequence of strings for a given launching utility. + + Example usage: + + .. highlight:: python + .. code-block:: python + + echo_hello_world: ExecutableProtocol = ... + env = {} + slurm_args: SlurmLaunchArguments = ... + slurm_args.set_nodes(3) + + as_srun_command = make_shell_format_fn("srun") + fmt_cmd = as_srun_command(slurm_args, echo_hello_world, env) + print(list(fmt_cmd)) + # prints: "['srun', '--nodes=3', '--', 'echo', 'Hello World!']" + + .. note:: + This function was/is a kind of slap-dash implementation, and is likely + to change or be removed entierely as more functionality is added to the + shell launcher. Use with caution and at your own risk! + + :param run_command: Name or path of the launching utility to invoke with + the arguments. + :returns: A function to format an arguments, an executable, and an + environment as a shell launchable sequence for strings. + """ + + def impl( + args: ShellLaunchArguments, + exe: t.Sequence[str], + path: WorkingDirectory, + env: EnvironMappingType, + stdout_path: pathlib.Path, + stderr_path: pathlib.Path, + ) -> ShellLauncherCommand: + command_tuple = ( + ( + run_command, + *(args.format_launch_args() or ()), + "--", + *exe, + ) + if run_command is not None + else exe + ) + # pylint: disable-next=consider-using-with + return ShellLauncherCommand( + env, pathlib.Path(path), open(stdout_path), open(stderr_path), command_tuple + ) + + return impl + + +class ShellLauncher: + """A launcher for launching/tracking local shell commands""" + + def __init__(self) -> None: + """Initialize a new shell launcher.""" + self._launched: dict[LaunchedJobID, sp.Popen[bytes]] = {} + + def check_popen_inputs(self, shell_command: ShellLauncherCommand) -> None: + """Validate that the contents of a shell command are valid. + + :param shell_command: The command to validate + :raises ValueError: If the command is not valid + """ + if not shell_command.path.exists(): + raise ValueError("Please provide a valid path to ShellLauncherCommand.") + + def start(self, shell_command: ShellLauncherCommand) -> LaunchedJobID: + """Have the shell launcher start and track the progress of a new + subprocess. + + :param shell_command: The template of a subprocess to start. + :returns: An id to reference the process for status. + """ + self.check_popen_inputs(shell_command) + id_ = create_job_id() + exe, *rest = shell_command.command_tuple + expanded_exe = helpers.expand_exe_path(exe) + # pylint: disable-next=consider-using-with + self._launched[id_] = sp.Popen( + (expanded_exe, *rest), + cwd=shell_command.path, + env={k: v for k, v in shell_command.env.items() if v is not None}, + stdout=shell_command.stdout, + stderr=shell_command.stderr, + ) + return id_ + + def _get_proc_from_job_id(self, id_: LaunchedJobID, /) -> sp.Popen[bytes]: + """Given an issued job id, return the process represented by that id. + + :param id_: The launched job id of the process + :raises: errors.LauncherJobNotFound: The id could not be mapped to a + process. This usually means that the provided id was not issued by + this launcher instance. + :returns: The process that the shell launcher started and represented + by the issued id. + """ + if (proc := self._launched.get(id_)) is None: + msg = f"Launcher `{self}` has not launched a job with id `{id_}`" + raise errors.LauncherJobNotFound(msg) + return proc + + def get_status( + self, *launched_ids: LaunchedJobID + ) -> t.Mapping[LaunchedJobID, JobStatus]: + """Take a collection of job ids and return the status of the + corresponding processes started by the shell launcher. + + :param launched_ids: A collection of ids of the launched jobs to get + the statuses of. + :returns: A mapping of ids for jobs to stop to their reported status. + """ + return {id_: self._get_status(id_) for id_ in launched_ids} + + def _get_status(self, id_: LaunchedJobID, /) -> JobStatus: + """Given an issued job id, return the process represented by that id + + :param id_: The launched job id of the process to get the status of. + :returns: The status of that process represented by the given id. + """ + proc = self._get_proc_from_job_id(id_) + ret_code = proc.poll() + if ret_code is None: + status = psutil.Process(proc.pid).status() + return { + psutil.STATUS_RUNNING: JobStatus.RUNNING, + psutil.STATUS_SLEEPING: JobStatus.RUNNING, + psutil.STATUS_WAKING: JobStatus.RUNNING, + psutil.STATUS_DISK_SLEEP: JobStatus.RUNNING, + psutil.STATUS_DEAD: JobStatus.FAILED, + psutil.STATUS_TRACING_STOP: JobStatus.PAUSED, + psutil.STATUS_WAITING: JobStatus.PAUSED, + psutil.STATUS_STOPPED: JobStatus.PAUSED, + psutil.STATUS_LOCKED: JobStatus.PAUSED, + psutil.STATUS_PARKED: JobStatus.PAUSED, + psutil.STATUS_IDLE: JobStatus.PAUSED, + psutil.STATUS_ZOMBIE: JobStatus.COMPLETED, + }.get(status, JobStatus.UNKNOWN) + if ret_code == 0: + return JobStatus.COMPLETED + return JobStatus.FAILED + + def stop_jobs( + self, *launched_ids: LaunchedJobID + ) -> t.Mapping[LaunchedJobID, JobStatus]: + """Take a collection of job ids and kill the corresponding processes + started by the shell launcher. + + :param launched_ids: The ids of the launched jobs to stop. + :returns: A mapping of ids for jobs to stop to their reported status + after attempting to stop them. + """ + return {id_: self._stop(id_) for id_ in launched_ids} + + def _stop(self, id_: LaunchedJobID, /, wait_time: float = 5.0) -> JobStatus: + """Stop a job represented by an id + + The launcher will first start by attempting to kill the process using + by sending a SIGTERM signal and then waiting for an amount of time. If + the process is not killed by the timeout time, a SIGKILL signal will be + sent and another waiting period will be started. If the period also + ends, the message will be logged and the process will be left to + continue running. The method will then get and return the status of the + job. + + :param id_: The id of a launched job to stop. + :param wait: The maximum amount of time, in seconds, to wait for a + signal to stop a process. + :returns: The status of the job after sending signals to terminate the + started process. + """ + proc = self._get_proc_from_job_id(id_) + if proc.poll() is None: + msg = f"Attempting to terminate local process {proc.pid}" + logger.debug(msg) + proc.terminate() + + try: + proc.wait(wait_time) + except sp.TimeoutExpired: + msg = f"Failed to terminate process {proc.pid}. Attempting to kill." + logger.warning(msg) + proc.kill() + + try: + proc.wait(wait_time) + except sp.TimeoutExpired: + logger.error(f"Failed to kill process {proc.pid}") + return self._get_status(id_) + + @classmethod + def create(cls, _: Experiment) -> Self: + """Create a new launcher instance from an experiment instance. + + :param _: An experiment instance. + :returns: A new launcher instance. + """ + return cls() diff --git a/smartsim/_core/utils/__init__.py b/smartsim/_core/utils/__init__.py index cddbc4ce98..4159c90424 100644 --- a/smartsim/_core/utils/__init__.py +++ b/smartsim/_core/utils/__init__.py @@ -30,7 +30,5 @@ delete_elements, execute_platform_cmd, expand_exe_path, - installed_redisai_backends, is_crayex_platform, ) -from .redis import check_cluster_status, create_cluster, db_is_active diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index bf5838928e..265205bef4 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -27,47 +27,97 @@ """ A file of helper functions for SmartSim """ +from __future__ import annotations + import base64 import collections.abc +import functools +import itertools import os import signal import subprocess +import sys import typing as t import uuid +import warnings from datetime import datetime -from functools import lru_cache -from pathlib import Path from shutil import which +from typing_extensions import TypeAlias + if t.TYPE_CHECKING: from types import FrameType + from typing_extensions import TypeVarTuple, Unpack + + from smartsim.launchable.job import Job + + _Ts = TypeVarTuple("_Ts") + _TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime"] +_T = t.TypeVar("_T") +_HashableT = t.TypeVar("_HashableT", bound=t.Hashable) _TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object] +_NestedJobSequenceType: TypeAlias = "t.Sequence[Job | _NestedJobSequenceType]" + + +def unpack(value: _NestedJobSequenceType) -> t.Generator[Job, None, None]: + """Unpack any iterable input in order to obtain a + single sequence of values + + :param value: Sequence containing elements of type Job or other + sequences that are also of type _NestedJobSequenceType + :return: flattened list of Jobs""" + from smartsim.launchable.job import Job + + for item in value: + + if isinstance(item, t.Iterable): + # string are iterable of string. Avoid infinite recursion + if isinstance(item, str): + raise TypeError("jobs argument was not of type Job") + yield from unpack(item) + else: + if not isinstance(item, Job): + raise TypeError("jobs argument was not of type Job") + yield item + + +def check_name(name: str) -> None: + """ + Checks if the input name is valid. + + :param name: The name to be checked. -def unpack_db_identifier(db_id: str, token: str) -> t.Tuple[str, str]: - """Unpack the unformatted database identifier + :raises ValueError: If the name contains the path separator (os.path.sep). + """ + if os.path.sep in name: + raise ValueError("Invalid input: String contains the path separator.") + + +def unpack_fs_identifier(fs_id: str, token: str) -> t.Tuple[str, str]: + """Unpack the unformatted feature store identifier and format for env variable suffix using the token - :param db_id: the unformatted database identifier eg. identifier_1 - :param token: character to use to construct the db suffix - :return: db id suffix and formatted db_id e.g. ("_identifier_1", "identifier_1") + :param fs_id: the unformatted feature store identifier eg. identifier_1 + :param token: character to use to construct the fs suffix + :return: fs id suffix and formatted fs_id e.g. ("_identifier_1", "identifier_1") """ - if db_id == "orchestrator": + if fs_id == "featurestore": return "", "" - db_name_suffix = token + db_id - return db_name_suffix, db_id + fs_name_suffix = token + fs_id + return fs_name_suffix, fs_id -def unpack_colo_db_identifier(db_id: str) -> str: - """Create database identifier suffix for colocated database +def unpack_colo_fs_identifier(fs_id: str) -> str: + """Create feature store identifier suffix for colocated feature store - :param db_id: the unformatted database identifier - :return: db suffix + :param fs_id: the unformatted feature store identifier + :return: fs suffix """ - return "_" + db_id if db_id else "" + return "_" + fs_id if fs_id else "" def create_short_id_str() -> str: @@ -80,13 +130,13 @@ def create_lockfile_name() -> str: return f"smartsim-{lock_suffix}.lock" -@lru_cache(maxsize=20, typed=False) +@functools.lru_cache(maxsize=20, typed=False) def check_dev_log_level() -> bool: lvl = os.environ.get("SMARTSIM_LOG_LEVEL", "") return lvl == "developer" -def fmt_dict(value: t.Dict[str, t.Any]) -> str: +def fmt_dict(value: t.Mapping[str, t.Any]) -> str: fmt_str = "" for k, v in value.items(): fmt_str += "\t" + str(k) + " = " + str(v) @@ -115,10 +165,13 @@ def expand_exe_path(exe: str) -> str: """Takes an executable and returns the full path to that executable :param exe: executable or file + :raises ValueError: if no executable is provided :raises TypeError: if file is not an executable :raises FileNotFoundError: if executable cannot be found """ + if not exe: + raise ValueError("No executable provided") # which returns none if not found in_path = which(exe) if not in_path: @@ -214,54 +267,6 @@ def cat_arg_and_value(arg_name: str, value: str) -> str: return f"--{arg_name}={value}" -def _installed(base_path: Path, backend: str) -> bool: - """ - Check if a backend is available for the RedisAI module. - """ - backend_key = f"redisai_{backend}" - backend_path = base_path / backend_key / f"{backend_key}.so" - backend_so = Path(os.environ.get("SMARTSIM_RAI_LIB", backend_path)).resolve() - - return backend_so.is_file() - - -def redis_install_base(backends_path: t.Optional[str] = None) -> Path: - # pylint: disable-next=import-outside-toplevel - from ..._core.config import CONFIG - - base_path: Path = ( - Path(backends_path) if backends_path else CONFIG.lib_path / "backends" - ) - return base_path - - -def installed_redisai_backends( - backends_path: t.Optional[str] = None, -) -> t.Set[_TRedisAIBackendStr]: - """Check which ML backends are available for the RedisAI module. - - The optional argument ``backends_path`` is needed if the backends - have not been built as part of the SmartSim building process (i.e. - they have not been built by invoking `smart build`). In that case - ``backends_path`` should point to the directory containing e.g. - the backend directories (`redisai_tensorflow`, `redisai_torch`, - `redisai_onnxruntime`, or `redisai_tflite`). - - :param backends_path: path containing backends - :return: list of installed RedisAI backends - """ - # import here to avoid circular import - base_path = redis_install_base(backends_path) - backends: t.Set[_TRedisAIBackendStr] = { - "tensorflow", - "torch", - "onnxruntime", - } - - installed = {backend for backend in backends if _installed(base_path, backend)} - return installed - - def get_ts_ms() -> int: """Return the current timestamp (accurate to milliseconds) cast to an integer""" return int(datetime.now().timestamp() * 1000) @@ -318,6 +323,20 @@ def execute_platform_cmd(cmd: str) -> t.Tuple[str, int]: return process.stdout.decode("utf-8"), process.returncode +def _stringify_id(_id: int) -> str: + """Return the CPU id as a string if an int, otherwise raise a ValueError + + :params _id: the CPU id as an int + :returns: the CPU as a string + """ + if isinstance(_id, int): + if _id < 0: + raise ValueError("CPU id must be a nonnegative number") + return str(_id) + + raise TypeError(f"Argument is of type '{type(_id)}' not 'int'") + + class CrayExPlatformResult: locate_msg = "Unable to locate `{0}`." @@ -412,6 +431,102 @@ def is_crayex_platform() -> bool: return result.is_cray +def first(predicate: t.Callable[[_T], bool], iterable: t.Iterable[_T]) -> _T | None: + """Return the first instance of an iterable that meets some precondition. + Any elements of the iterable that do not meet the precondition will be + forgotten. If no item in the iterable is found that meets the predicate, + `None` is returned. This is roughly equivalent to + + .. highlight:: python + .. code-block:: python + + next(filter(predicate, iterable), None) + + but does not require the predicate to be a type guard to type check. + + :param predicate: A function that returns `True` or `False` given a element + of the iterable + :param iterable: An iterable that yields elements to evealuate + :returns: The first element of the iterable to make the the `predicate` + return `True` + """ + return next((item for item in iterable if predicate(item)), None) + + +def unique(iterable: t.Iterable[_HashableT]) -> t.Iterable[_HashableT]: + """Iterate over an iterable, yielding only unique values. + + This helper function will maintain a set of seen values in memory and yield + any values not previously seen during iteration. This is nice if you know + you will be iterating over the iterable exactly once, but if you need to + iterate over the iterable multiple times, it would likely use less memory + to cast the iterable to a set first. + + :param iterable: An iterable of possibly not unique values. + :returns: An iterable of unique values with order unchanged from the + original iterable. + """ + seen = set() + for item in filter(lambda x: x not in seen, iterable): + seen.add(item) + yield item + + +def group_by( + fn: t.Callable[[_T], _HashableT], items: t.Iterable[_T] +) -> t.Mapping[_HashableT, t.Collection[_T]]: + """Iterate over an iterable and group the items based on the return of some + mapping function. Works similar to SQL's "GROUP BY" statement, but works + over an arbitrary mapping function. + + :param fn: A function mapping the iterable values to some hashable values + :items: An iterable yielding items to group by mapping function return. + :returns: A mapping of mapping function return values to collection of + items that returned that value when fed to the mapping function. + """ + groups = collections.defaultdict[_HashableT, list[_T]](list) + for item in items: + groups[fn(item)].append(item) + return dict(groups) + + +def pack_params( + fn: t.Callable[[Unpack[_Ts]], _T] +) -> t.Callable[[tuple[Unpack[_Ts]]], _T]: + r"""Take a function that takes an unspecified number of positional arguments + and turn it into a function that takes one argument of type `tuple` of + unspecified length. The main use case is largely just for iterating over an + iterable where arguments are "pre-zipped" into tuples. E.g. + + .. highlight:: python + .. code-block:: python + + def pretty_print_dict(d): + fmt_pair = lambda key, value: f"{repr(key)}: {repr(value)}," + body = "\n".join(map(pack_params(fmt_pair), d.items())) + # ^^^^^^^^^^^^^^^^^^^^^ + print(f"{{\n{textwrap.indent(body, ' ')}\n}}") + + pretty_print_dict({"spam": "eggs", "foo": "bar", "hello": "world"}) + # prints: + # { + # 'spam': 'eggs', + # 'foo': 'bar', + # 'hello': 'world', + # } + + :param fn: A callable that takes many positional parameters. + :returns: A callable that takes a single positional parameter of type tuple + of with the same shape as the original callable parameter list. + """ + + @functools.wraps(fn) + def packed(args: tuple[Unpack[_Ts]]) -> _T: + return fn(*args) + + return packed + + @t.final class SignalInterceptionStack(collections.abc.Collection[_TSignalHandlerFn]): """Registers a stack of callables to be called when a signal is @@ -490,3 +605,46 @@ def push_unique(self, fn: _TSignalHandlerFn) -> bool: if did_push := fn not in self: self.push(fn) return did_push + + def _create_pinning_string( + pin_ids: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]], cpus: int + ) -> t.Optional[str]: + """Create a comma-separated string of CPU ids. By default, ``None`` + returns 0,1,...,cpus-1; an empty iterable will disable pinning + altogether, and an iterable constructs a comma separated string of + integers (e.g. ``[0, 2, 5]`` -> ``"0,2,5"``) + + :params pin_ids: CPU ids + :params cpu: number of CPUs + :raises TypeError: if pin id is not an iterable of ints + :returns: a comma separated string of CPU ids + """ + + try: + pin_ids = tuple(pin_ids) if pin_ids is not None else None + except TypeError: + raise TypeError( + "Expected a cpu pinning specification of type iterable of ints or " + f"iterables of ints. Instead got type `{type(pin_ids)}`" + ) from None + + # Deal with MacOSX limitations first. The "None" (default) disables pinning + # and is equivalent to []. The only invalid option is a non-empty pinning + if sys.platform == "darwin": + if pin_ids: + warnings.warn( + "CPU pinning is not supported on MacOSX. Ignoring pinning " + "specification.", + RuntimeWarning, + ) + return None + + # Flatten the iterable into a list and check to make sure that the resulting + # elements are all ints + if pin_ids is None: + return ",".join(_stringify_id(i) for i in range(cpus)) + if not pin_ids: + return None + pin_ids = ((x,) if isinstance(x, int) else x for x in pin_ids) + to_fmt = itertools.chain.from_iterable(pin_ids) + return ",".join(sorted({_stringify_id(x) for x in to_fmt})) diff --git a/smartsim/_core/utils/launcher.py b/smartsim/_core/utils/launcher.py new file mode 100644 index 0000000000..7cb0a440b9 --- /dev/null +++ b/smartsim/_core/utils/launcher.py @@ -0,0 +1,99 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import abc +import collections.abc +import typing as t +import uuid + +from typing_extensions import Self + +from smartsim.status import JobStatus +from smartsim.types import LaunchedJobID + +if t.TYPE_CHECKING: + from smartsim.experiment import Experiment + +_T_contra = t.TypeVar("_T_contra", contravariant=True) + + +def create_job_id() -> LaunchedJobID: + return LaunchedJobID(str(uuid.uuid4())) + + +class LauncherProtocol(collections.abc.Hashable, t.Protocol[_T_contra]): + """The protocol defining a launcher that can be used by a SmartSim + experiment + """ + + @classmethod + @abc.abstractmethod + def create(cls, exp: Experiment, /) -> Self: + """Create an new launcher instance from and to be used by the passed in + experiment instance + + :param: An experiment to use the newly created launcher instance + :returns: The newly constructed launcher instance + """ + + @abc.abstractmethod + def start(self, launchable: _T_contra, /) -> LaunchedJobID: + """Given input that this launcher understands, create a new process and + issue a launched job id to query the status of the job in future. + + :param launchable: The input to start a new process + :returns: The id to query the status of the process in future + """ + + @abc.abstractmethod + def get_status( + self, *launched_ids: LaunchedJobID + ) -> t.Mapping[LaunchedJobID, JobStatus]: + """Given a collection of launched job ids, return a mapping of id to + current status of the launched job. If a job id is no recognized by the + launcher, a `smartsim.error.errors.LauncherJobNotFound` error should be + raised. + + :param launched_ids: The collection of ids of launched jobs to query + for current status + :raises smartsim.error.errors.LauncherJobNotFound: If at least one of + the ids of the `launched_ids` collection is not recognized. + :returns: A mapping of launched id to current status + """ + + @abc.abstractmethod + def stop_jobs( + self, *launched_ids: LaunchedJobID + ) -> t.Mapping[LaunchedJobID, JobStatus]: + """Given a collection of launched job ids, cancel the launched jobs + + :param launched_ids: The ids of the jobs to stop + :raises smartsim.error.errors.LauncherJobNotFound: If at least one of + the ids of the `launched_ids` collection is not recognized. + :returns: A mapping of launched id to status upon cancellation + """ diff --git a/smartsim/_core/utils/redis.py b/smartsim/_core/utils/redis.py deleted file mode 100644 index 76ff45cd5a..0000000000 --- a/smartsim/_core/utils/redis.py +++ /dev/null @@ -1,238 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import logging -import time -import typing as t -from itertools import product - -import redis -from redis.cluster import ClusterNode, RedisCluster -from redis.exceptions import ClusterDownError, RedisClusterException -from smartredis import Client -from smartredis.error import RedisReplyError - -from ...entity import DBModel, DBScript -from ...error import SSInternalError -from ...log import get_logger -from ..config import CONFIG -from .network import get_ip_from_host -from .shell import execute_cmd - -logging.getLogger("rediscluster").setLevel(logging.WARNING) -logger = get_logger(__name__) - - -def create_cluster(hosts: t.List[str], ports: t.List[int]) -> None: # cov-wlm - """Connect launched cluster instances. - - Should only be used in the case where cluster initialization - needs to occur manually which is not often. - - :param hosts: List of hostnames to connect to - :param ports: List of ports for each hostname - :raises SmartSimError: if cluster creation fails - """ - ip_list = [] - for host in hosts: - ip_address = get_ip_from_host(host) - for port in ports: - address = ":".join((ip_address, str(port) + " ")) - ip_list.append(address) - - # call cluster command - redis_cli = CONFIG.database_cli - cmd = [redis_cli, "--cluster", "create"] - cmd += ip_list - cmd += ["--cluster-replicas", "0", "--cluster-yes"] - returncode, out, err = execute_cmd(cmd, proc_input="yes", shell=False) - - if returncode != 0: - logger.error(out) - logger.error(err) - raise SSInternalError("Database '--cluster create' command failed") - logger.debug(out) - - -def check_cluster_status( - hosts: t.List[str], ports: t.List[int], trials: int = 10 -) -> None: # cov-wlm - """Check that a Redis/KeyDB cluster is up and running - - :param hosts: List of hostnames to connect to - :param ports: List of ports for each hostname - :param trials: number of attempts to verify cluster status - - :raises SmartSimError: If cluster status cannot be verified - """ - cluster_nodes = [ - ClusterNode(get_ip_from_host(host), port) - for host, port in product(hosts, ports) - ] - - if not cluster_nodes: - raise SSInternalError( - "No cluster nodes have been set for database status check." - ) - - logger.debug("Beginning database cluster status check...") - while trials > 0: - # wait for cluster to spin up - time.sleep(5) - try: - redis_tester: "RedisCluster[t.Any]" = RedisCluster( - startup_nodes=cluster_nodes - ) - redis_tester.set("__test__", "__test__") - redis_tester.delete("__test__") # type: ignore - logger.debug("Cluster status verified") - return - except (ClusterDownError, RedisClusterException, redis.RedisError): - logger.debug("Cluster still spinning up...") - trials -= 1 - if trials == 0: - raise SSInternalError("Cluster setup could not be verified") - - -def db_is_active(hosts: t.List[str], ports: t.List[int], num_shards: int) -> bool: - """Check if a DB is running - - if the DB is clustered, check cluster status, otherwise - just ping DB. - - :param hosts: list of hosts - :param ports: list of ports - :param num_shards: Number of DB shards - :return: Whether DB is running - """ - # if single shard - if num_shards < 2: - host = hosts[0] - port = ports[0] - try: - client = redis.Redis(host=host, port=port, db=0) - if client.ping(): - return True - return False - except redis.RedisError: - return False - # if a cluster - else: - try: - check_cluster_status(hosts, ports, trials=1) - return True - # we expect this to fail if the cluster is not active - except SSInternalError: - return False - - -def set_ml_model(db_model: DBModel, client: Client) -> None: - logger.debug(f"Adding DBModel named {db_model.name}") - - for device in db_model.devices: - try: - if db_model.is_file: - client.set_model_from_file( - name=db_model.name, - model_file=str(db_model.file), - backend=db_model.backend, - device=device, - batch_size=db_model.batch_size, - min_batch_size=db_model.min_batch_size, - min_batch_timeout=db_model.min_batch_timeout, - tag=db_model.tag, - inputs=db_model.inputs, - outputs=db_model.outputs, - ) - else: - if db_model.model is None: - raise ValueError(f"No model attacted to {db_model.name}") - client.set_model( - name=db_model.name, - model=db_model.model, - backend=db_model.backend, - device=device, - batch_size=db_model.batch_size, - min_batch_size=db_model.min_batch_size, - min_batch_timeout=db_model.min_batch_timeout, - tag=db_model.tag, - inputs=db_model.inputs, - outputs=db_model.outputs, - ) - except RedisReplyError as error: # pragma: no cover - logger.error("Error while setting model on orchestrator.") - raise error - - -def set_script(db_script: DBScript, client: Client) -> None: - logger.debug(f"Adding DBScript named {db_script.name}") - - for device in db_script.devices: - try: - if db_script.is_file: - client.set_script_from_file( - name=db_script.name, file=str(db_script.file), device=device - ) - elif db_script.script: - if isinstance(db_script.script, str): - client.set_script( - name=db_script.name, script=db_script.script, device=device - ) - elif callable(db_script.script): - client.set_function( - name=db_script.name, function=db_script.script, device=device - ) - else: - raise ValueError(f"No script or file attached to {db_script.name}") - except RedisReplyError as error: # pragma: no cover - logger.error("Error while setting model on orchestrator.") - raise error - - -def shutdown_db_node(host_ip: str, port: int) -> t.Tuple[int, str, str]: # cov-wlm - """Send shutdown signal to DB node. - - Should only be used in the case where cluster deallocation - needs to occur manually. Usually, the SmartSim job manager - will take care of this automatically. - - :param host_ip: IP of host to connect to - :param ports: Port to which node is listening - :return: returncode, output, and error of the process - """ - redis_cli = CONFIG.database_cli - cmd = [redis_cli, "-h", host_ip, "-p", str(port), "shutdown"] - returncode, out, err = execute_cmd(cmd, proc_input="yes", shell=False, timeout=10) - - if returncode != 0: - logger.error(out) - err_msg = "Error while shutting down DB node. " - err_msg += f"Return code: {returncode}, err: {err}" - logger.error(err_msg) - elif out: - logger.debug(out) - - return returncode, out, err diff --git a/smartsim/_core/utils/serialize.py b/smartsim/_core/utils/serialize.py index d4ec66eaf5..46c0a2c1da 100644 --- a/smartsim/_core/utils/serialize.py +++ b/smartsim/_core/utils/serialize.py @@ -36,9 +36,10 @@ if t.TYPE_CHECKING: from smartsim._core.control.manifest import LaunchedManifest as _Manifest - from smartsim.database.orchestrator import Orchestrator - from smartsim.entity import DBNode, Ensemble, Model - from smartsim.entity.dbobject import DBModel, DBScript + from smartsim.builders import Ensemble + from smartsim.database.orchestrator import FeatureStore + from smartsim.entity import Application, FSNode + from smartsim.entity.dbobject import FSModel, FSScript from smartsim.settings.base import BatchSettings, RunSettings @@ -58,12 +59,12 @@ def save_launch_manifest(manifest: _Manifest[TStepLaunchMetaData]) -> None: new_run = { "run_id": manifest.metadata.run_id, "timestamp": int(time.time_ns()), - "model": [ - _dictify_model(model, *telemetry_metadata) - for model, telemetry_metadata in manifest.models + "application": [ + _dictify_application(application, *telemetry_metadata) + for application, telemetry_metadata in manifest.applications ], - "orchestrator": [ - _dictify_db(db, nodes_info) for db, nodes_info in manifest.databases + "featurestore": [ + _dictify_fs(fs, nodes_info) for fs, nodes_info in manifest.featurestores ], "ensemble": [ _dictify_ensemble(ens, member_info) @@ -95,8 +96,8 @@ def save_launch_manifest(manifest: _Manifest[TStepLaunchMetaData]) -> None: json.dump(manifest_dict, file, indent=2) -def _dictify_model( - model: Model, +def _dictify_application( + application: Application, step_id: t.Optional[str], task_id: t.Optional[str], managed: t.Optional[bool], @@ -104,34 +105,38 @@ def _dictify_model( err_file: str, telemetry_data_path: Path, ) -> t.Dict[str, t.Any]: - colo_settings = (model.run_settings.colocated_db_settings or {}).copy() - db_scripts = t.cast("t.List[DBScript]", colo_settings.pop("db_scripts", [])) - db_models = t.cast("t.List[DBModel]", colo_settings.pop("db_models", [])) + if application.run_settings is not None: + colo_settings = (application.run_settings.colocated_fs_settings or {}).copy() + else: + colo_settings = ({}).copy() + fs_scripts = t.cast("t.List[FSScript]", colo_settings.pop("fs_scripts", [])) + fs_models = t.cast("t.List[FSModel]", colo_settings.pop("fs_models", [])) return { - "name": model.name, - "path": model.path, - "exe_args": model.run_settings.exe_args, - "run_settings": _dictify_run_settings(model.run_settings), + "name": application.name, + "path": application.path, + "exe_args": application.exe_args, + "exe": application.exe, + "run_settings": _dictify_run_settings(application.run_settings), "batch_settings": ( - _dictify_batch_settings(model.batch_settings) - if model.batch_settings + _dictify_batch_settings(application.batch_settings) + if application.batch_settings else {} ), - "params": model.params, + "params": application.params, "files": ( { - "Symlink": model.files.link, - "Configure": model.files.tagged, - "Copy": model.files.copy, + "Symlink": application.files.link, + "Configure": application.files.tagged, + "Copy": application.files.copy, } - if model.files + if application.files else { "Symlink": [], "Configure": [], "Copy": [], } ), - "colocated_db": ( + "colocated_fs": ( { "settings": colo_settings, "scripts": [ @@ -141,7 +146,7 @@ def _dictify_model( "device": script.device, } } - for script in db_scripts + for script in fs_scripts ], "models": [ { @@ -150,7 +155,7 @@ def _dictify_model( "device": model.device, } } - for model in db_models + for model in fs_models ], } if colo_settings @@ -169,7 +174,7 @@ def _dictify_model( def _dictify_ensemble( ens: Ensemble, - members: t.Sequence[t.Tuple[Model, TStepLaunchMetaData]], + members: t.Sequence[t.Tuple[Application, TStepLaunchMetaData]], ) -> t.Dict[str, t.Any]: return { "name": ens.name, @@ -181,9 +186,9 @@ def _dictify_ensemble( if ens.batch_settings else {} ), - "models": [ - _dictify_model(model, *launching_metadata) - for model, launching_metadata in members + "applications": [ + _dictify_application(application, *launching_metadata) + for application, launching_metadata in members ], } @@ -196,11 +201,10 @@ def _dictify_run_settings(run_settings: RunSettings) -> t.Dict[str, t.Any]: "MPMD run settings" ) return { - "exe": run_settings.exe, # TODO: We should try to move this back # "exe_args": run_settings.exe_args, - "run_command": run_settings.run_command, - "run_args": run_settings.run_args, + "run_command": run_settings.run_command if run_settings else "", + "run_args": run_settings.run_args if run_settings else None, # TODO: We currently do not have a way to represent MPMD commands! # Maybe add a ``"mpmd"`` key here that is a # ``list[TDictifiedRunSettings]``? @@ -214,20 +218,20 @@ def _dictify_batch_settings(batch_settings: BatchSettings) -> t.Dict[str, t.Any] } -def _dictify_db( - db: Orchestrator, - nodes: t.Sequence[t.Tuple[DBNode, TStepLaunchMetaData]], +def _dictify_fs( + fs: FeatureStore, + nodes: t.Sequence[t.Tuple[FSNode, TStepLaunchMetaData]], ) -> t.Dict[str, t.Any]: - db_path = _utils.get_db_path() - if db_path: - db_type, _ = db_path.name.split("-", 1) + fs_path = _utils.get_fs_path() + if fs_path: + fs_type, _ = fs_path.name.split("-", 1) else: - db_type = "Unknown" + fs_type = "Unknown" return { - "name": db.name, - "type": db_type, - "interface": db._interfaces, # pylint: disable=protected-access + "name": fs.name, + "type": fs_type, + "interface": fs._interfaces, # pylint: disable=protected-access "shards": [ { **shard.to_dict(), @@ -235,14 +239,14 @@ def _dictify_db( "out_file": out_file, "err_file": err_file, "memory_file": ( - str(status_dir / "memory.csv") if db.telemetry.is_enabled else "" + str(status_dir / "memory.csv") if fs.telemetry.is_enabled else "" ), "client_file": ( - str(status_dir / "client.csv") if db.telemetry.is_enabled else "" + str(status_dir / "client.csv") if fs.telemetry.is_enabled else "" ), "client_count_file": ( str(status_dir / "client_count.csv") - if db.telemetry.is_enabled + if fs.telemetry.is_enabled else "" ), "telemetry_metadata": { @@ -252,7 +256,7 @@ def _dictify_db( "managed": managed, }, } - for dbnode, ( + for fsnode, ( step_id, task_id, managed, @@ -260,6 +264,6 @@ def _dictify_db( err_file, status_dir, ) in nodes - for shard in dbnode.get_launched_shard_info() + for shard in fsnode.get_launched_shard_info() ], } diff --git a/smartsim/_core/utils/telemetry/collector.py b/smartsim/_core/utils/telemetry/collector.py index 178126dec9..02f5ed9f1f 100644 --- a/smartsim/_core/utils/telemetry/collector.py +++ b/smartsim/_core/utils/telemetry/collector.py @@ -30,16 +30,18 @@ import logging import typing as t -import redis.asyncio as redisa -import redis.exceptions as redisex - from smartsim._core.control.job import JobEntity from smartsim._core.utils.helpers import get_ts_ms from smartsim._core.utils.telemetry.sink import FileSink, Sink +from smartsim.entity._mock import Mock logger = logging.getLogger("TelemetryMonitor") +class Client(Mock): + """Mock Client""" + + class Collector(abc.ABC): """Base class for telemetry collectors. @@ -95,8 +97,8 @@ class _DBAddress: def __init__(self, host: str, port: int) -> None: """Initialize the instance - :param host: host address for database connections - :param port: port number for database connections + :param host: host address for feature store connections + :param port: port number for feature store connections """ self.host = host.strip() if host else "" self.port = port @@ -114,8 +116,9 @@ def __str__(self) -> str: return f"{self.host}:{self.port}" +# TODO add a new Client class DBCollector(Collector): - """A base class for collectors that retrieve statistics from an orchestrator""" + """A base class for collectors that retrieve statistics from a feature store""" def __init__(self, entity: JobEntity, sink: Sink) -> None: """Initialize the `DBCollector` @@ -124,19 +127,17 @@ def __init__(self, entity: JobEntity, sink: Sink) -> None: :param sink: destination to write collected information """ super().__init__(entity, sink) - self._client: t.Optional[redisa.Redis[bytes]] = None + self._client: Client self._address = _DBAddress( self._entity.config.get("host", ""), int(self._entity.config.get("port", 0)), ) async def _configure_client(self) -> None: - """Configure the client connection to the target database""" + """Configure the client connection to the target feature store""" try: if not self._client: - self._client = redisa.Redis( - host=self._address.host, port=self._address.port - ) + self._client = None except Exception as e: logger.exception(e) finally: @@ -146,7 +147,7 @@ async def _configure_client(self) -> None: ) async def prepare(self) -> None: - """Initialization logic for the DB collector. Creates a database + """Initialization logic for the FS collector. Creates a feature store connection then executes the `post_prepare` callback function.""" if self._client: return @@ -157,7 +158,7 @@ async def prepare(self) -> None: @abc.abstractmethod async def _post_prepare(self) -> None: """Hook function to enable subclasses to perform actions - after a db client is ready""" + after a fss client is ready""" @abc.abstractmethod async def _perform_collection( @@ -171,7 +172,7 @@ async def _perform_collection( """ async def collect(self) -> None: - """Execute database metric collection if the collector is enabled. Writes + """Execute feature store metric collection if the collector is enabled. Writes the resulting metrics to the associated output sink. Calling `collect` when `self.enabled` is `False` performs no actions.""" if not self.enabled: @@ -186,8 +187,8 @@ async def collect(self) -> None: return try: - # if we can't communicate w/the db, exit - if not await self._check_db(): + # if we can't communicate w/the fs, exit + if not await self._check_fs(): return all_metrics = await self._perform_collection() @@ -197,7 +198,7 @@ async def collect(self) -> None: logger.warning(f"Collect failed for {type(self).__name__}", exc_info=ex) async def shutdown(self) -> None: - """Execute cleanup of database client connections""" + """Execute cleanup of feature store client connections""" try: if self._client: logger.info( @@ -210,16 +211,16 @@ async def shutdown(self) -> None: f"An error occurred during {type(self).__name__} shutdown", exc_info=ex ) - async def _check_db(self) -> bool: - """Check if the target database is reachable. + async def _check_fs(self) -> bool: + """Check if the target feature store is reachable. :return: `True` if connection succeeds, `False` otherwise. """ try: if self._client: return await self._client.ping() - except redisex.ConnectionError: - logger.warning(f"Cannot ping db {self._address}") + except Exception: + logger.warning(f"Cannot ping fs {self._address}") return False @@ -233,7 +234,7 @@ def __init__(self, entity: JobEntity, sink: Sink) -> None: async def _post_prepare(self) -> None: """Write column headers for a CSV formatted output sink after - the database connection is established""" + the feature store connection is established""" await self._sink.save("timestamp", *self._columns) async def _perform_collection( @@ -247,11 +248,11 @@ async def _perform_collection( if self._client is None: return [] - db_info = await self._client.info("memory") + fs_info = await self._client.info("memory") - used = float(db_info["used_memory"]) - peak = float(db_info["used_memory_peak"]) - total = float(db_info["total_system_memory"]) + used = float(fs_info["used_memory"]) + peak = float(fs_info["used_memory_peak"]) + total = float(fs_info["total_system_memory"]) value = (get_ts_ms(), used, peak, total) @@ -261,7 +262,7 @@ async def _perform_collection( class DBConnectionCollector(DBCollector): - """A `DBCollector` that collects database client-connection metrics""" + """A `DBCollector` that collects feature store client-connection metrics""" def __init__(self, entity: JobEntity, sink: Sink) -> None: super().__init__(entity, sink) @@ -269,7 +270,7 @@ def __init__(self, entity: JobEntity, sink: Sink) -> None: async def _post_prepare(self) -> None: """Write column headers for a CSV formatted output sink after - the database connection is established""" + the feature store connection is established""" await self._sink.save("timestamp", *self._columns) async def _perform_collection( @@ -306,7 +307,7 @@ def __init__(self, entity: JobEntity, sink: Sink) -> None: async def _post_prepare(self) -> None: """Write column headers for a CSV formatted output sink after - the database connection is established""" + the feature store connection is established""" await self._sink.save("timestamp", *self._columns) async def _perform_collection( @@ -457,9 +458,9 @@ def register_collectors(self, entity: JobEntity) -> None: """ collectors: t.List[Collector] = [] - # ONLY db telemetry is implemented at this time. This resolver must - # be updated when non-database or always-on collectors are introduced - if entity.is_db and entity.telemetry_on: + # ONLY fs telemetry is implemented at this time. This resolver must + # be updated when non-feature store or always-on collectors are introduced + if entity.is_fs and entity.telemetry_on: if mem_out := entity.collectors.get("memory", None): collectors.append(DBMemoryCollector(entity, FileSink(mem_out))) @@ -469,7 +470,7 @@ def register_collectors(self, entity: JobEntity) -> None: if num_out := entity.collectors.get("client_count", None): collectors.append(DBConnectionCountCollector(entity, FileSink(num_out))) else: - logger.debug(f"Collectors disabled for db {entity.name}") + logger.debug(f"Collectors disabled for fs {entity.name}") self.add_all(collectors) diff --git a/smartsim/_core/utils/telemetry/manifest.py b/smartsim/_core/utils/telemetry/manifest.py index 942fa4ae87..4cf067f08e 100644 --- a/smartsim/_core/utils/telemetry/manifest.py +++ b/smartsim/_core/utils/telemetry/manifest.py @@ -43,10 +43,10 @@ class Run: timestamp: int """the timestamp at the time the `Experiment.start` is called""" - models: t.List[JobEntity] - """models started in this run""" - orchestrators: t.List[JobEntity] - """orchestrators started in this run""" + applications: t.List[JobEntity] + """applications started in this run""" + featurestores: t.List[JobEntity] + """featurestores started in this run""" ensembles: t.List[JobEntity] """ensembles started in this run""" @@ -58,7 +58,7 @@ def flatten( :param filter_fn: optional boolean filter that returns True for entities to include in the result """ - entities = self.models + self.orchestrators + self.ensembles + entities = self.applications + self.featurestores + self.ensembles if filter_fn: entities = [entity for entity in entities if filter_fn(entity)] return entities @@ -82,11 +82,11 @@ def load_entity( # an entity w/parent keys must create entities for the items that it # comprises. traverse the children and create each entity - parent_keys = {"shards", "models"} + parent_keys = {"shards", "applications"} parent_keys = parent_keys.intersection(entity_dict.keys()) if parent_keys: - container = "shards" if "shards" in parent_keys else "models" - child_type = "orchestrator" if container == "shards" else "model" + container = "shards" if "shards" in parent_keys else "applications" + child_type = "featurestore" if container == "shards" else "application" for child_entity in entity_dict[container]: entity = JobEntity.from_manifest( child_type, child_entity, str(exp_dir), raw_experiment @@ -118,8 +118,8 @@ def load_entities( :return: list of loaded `JobEntity` instances """ persisted: t.Dict[str, t.List[JobEntity]] = { - "model": [], - "orchestrator": [], + "application": [], + "featurestore": [], } for item in run[entity_type]: entities = Run.load_entity(entity_type, item, exp_dir, raw_experiment) @@ -144,8 +144,8 @@ def load_run( # create an output mapping to hold the deserialized entities run_entities: t.Dict[str, t.List[JobEntity]] = { - "model": [], - "orchestrator": [], + "application": [], + "featurestore": [], "ensemble": [], } @@ -164,8 +164,8 @@ def load_run( loaded_run = Run( raw_run["timestamp"], - run_entities["model"], - run_entities["orchestrator"], + run_entities["application"], + run_entities["featurestore"], run_entities["ensemble"], ) return loaded_run diff --git a/smartsim/_core/utils/telemetry/telemetry.py b/smartsim/_core/utils/telemetry/telemetry.py index e9e4c46bc4..c8ff3bf25e 100644 --- a/smartsim/_core/utils/telemetry/telemetry.py +++ b/smartsim/_core/utils/telemetry/telemetry.py @@ -41,14 +41,13 @@ from smartsim._core.config import CONFIG from smartsim._core.control.job import JobEntity, _JobKey -from smartsim._core.control.jobmanager import JobManager -from smartsim._core.launcher.dragon.dragonLauncher import DragonLauncher +from smartsim._core.launcher.dragon.dragon_launcher import DragonLauncher from smartsim._core.launcher.launcher import Launcher from smartsim._core.launcher.local.local import LocalLauncher -from smartsim._core.launcher.lsf.lsfLauncher import LSFLauncher -from smartsim._core.launcher.pbs.pbsLauncher import PBSLauncher -from smartsim._core.launcher.slurm.slurmLauncher import SlurmLauncher -from smartsim._core.launcher.stepInfo import StepInfo +from smartsim._core.launcher.lsf.lsf_launcher import LSFLauncher +from smartsim._core.launcher.pbs.pbs_launcher import PBSLauncher +from smartsim._core.launcher.slurm.slurm_launcher import SlurmLauncher +from smartsim._core.launcher.step_info import StepInfo from smartsim._core.utils.helpers import get_ts_ms from smartsim._core.utils.serialize import MANIFEST_FILENAME from smartsim._core.utils.telemetry.collector import CollectorManager @@ -95,7 +94,6 @@ def __init__( self._tracked_jobs: t.Dict[_JobKey, JobEntity] = {} self._completed_jobs: t.Dict[_JobKey, JobEntity] = {} self._launcher: t.Optional[Launcher] = None - self.job_manager: JobManager = JobManager(threading.RLock()) self._launcher_map: t.Dict[str, t.Type[Launcher]] = { "slurm": SlurmLauncher, "pbs": PBSLauncher, @@ -132,14 +130,6 @@ def init_launcher(self, launcher: str) -> None: raise ValueError("Launcher type not supported: " + launcher) - def init_job_manager(self) -> None: - """Initialize the job manager instance""" - if not self._launcher: - raise TypeError("self._launcher must be initialized") - - self.job_manager.set_launcher(self._launcher) - self.job_manager.start() - def set_launcher(self, launcher_type: str) -> None: """Set the launcher for the experiment :param launcher_type: the name of the workload manager used by the experiment @@ -149,9 +139,6 @@ def set_launcher(self, launcher_type: str) -> None: if self._launcher is None: raise SmartSimError("Launcher init failed") - self.job_manager.set_launcher(self._launcher) - self.job_manager.start() - def process_manifest(self, manifest_path: str) -> None: """Read the manifest for the experiment. Process the `RuntimeManifest` by updating the set of tracked jobs @@ -210,14 +197,6 @@ def process_manifest(self, manifest_path: str) -> None: ) if entity.is_managed: - # Tell JobManager the task is unmanaged. This collects - # status updates but does not try to start a new copy - self.job_manager.add_job( - entity.name, - entity.step_id, - entity, - False, - ) # Tell the launcher it's managed so it doesn't attempt # to look for a PID that may no longer exist self._launcher.step_mapping.add( @@ -264,9 +243,6 @@ async def _to_completed( # remove all the registered collectors for the completed entity await self._collector_mgr.remove(entity) - job = self.job_manager[entity.name] - self.job_manager.move_to_completed(job) - status_clause = f"status: {step_info.status}" error_clause = f", error: {step_info.error}" if step_info.error else "" @@ -432,8 +408,7 @@ class TelemetryMonitor: """The telemetry monitor is a standalone process managed by SmartSim to perform long-term retrieval of experiment status updates and resource usage metrics. Note that a non-blocking driver script is likely to complete before - the SmartSim entities complete. Also, the JobManager performs status updates - only as long as the driver is running. This telemetry monitor entrypoint is + the SmartSim entities complete. This telemetry monitor entrypoint is started automatically when a SmartSim experiment calls the `start` method on resources. The entrypoint runs until it has no resources to monitor.""" @@ -458,33 +433,29 @@ def __init__(self, telemetry_monitor_args: TelemetryMonitorArgs): def _can_shutdown(self) -> bool: """Determines if the telemetry monitor can perform shutdown. An automatic shutdown will occur if there are no active jobs being monitored. - Managed jobs and databases are considered separately due to the way they + Managed jobs and feature stores are considered separately due to the way they are stored in the job manager :return: return True if capable of automatically shutting down """ - managed_jobs = ( - list(self._action_handler.job_manager.jobs.values()) - if self._action_handler - else [] - ) + managed_jobs = [] unmanaged_jobs = ( list(self._action_handler.tracked_jobs) if self._action_handler else [] ) - # get an individual count of databases for logging - n_dbs: int = len( + # get an individual count of feature stores for logging + n_fss: int = len( [ job for job in managed_jobs + unmanaged_jobs - if isinstance(job, JobEntity) and job.is_db + if isinstance(job, JobEntity) and job.is_fs ] ) # if we have no jobs currently being monitored we can shutdown - n_jobs = len(managed_jobs) + len(unmanaged_jobs) - n_dbs - shutdown_ok = n_jobs + n_dbs == 0 + n_jobs = len(managed_jobs) + len(unmanaged_jobs) - n_fss + shutdown_ok = n_jobs + n_fss == 0 - logger.debug(f"{n_jobs} active job(s), {n_dbs} active db(s)") + logger.debug(f"{n_jobs} active job(s), {n_fss} active fs(s)") return shutdown_ok async def monitor(self) -> None: diff --git a/smartsim/_core/utils/telemetry/util.py b/smartsim/_core/utils/telemetry/util.py index 2c51d96000..5a1c94d5cb 100644 --- a/smartsim/_core/utils/telemetry/util.py +++ b/smartsim/_core/utils/telemetry/util.py @@ -30,8 +30,8 @@ import pathlib import typing as t -from smartsim._core.launcher.stepInfo import StepInfo -from smartsim.status import TERMINAL_STATUSES, SmartSimStatus +from smartsim._core.launcher.step_info import StepInfo +from smartsim.status import TERMINAL_STATUSES, JobStatus _EventClass = t.Literal["start", "stop", "timestep"] @@ -55,7 +55,7 @@ def write_event( :param task_id: the task_id of a managed task :param step_id: the step_id of an unmanaged task :param entity_type: the SmartSimEntity subtype - (e.g. `orchestrator`, `ensemble`, `model`, `dbnode`, ...) + (e.g. `featurestore`, `ensemble`, `application`, `fsnode`, ...) :param event_type: the event subtype :param status_dir: path where the SmartSimEntity outputs are written :param detail: (optional) additional information to write with the event @@ -106,8 +106,6 @@ def map_return_code(step_info: StepInfo) -> t.Optional[int]: :return: a return code if the step is finished, otherwise None """ rc_map = {s: 1 for s in TERMINAL_STATUSES} # return `1` for all terminal statuses - rc_map.update( - {SmartSimStatus.STATUS_COMPLETED: os.EX_OK} - ) # return `0` for full success + rc_map.update({JobStatus.COMPLETED: os.EX_OK}) # return `0` for full success return rc_map.get(step_info.status, None) # return `None` when in-progress diff --git a/smartsim/builders/__init__.py b/smartsim/builders/__init__.py new file mode 100644 index 0000000000..866269f201 --- /dev/null +++ b/smartsim/builders/__init__.py @@ -0,0 +1,28 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .ensemble import Ensemble +from .utils.strategies import ParamSet diff --git a/smartsim/builders/ensemble.py b/smartsim/builders/ensemble.py new file mode 100644 index 0000000000..d87ada15aa --- /dev/null +++ b/smartsim/builders/ensemble.py @@ -0,0 +1,432 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import collections +import copy +import itertools +import os +import os.path +import typing as t + +from smartsim.builders.utils import strategies +from smartsim.builders.utils.strategies import ParamSet +from smartsim.entity import entity +from smartsim.entity.application import Application +from smartsim.entity.files import EntityFiles +from smartsim.launchable.job import Job +from smartsim.settings.launch_settings import LaunchSettings + +if t.TYPE_CHECKING: + from smartsim.settings.launch_settings import LaunchSettings + + +class Ensemble(entity.CompoundEntity): + """An Ensemble is a builder class that parameterizes the creation of multiple + Applications. + """ + + def __init__( + self, + name: str, + exe: str | os.PathLike[str], + exe_args: t.Sequence[str] | None = None, + exe_arg_parameters: t.Mapping[str, t.Sequence[t.Sequence[str]]] | None = None, + files: EntityFiles | None = None, + file_parameters: t.Mapping[str, t.Sequence[str]] | None = None, + permutation_strategy: str | strategies.PermutationStrategyType = "all_perm", + max_permutations: int = -1, + replicas: int = 1, + ) -> None: + """Initialize an ``Ensemble`` of Application instances + + An Ensemble can be tailored to align with one of the following + creation strategies: parameter expansion or replicas. + + **Parameter Expansion** + + Parameter expansion allows users to assign different parameter values to + multiple Applications. This is done by specifying input to `Ensemble.file_parameters`, + `Ensemble.exe_arg_parameters` and `Ensemble.permutation_strategy`. The `permutation_strategy` + argument accepts three options: + + 1. "all_perm": Generates all possible parameter permutations for exhaustive exploration. + 2. "step": Collects identically indexed values across parameter lists to create parameter sets. + 3. "random": Enables random selection from predefined parameter spaces. + + The example below demonstrates creating an Ensemble via parameter expansion, resulting in + the creation of two Applications: + + .. highlight:: python + .. code-block:: python + + file_params={"SPAM": ["a", "b"], "EGGS": ["c", "d"]} + exe_arg_parameters = {"EXE": [["a"], ["b", "c"]], "ARGS": [["d"], ["e", "f"]]} + ensemble = Ensemble(name="name",exe="python",exe_arg_parameters=exe_arg_parameters, + file_parameters=file_params,permutation_strategy="step") + + This configuration will yield the following permutations: + + .. highlight:: python + .. code-block:: python + [ParamSet(params={'SPAM': 'a', 'EGGS': 'c'}, exe_args={'EXE': ['a'], 'ARGS': ['d']}), + ParamSet(params={'SPAM': 'b', 'EGGS': 'd'}, exe_args={'EXE': ['b', 'c'], 'ARGS': ['e', 'f']})] + + Each ParamSet contains the parameters assigned from file_params and the corresponding executable + arguments from exe_arg_parameters. + + **Replication** + The replication strategy involves creating identical Applications within an Ensemble. + This is achieved by specifying the `replicas` argument in the Ensemble. + + For example, by applying the `replicas` argument to the previous parameter expansion + example, we can double our Application output: + + .. highlight:: python + .. code-block:: python + + file_params={"SPAM": ["a", "b"], "EGGS": ["c", "d"]} + exe_arg_parameters = {"EXE": [["a"], ["b", "c"]], "ARGS": [["d"], ["e", "f"]]} + ensemble = Ensemble(name="name",exe="python",exe_arg_parameters=exe_arg_parameters, + file_parameters=file_params,permutation_strategy="step", replicas=2) + + This configuration will result in each ParamSet being replicated, effectively doubling + the number of Applications created. + + :param name: name of the ensemble + :param exe: executable to run + :param exe_args: executable arguments + :param exe_arg_parameters: parameters and values to be used when configuring entities + :param files: files to be copied, symlinked, and/or configured prior to + execution + :param file_parameters: parameters and values to be used when configuring + files + :param permutation_strategy: strategy to control how the param values are applied to the Ensemble + :param max_permutations: max parameter permutations to set for the ensemble + :param replicas: number of identical entities to create within an Ensemble + """ + self.name = name + """The name of the ensemble""" + self._exe = os.fspath(exe) + """The executable to run""" + self.exe_args = list(exe_args) if exe_args else [] + """The executable arguments""" + self._exe_arg_parameters = ( + copy.deepcopy(exe_arg_parameters) if exe_arg_parameters else {} + ) + """The parameters and values to be used when configuring entities""" + self._files = copy.deepcopy(files) if files else EntityFiles() + """The files to be copied, symlinked, and/or configured prior to execution""" + self._file_parameters = ( + copy.deepcopy(file_parameters) if file_parameters else {} + ) + """The parameters and values to be used when configuring files""" + self._permutation_strategy = permutation_strategy + """The strategy to control how the param values are applied to the Ensemble""" + self._max_permutations = max_permutations + """The maximum number of entities to come out of the permutation strategy""" + self._replicas = replicas + """How many identical entities to create within an Ensemble""" + + @property + def exe(self) -> str: + """Return the attached executable. + + :return: the executable + """ + return self._exe + + @exe.setter + def exe(self, value: str | os.PathLike[str]) -> None: + """Set the executable. + + :param value: the executable + :raises TypeError: if the exe argument is not str or PathLike str + """ + if not isinstance(value, (str, os.PathLike)): + raise TypeError("exe argument was not of type str or PathLike str") + + self._exe = os.fspath(value) + + @property + def exe_args(self) -> t.List[str]: + """Return attached list of executable arguments. + + :return: the executable arguments + """ + return self._exe_args + + @exe_args.setter + def exe_args(self, value: t.Sequence[str]) -> None: + """Set the executable arguments. + + :param value: the executable arguments + :raises TypeError: if exe_args is not sequence of str + """ + + if not ( + isinstance(value, collections.abc.Sequence) + and (all(isinstance(x, str) for x in value)) + ): + raise TypeError("exe_args argument was not of type sequence of str") + + self._exe_args = list(value) + + @property + def exe_arg_parameters(self) -> t.Mapping[str, t.Sequence[t.Sequence[str]]]: + """Return attached executable argument parameters. + + :return: the executable argument parameters + """ + return self._exe_arg_parameters + + @exe_arg_parameters.setter + def exe_arg_parameters( + self, value: t.Mapping[str, t.Sequence[t.Sequence[str]]] + ) -> None: + """Set the executable argument parameters. + + :param value: the executable argument parameters + :raises TypeError: if exe_arg_parameters is not mapping + of str and sequences of sequences of strings + """ + + if not ( + isinstance(value, collections.abc.Mapping) + and ( + all( + isinstance(key, str) + and isinstance(val, collections.abc.Sequence) + and all( + isinstance(subval, collections.abc.Sequence) for subval in val + ) + and all( + isinstance(item, str) + for item in itertools.chain.from_iterable(val) + ) + for key, val in value.items() + ) + ) + ): + raise TypeError( + "exe_arg_parameters argument was not of type " + "mapping of str and sequences of sequences of strings" + ) + + self._exe_arg_parameters = copy.deepcopy(value) + + @property + def files(self) -> EntityFiles: + """Return attached EntityFiles object. + + :return: the EntityFiles object of files to be copied, symlinked, + and/or configured prior to execution + """ + return self._files + + @files.setter + def files(self, value: EntityFiles) -> None: + """Set the EntityFiles object. + + :param value: the EntityFiles object of files to be copied, symlinked, + and/or configured prior to execution + :raises TypeError: if files is not of type EntityFiles + """ + + if not isinstance(value, EntityFiles): + raise TypeError("files argument was not of type EntityFiles") + self._files = copy.deepcopy(value) + + @property + def file_parameters(self) -> t.Mapping[str, t.Sequence[str]]: + """Return the attached file parameters. + + :return: the file parameters + """ + return self._file_parameters + + @file_parameters.setter + def file_parameters(self, value: t.Mapping[str, t.Sequence[str]]) -> None: + """Set the file parameters. + + :param value: the file parameters + :raises TypeError: if file_parameters is not a mapping of str and + sequence of str + """ + + if not ( + isinstance(value, t.Mapping) + and ( + all( + isinstance(key, str) + and isinstance(val, collections.abc.Sequence) + and all(isinstance(subval, str) for subval in val) + for key, val in value.items() + ) + ) + ): + raise TypeError( + "file_parameters argument was not of type mapping of str " + "and sequence of str" + ) + + self._file_parameters = dict(value) + + @property + def permutation_strategy(self) -> str | strategies.PermutationStrategyType: + """Return the permutation strategy + + :return: the permutation strategy + """ + return self._permutation_strategy + + @permutation_strategy.setter + def permutation_strategy( + self, value: str | strategies.PermutationStrategyType + ) -> None: + """Set the permutation strategy + + :param value: the permutation strategy + :raises TypeError: if permutation_strategy is not str or + PermutationStrategyType + """ + + if not (callable(value) or isinstance(value, str)): + raise TypeError( + "permutation_strategy argument was not of " + "type str or PermutationStrategyType" + ) + self._permutation_strategy = value + + @property + def max_permutations(self) -> int: + """Return the maximum permutations + + :return: the max permutations + """ + return self._max_permutations + + @max_permutations.setter + def max_permutations(self, value: int) -> None: + """Set the maximum permutations + + :param value: the max permutations + :raises TypeError: max_permutations argument was not of type int + """ + if not isinstance(value, int): + raise TypeError("max_permutations argument was not of type int") + + self._max_permutations = value + + @property + def replicas(self) -> int: + """Return the number of replicas. + + :return: the number of replicas + """ + return self._replicas + + @replicas.setter + def replicas(self, value: int) -> None: + """Set the number of replicas. + + :return: the number of replicas + :raises TypeError: replicas argument was not of type int + """ + if not isinstance(value, int): + raise TypeError("replicas argument was not of type int") + if value <= 0: + raise ValueError("Number of replicas must be a positive integer") + + self._replicas = value + + def _create_applications(self) -> tuple[Application, ...]: + """Generate a collection of Application instances based on the Ensembles attributes. + + This method uses a permutation strategy to create various combinations of file + parameters and executable arguments. Each combination is then replicated according + to the specified number of replicas, resulting in a set of Application instances. + + :return: A tuple of Application instances + """ + permutation_strategy = strategies.resolve(self.permutation_strategy) + + combinations = permutation_strategy( + self.file_parameters, self.exe_arg_parameters, self.max_permutations + ) + combinations = combinations if combinations else [ParamSet({}, {})] + permutations_ = itertools.chain.from_iterable( + itertools.repeat(permutation, self.replicas) for permutation in combinations + ) + return tuple( + Application( + name=f"{self.name}-{i}", + exe=self.exe, + exe_args=self.exe_args, + file_parameters=permutation.params, + ) + for i, permutation in enumerate(permutations_) + ) + + def build_jobs(self, settings: LaunchSettings) -> tuple[Job, ...]: + """Expand an Ensemble into a list of deployable Jobs and apply + identical LaunchSettings to each Job. + + The number of Jobs returned is controlled by the Ensemble attributes: + - Ensemble.exe_arg_parameters + - Ensemble.file_parameters + - Ensemble.permutation_strategy + - Ensemble.max_permutations + - Ensemble.replicas + + Consider the example below: + + .. highlight:: python + .. code-block:: python + + # Create LaunchSettings + my_launch_settings = LaunchSettings(...) + + # Initialize the Ensemble + ensemble = Ensemble("my_name", "echo", "hello world", replicas=3) + # Expand Ensemble into Jobs + ensemble_as_jobs = ensemble.build_jobs(my_launch_settings) + + By calling `build_jobs` on `ensemble`, three Jobs are returned because + three replicas were specified. Each Job will have the provided LaunchSettings. + + :param settings: LaunchSettings to apply to each Job + :return: Sequence of Jobs with the provided LaunchSettings + :raises TypeError: if the ids argument is not type LaunchSettings + :raises ValueError: if the LaunchSettings provided are empty + """ + if not isinstance(settings, LaunchSettings): + raise TypeError("ids argument was not of type LaunchSettings") + apps = self._create_applications() + if not apps: + raise ValueError("There are no members as part of this ensemble") + return tuple(Job(app, settings, app.name) for app in apps) diff --git a/smartsim/builders/utils/strategies.py b/smartsim/builders/utils/strategies.py new file mode 100644 index 0000000000..e3a2527a52 --- /dev/null +++ b/smartsim/builders/utils/strategies.py @@ -0,0 +1,262 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Generation Strategies + +from __future__ import annotations + +import functools +import itertools +import random +import typing as t +from dataclasses import dataclass, field + +from smartsim.error import errors + + +@dataclass(frozen=True) +class ParamSet: + """ + Represents a set of file parameters and execution arguments as parameters. + """ + + params: dict[str, str] = field(default_factory=dict) + exe_args: dict[str, list[str]] = field(default_factory=dict) + + +# Type alias for the shape of a permutation strategy callable +PermutationStrategyType = t.Callable[ + [t.Mapping[str, t.Sequence[str]], t.Mapping[str, t.Sequence[t.Sequence[str]]], int], + list[ParamSet], +] + +# Map of globally registered strategy names to registered strategy callables +_REGISTERED_STRATEGIES: t.Final[dict[str, PermutationStrategyType]] = {} + + +def _register(name: str) -> t.Callable[ + [PermutationStrategyType], + PermutationStrategyType, +]: + """Create a decorator to globally register a permutation strategy under a + given name. + + :param name: The name under which to register a strategy + :return: A decorator to register a permutation strategy function + """ + + def _impl(fn: PermutationStrategyType) -> PermutationStrategyType: + """Add a strategy function to the globally registered strategies under + the `name` caught in the closure. + + :param fn: A permutation strategy + :returns: The original strategy, unaltered + :raises ValueError: A strategy under name caught in the closure has + already been registered + """ + if name in _REGISTERED_STRATEGIES: + msg = f"A strategy with the name '{name}' has already been registered" + raise ValueError(msg) + _REGISTERED_STRATEGIES[name] = fn + return fn + + return _impl + + +def resolve(strategy: str | PermutationStrategyType) -> PermutationStrategyType: + """Look-up or sanitize a permutation strategy: + + - When `strategy` is a `str` it will look for a globally registered + strategy function by that name. + + - When `strategy` is a `callable` it is will return a sanitized + strategy function. + + :param strategy: The name of a registered strategy or a custom + permutation strategy + :return: A valid permutation strategy callable + """ + if callable(strategy): + return _make_sanitized_custom_strategy(strategy) + try: + return _REGISTERED_STRATEGIES[strategy] + except KeyError: + raise ValueError( + f"Failed to find an ensembling strategy by the name of '{strategy}'." + f"All known strategies are:\n{', '.join(_REGISTERED_STRATEGIES)}" + ) from None + + +def _make_sanitized_custom_strategy( + fn: PermutationStrategyType, +) -> PermutationStrategyType: + """Take a callable that satisfies the shape of a permutation strategy and + return a sanitized version for future callers. + + The sanitized version of the permutation strategy will intercept any + exceptions raised by the original permutation and re-raise a + `UserStrategyError`. + + The sanitized version will also check the type of the value returned from + the original callable, and if it does conform to the expected return type, + a `UserStrategyError` will be raised. + + :param fn: A custom user strategy function + :return: A sanitized version of the custom strategy function + """ + + @functools.wraps(fn) + def _impl( + params: t.Mapping[str, t.Sequence[str]], + exe_args: t.Mapping[str, t.Sequence[t.Sequence[str]]], + n_permutations: int = -1, + ) -> list[ParamSet]: + try: + permutations = fn(params, exe_args, n_permutations) + except Exception as e: + raise errors.UserStrategyError(str(fn)) from e + if not isinstance(permutations, list) or not all( + isinstance(permutation, ParamSet) for permutation in permutations + ): + raise errors.UserStrategyError(str(fn)) + return permutations + + return _impl + + +@_register("all_perm") +def create_all_permutations( + params: t.Mapping[str, t.Sequence[str]], + exe_arg: t.Mapping[str, t.Sequence[t.Sequence[str]]], + n_permutations: int = -1, +) -> list[ParamSet]: + """Take two mapping parameters to possible values and return a sequence of + all possible permutations of those parameters. + For example calling: + .. highlight:: python + .. code-block:: python + create_all_permutations({"SPAM": ["a", "b"], + "EGGS": ["c", "d"]}, + {"EXE": [["a"], ["b", "c"]], + "ARGS": [["d"], ["e", "f"]]}, + 1 + ) + Would result in the following permutations (not necessarily in this order): + .. highlight:: python + .. code-block:: python + [ParamSet(params={'SPAM': 'a', 'EGGS': 'c'}, + exe_args={'EXE': ['a'], 'ARGS': ['d']})] + :param file_params: A mapping of file parameter names to possible values + :param exe_arg_params: A mapping of exe arg parameter names to possible values + :param n_permutations: The maximum number of permutations to sample from + the sequence of all permutations + :return: A sequence of ParamSets of all possible permutations + """ + file_params_permutations = itertools.product(*params.values()) + param_zip = ( + dict(zip(params, permutation)) for permutation in file_params_permutations + ) + + exe_arg_params_permutations = itertools.product(*exe_arg.values()) + exe_arg_params_permutations_ = ( + tuple(map(list, sequence)) for sequence in exe_arg_params_permutations + ) + exe_arg_zip = ( + dict(zip(exe_arg, permutation)) for permutation in exe_arg_params_permutations_ + ) + + combinations = itertools.product(param_zip, exe_arg_zip) + param_set: t.Iterable[ParamSet] = ( + ParamSet(file_param, exe_arg) for file_param, exe_arg in combinations + ) + if n_permutations >= 0: + param_set = itertools.islice(param_set, n_permutations) + return list(param_set) + + +@_register("step") +def step_values( + params: t.Mapping[str, t.Sequence[str]], + exe_args: t.Mapping[str, t.Sequence[t.Sequence[str]]], + n_permutations: int = -1, +) -> list[ParamSet]: + """Take two mapping parameters to possible values and return a sequence of + stepped values until a possible values sequence runs out of possible + values. + For example calling: + .. highlight:: python + .. code-block:: python + step_values({"SPAM": ["a", "b"], + "EGGS": ["c", "d"]}, + {"EXE": [["a"], ["b", "c"]], + "ARGS": [["d"], ["e", "f"]]}, + 1 + ) + Would result in the following permutations: + .. highlight:: python + .. code-block:: python + [ParamSet(params={'SPAM': 'a', 'EGGS': 'c'}, + exe_args={'EXE': ['a'], 'ARGS': ['d']})] + :param file_params: A mapping of file parameter names to possible values + :param exe_arg_params: A mapping of exe arg parameter names to possible values + :param n_permutations: The maximum number of permutations to sample from + the sequence of step permutations + :return: A sequence of ParamSets of stepped values + """ + param_zip: t.Iterable[tuple[str, ...]] = zip(*params.values()) + param_zip_ = (dict(zip(params, step)) for step in param_zip) + + exe_arg_zip: t.Iterable[tuple[t.Sequence[str], ...]] = zip(*exe_args.values()) + exe_arg_zip_ = (map(list, sequence) for sequence in exe_arg_zip) + exe_arg_zip__ = (dict(zip(exe_args, step)) for step in exe_arg_zip_) + + param_set: t.Iterable[ParamSet] = ( + ParamSet(file_param, exe_arg) + for file_param, exe_arg in zip(param_zip_, exe_arg_zip__) + ) + if n_permutations >= 0: + param_set = itertools.islice(param_set, n_permutations) + return list(param_set) + + +@_register("random") +def random_permutations( + params: t.Mapping[str, t.Sequence[str]], + exe_args: t.Mapping[str, t.Sequence[t.Sequence[str]]], + n_permutations: int = -1, +) -> list[ParamSet]: + """Take two mapping parameters to possible values and return a sequence of + length `n_permutations` sampled randomly from all possible permutations + :param file_params: A mapping of file parameter names to possible values + :param exe_arg_params: A mapping of exe arg parameter names to possible values + :param n_permutations: The maximum number of permutations to sample from + the sequence of all permutations + :return: A sequence of ParamSets of sampled permutations + """ + permutations = create_all_permutations(params, exe_args, -1) + if 0 <= n_permutations < len(permutations): + permutations = random.sample(permutations, n_permutations) + return permutations diff --git a/smartsim/database/__init__.py b/smartsim/database/__init__.py index 106f8e1e24..0801c682bd 100644 --- a/smartsim/database/__init__.py +++ b/smartsim/database/__init__.py @@ -24,4 +24,4 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .orchestrator import Orchestrator +from .orchestrator import FeatureStore diff --git a/smartsim/database/orchestrator.py b/smartsim/database/orchestrator.py index e5e99c8932..c29c781a17 100644 --- a/smartsim/database/orchestrator.py +++ b/smartsim/database/orchestrator.py @@ -35,25 +35,19 @@ from shlex import split as sh_split import psutil -from smartredis import Client, ConfigOptions -from smartredis.error import RedisReplyError -from .._core.config import CONFIG -from .._core.utils import db_is_active -from .._core.utils.helpers import is_valid_cmd, unpack_db_identifier +from smartsim.entity._mock import Mock + +from .._core.utils.helpers import is_valid_cmd, unpack_fs_identifier from .._core.utils.network import get_ip_from_host from .._core.utils.shell import execute_cmd -from ..entity import DBNode, EntityList, TelemetryConfiguration -from ..error import ( - SmartSimError, - SSConfigError, - SSDBFilesNotParseable, - SSUnsupportedError, -) +from ..entity import FSNode, TelemetryConfiguration +from ..error import SmartSimError, SSDBFilesNotParseable, SSUnsupportedError from ..log import get_logger from ..servertype import CLUSTERED, STANDALONE from ..settings import ( AprunSettings, + BatchSettings, BsubBatchSettings, JsrunSettings, MpiexecSettings, @@ -61,15 +55,29 @@ OrterunSettings, PalsMpiexecSettings, QsubBatchSettings, + RunSettings, SbatchSettings, SrunSettings, + create_batch_settings, + create_run_settings, ) -from ..settings.base import BatchSettings, RunSettings -from ..settings.settings import create_batch_settings, create_run_settings from ..wlm import detect_launcher logger = get_logger(__name__) + +class Client(Mock): + """Mock Client""" + + +class ConfigOptions(Mock): + """Mock ConfigOptions""" + + +def fs_is_active(): + return False + + by_launcher: t.Dict[str, t.List[str]] = { "dragon": [""], "slurm": ["srun", "mpirun", "mpiexec"], @@ -129,7 +137,7 @@ def _get_single_command( if run_command == "srun" and getenv("SLURM_HET_SIZE") is not None: msg = ( - "srun can not launch an orchestrator with single_cmd=True in " + "srun can not launch an FeatureStore with single_cmd=True in " + "a hetereogeneous job. Automatically switching to single_cmd=False." ) logger.info(msg) @@ -140,7 +148,7 @@ def _get_single_command( if run_command == "aprun": msg = ( - "aprun can not launch an orchestrator with batch=True and " + "aprun can not launch an FeatureStore with batch=True and " + "single_cmd=True. Automatically switching to single_cmd=False." ) logger.info(msg) @@ -152,13 +160,13 @@ def _get_single_command( def _check_local_constraints(launcher: str, batch: bool) -> None: """Check that the local launcher is not launched with invalid batch config""" if launcher == "local" and batch: - msg = "Local orchestrator can not be launched with batch=True" + msg = "Local FeatureStore can not be launched with batch=True" raise SmartSimError(msg) # pylint: disable-next=too-many-public-methods -class Orchestrator(EntityList[DBNode]): - """The Orchestrator is an in-memory database that can be launched +class FeatureStore: + """The FeatureStore is an in-memory database that can be launched alongside entities in SmartSim. Data can be transferred between entities by using one of the Python, C, C++ or Fortran clients within an entity. @@ -171,7 +179,7 @@ def __init__( interface: t.Union[str, t.List[str]] = "lo", launcher: str = "local", run_command: str = "auto", - db_nodes: int = 1, + fs_nodes: int = 1, batch: bool = False, hosts: t.Optional[t.Union[t.List[str], str]] = None, account: t.Optional[str] = None, @@ -182,14 +190,14 @@ def __init__( threads_per_queue: t.Optional[int] = None, inter_op_threads: t.Optional[int] = None, intra_op_threads: t.Optional[int] = None, - db_identifier: str = "orchestrator", + fs_identifier: str = "featurestore", **kwargs: t.Any, ) -> None: - """Initialize an ``Orchestrator`` reference for local launch + """Initialize an ``FeatureStore`` reference for local launch - Extra configurations for RedisAI + Extra configurations - :param path: path to location of ``Orchestrator`` directory + :param path: path to location of ``FeatureStore`` directory :param port: TCP/IP port :param interface: network interface(s) :param launcher: type of launcher being used, options are "slurm", "pbs", @@ -197,18 +205,18 @@ def __init__( an attempt will be made to find an available launcher on the system. :param run_command: specify launch binary or detect automatically - :param db_nodes: number of database shards + :param fs_nodes: number of feature store shards :param batch: run as a batch workload :param hosts: specify hosts to launch on :param account: account to run batch on :param time: walltime for batch 'HH:MM:SS' format - :param alloc: allocation to launch database on + :param alloc: allocation to launch feature store on :param single_cmd: run all shards with one (MPMD) command :param threads_per_queue: threads per GPU device :param inter_op_threads: threads across CPU operations :param intra_op_threads: threads per CPU operation - :param db_identifier: an identifier to distinguish this orchestrator in - multiple-database experiments + :param fs_identifier: an identifier to distinguish this FeatureStore in + multiple-feature store experiments """ self.launcher, self.run_command = _autodetect(launcher, run_command) _check_run_command(self.launcher, self.run_command) @@ -234,11 +242,11 @@ def __init__( gpus_per_shard = int(kwargs.pop("gpus_per_shard", 0)) cpus_per_shard = int(kwargs.pop("cpus_per_shard", 4)) super().__init__( - name=db_identifier, + name=fs_identifier, path=str(path), port=port, interface=interface, - db_nodes=db_nodes, + fs_nodes=fs_nodes, batch=batch, launcher=self.launcher, run_command=self.run_command, @@ -252,26 +260,9 @@ def __init__( **kwargs, ) - # detect if we can find at least the redis binaries. We - # don't want to force the user to launch with RedisAI so - # it's ok if that isn't present. - try: - # try to obtain redis binaries needed to launch Redis - # will raise SSConfigError if not found - self._redis_exe # pylint: disable=W0104 - self._redis_conf # pylint: disable=W0104 - CONFIG.database_cli # pylint: disable=W0104 - except SSConfigError as e: - raise SSConfigError( - "SmartSim not installed with pre-built extensions (Redis)\n" - "Use the `smart` cli tool to install needed extensions\n" - "or set SMARTSIM_REDIS_SERVER_EXE and SMARTSIM_REDIS_CLI_EXE " - "in your environment\nSee documentation for more information" - ) from e - if self.launcher != "local": self.batch_settings = self._build_batch_settings( - db_nodes, + fs_nodes, alloc or "", batch, account or "", @@ -285,10 +276,8 @@ def __init__( mpilike = run_command in ["mpirun", "mpiexec", "orterun"] if mpilike and not self._mpi_has_sge_support(): raise SmartSimError( - ( - "hosts argument required when launching ", - "Orchestrator with mpirun", - ) + "hosts argument required when launching " + f"{type(self).__name__} with mpirun" ) self._reserved_run_args: t.Dict[t.Type[RunSettings], t.List[str]] = {} self._reserved_batch_args: t.Dict[t.Type[BatchSettings], t.List[str]] = {} @@ -311,45 +300,45 @@ def _mpi_has_sge_support(self) -> bool: return False @property - def db_identifier(self) -> str: - """Return the DB identifier, which is common to a DB and all of its nodes + def fs_identifier(self) -> str: + """Return the FS identifier, which is common to a FS and all of its nodes - :return: DB identifier + :return: FS identifier """ return self.name @property def num_shards(self) -> int: - """Return the number of DB shards contained in the Orchestrator. - This might differ from the number of ``DBNode`` objects, as each - ``DBNode`` may start more than one shard (e.g. with MPMD). + """Return the number of FS shards contained in the FeatureStore. + This might differ from the number of ``FSNode`` objects, as each + ``FSNode`` may start more than one shard (e.g. with MPMD). - :returns: the number of DB shards contained in the Orchestrator + :returns: the number of FS shards contained in the FeatureStore """ return sum(node.num_shards for node in self.entities) @property - def db_nodes(self) -> int: - """Read only property for the number of nodes an ``Orchestrator`` is + def fs_nodes(self) -> int: + """Read only property for the number of nodes an ``FeatureStore`` is launched across. Notice that SmartSim currently assumes that each shard will be launched on its own node. Therefore this property is currently an alias to the ``num_shards`` attribute. - :returns: Number of database nodes + :returns: Number of feature store nodes """ return self.num_shards @property def hosts(self) -> t.List[str]: - """Return the hostnames of Orchestrator instance hosts + """Return the hostnames of FeatureStore instance hosts - Note that this will only be populated after the orchestrator + Note that this will only be populated after the FeatureStore has been launched by SmartSim. - :return: the hostnames of Orchestrator instance hosts + :return: the hostnames of FeatureStore instance hosts """ if not self._hosts: - self._hosts = self._get_db_hosts() + self._hosts = self._get_fs_hosts() return self._hosts @property @@ -370,22 +359,22 @@ def reset_hosts(self) -> None: self.set_hosts(self._user_hostlist) def remove_stale_files(self) -> None: - """Can be used to remove database files of a previous launch""" + """Can be used to remove feature store files of a previous launch""" - for db in self.entities: - db.remove_stale_dbnode_files() + for fs in self.entities: + fs.remove_stale_fsnode_files() def get_address(self) -> t.List[str]: - """Return database addresses + """Return feature store addresses :return: addresses - :raises SmartSimError: If database address cannot be found or is not active + :raises SmartSimError: If feature store address cannot be found or is not active """ if not self._hosts: - raise SmartSimError("Could not find database address") + raise SmartSimError("Could not find feature store address") if not self.is_active(): - raise SmartSimError("Database is not active") + raise SmartSimError("Feature store is not active") return self._get_address() def _get_address(self) -> t.List[str]: @@ -395,50 +384,26 @@ def _get_address(self) -> t.List[str]: ] def is_active(self) -> bool: - """Check if the database is active + """Check if the feature store is active - :return: True if database is active, False otherwise + :return: True if feature store is active, False otherwise """ try: hosts = self.hosts except SSDBFilesNotParseable: return False - return db_is_active(hosts, self.ports, self.num_shards) - - @property - def _rai_module(self) -> t.Tuple[str, ...]: - """Get the RedisAI module from third-party installations - - :return: Tuple of args to pass to the orchestrator exe - to load and configure the RedisAI - """ - module = ["--loadmodule", CONFIG.redisai] - if self.queue_threads: - module.extend(("THREADS_PER_QUEUE", str(self.queue_threads))) - if self.inter_threads: - module.extend(("INTER_OP_PARALLELISM", str(self.inter_threads))) - if self.intra_threads: - module.extend(("INTRA_OP_PARALLELISM", str(self.intra_threads))) - return tuple(module) - - @property - def _redis_exe(self) -> str: - return CONFIG.database_exe - - @property - def _redis_conf(self) -> str: - return CONFIG.database_conf + return fs_is_active(hosts, self.ports, self.num_shards) @property def checkpoint_file(self) -> str: - """Get the path to the checkpoint file for this Orchestrator + """Get the path to the checkpoint file for this Feature Store :return: Path to the checkpoint file if it exists, otherwise a None """ return osp.join(self.path, "smartsim_db.dat") def set_cpus(self, num_cpus: int) -> None: - """Set the number of CPUs available to each database shard + """Set the number of CPUs available to each feature store shard This effectively will determine how many cpus can be used for compute threads, background threads, and network I/O. @@ -455,19 +420,19 @@ def set_cpus(self, num_cpus: int) -> None: if hasattr(self.batch_settings, "set_cpus_per_task"): self.batch_settings.set_cpus_per_task(num_cpus) - for db in self.entities: - db.run_settings.set_cpus_per_task(num_cpus) - if db.is_mpmd and hasattr(db.run_settings, "mpmd"): - for mpmd in db.run_settings.mpmd: + for fs in self.entities: + fs.run_settings.set_cpus_per_task(num_cpus) + if fs.is_mpmd and hasattr(fs.run_settings, "mpmd"): + for mpmd in fs.run_settings.mpmd: mpmd.set_cpus_per_task(num_cpus) def set_walltime(self, walltime: str) -> None: - """Set the batch walltime of the orchestrator + """Set the batch walltime of the FeatureStore - Note: This will only effect orchestrators launched as a batch + Note: This will only effect FeatureStores launched as a batch :param walltime: amount of time e.g. 10 hours is 10:00:00 - :raises SmartSimError: if orchestrator isn't launching as batch + :raises SmartSimError: if FeatureStore isn't launching as batch """ if not self.batch: raise SmartSimError("Not running as batch, cannot set walltime") @@ -476,7 +441,7 @@ def set_walltime(self, walltime: str) -> None: self.batch_settings.set_walltime(walltime) def set_hosts(self, host_list: t.Union[t.List[str], str]) -> None: - """Specify the hosts for the ``Orchestrator`` to launch on + """Specify the hosts for the ``FeatureStore`` to launch on :param host_list: list of host (compute node names) :raises TypeError: if wrong type @@ -493,8 +458,8 @@ def set_hosts(self, host_list: t.Union[t.List[str], str]) -> None: self.batch_settings.set_hostlist(host_list) if self.launcher == "lsf": - for db in self.entities: - db.set_hosts(host_list) + for fs in self.entities: + fs.set_hosts(host_list) elif ( self.launcher == "pals" and isinstance(self.entities[0].run_settings, PalsMpiexecSettings) @@ -503,26 +468,26 @@ def set_hosts(self, host_list: t.Union[t.List[str], str]) -> None: # In this case, --hosts is a global option, set it to first run command self.entities[0].run_settings.set_hostlist(host_list) else: - for host, db in zip(host_list, self.entities): - if isinstance(db.run_settings, AprunSettings): + for host, fs in zip(host_list, self.entities): + if isinstance(fs.run_settings, AprunSettings): if not self.batch: - db.run_settings.set_hostlist([host]) + fs.run_settings.set_hostlist([host]) else: - db.run_settings.set_hostlist([host]) + fs.run_settings.set_hostlist([host]) - if db.is_mpmd and hasattr(db.run_settings, "mpmd"): - for i, mpmd_runsettings in enumerate(db.run_settings.mpmd, 1): + if fs.is_mpmd and hasattr(fs.run_settings, "mpmd"): + for i, mpmd_runsettings in enumerate(fs.run_settings.mpmd, 1): mpmd_runsettings.set_hostlist(host_list[i]) def set_batch_arg(self, arg: str, value: t.Optional[str] = None) -> None: - """Set a batch argument the orchestrator should launch with + """Set a batch argument the FeatureStore should launch with Some commonly used arguments such as --job-name are used by SmartSim and will not be allowed to be set. :param arg: batch argument to set e.g. "exclusive" :param value: batch param - set to None if no param value - :raises SmartSimError: if orchestrator not launching as batch + :raises SmartSimError: if FeatureStore not launching as batch """ if not hasattr(self, "batch_settings") or not self.batch_settings: raise SmartSimError("Not running as batch, cannot set batch_arg") @@ -530,13 +495,13 @@ def set_batch_arg(self, arg: str, value: t.Optional[str] = None) -> None: if arg in self._reserved_batch_args[type(self.batch_settings)]: logger.warning( f"Can not set batch argument {arg}: " - "it is a reserved keyword in Orchestrator" + "it is a reserved keyword in FeatureStore" ) else: self.batch_settings.batch_args[arg] = value def set_run_arg(self, arg: str, value: t.Optional[str] = None) -> None: - """Set a run argument the orchestrator should launch + """Set a run argument the FeatureStore should launch each node with (it will be passed to `jrun`) Some commonly used arguments are used @@ -549,24 +514,24 @@ def set_run_arg(self, arg: str, value: t.Optional[str] = None) -> None: if arg in self._reserved_run_args[type(self.entities[0].run_settings)]: logger.warning( f"Can not set batch argument {arg}: " - "it is a reserved keyword in Orchestrator" + "it is a reserved keyword in FeatureStore" ) else: - for db in self.entities: - db.run_settings.run_args[arg] = value - if db.is_mpmd and hasattr(db.run_settings, "mpmd"): - for mpmd in db.run_settings.mpmd: + for fs in self.entities: + fs.run_settings.run_args[arg] = value + if fs.is_mpmd and hasattr(fs.run_settings, "mpmd"): + for mpmd in fs.run_settings.mpmd: mpmd.run_args[arg] = value def enable_checkpoints(self, frequency: int) -> None: - """Sets the database's save configuration to save the DB every 'frequency' - seconds given that at least one write operation against the DB occurred in - that time. E.g., if `frequency` is 900, then the database will save to disk + """Sets the feature store's save configuration to save the fs every 'frequency' + seconds given that at least one write operation against the fs occurred in + that time. E.g., if `frequency` is 900, then the feature store will save to disk after 900 seconds if there is at least 1 change to the dataset. - :param frequency: the given number of seconds before the DB saves + :param frequency: the given number of seconds before the FS saves """ - self.set_db_conf("save", f"{frequency} 1") + self.set_fs_conf("save", f"{frequency} 1") def set_max_memory(self, mem: str) -> None: """Sets the max memory configuration. By default there is no memory limit. @@ -583,33 +548,33 @@ def set_max_memory(self, mem: str) -> None: :param mem: the desired max memory size e.g. 3gb :raises SmartSimError: If 'mem' is an invalid memory value - :raises SmartSimError: If database is not active + :raises SmartSimError: If feature store is not active """ - self.set_db_conf("maxmemory", mem) + self.set_fs_conf("maxmemory", mem) def set_eviction_strategy(self, strategy: str) -> None: - """Sets how the database will select what to remove when + """Sets how the feature store will select what to remove when 'maxmemory' is reached. The default is noeviction. :param strategy: The max memory policy to use e.g. "volatile-lru", "allkeys-lru", etc. :raises SmartSimError: If 'strategy' is an invalid maxmemory policy - :raises SmartSimError: If database is not active + :raises SmartSimError: If feature store is not active """ - self.set_db_conf("maxmemory-policy", strategy) + self.set_fs_conf("maxmemory-policy", strategy) def set_max_clients(self, clients: int = 50_000) -> None: """Sets the max number of connected clients at the same time. - When the number of DB shards contained in the orchestrator is + When the number of FS shards contained in the feature store is more than two, then every node will use two connections, one incoming and another outgoing. :param clients: the maximum number of connected clients """ - self.set_db_conf("maxclients", str(clients)) + self.set_fs_conf("maxclients", str(clients)) def set_max_message_size(self, size: int = 1_073_741_824) -> None: - """Sets the database's memory size limit for bulk requests, + """Sets the feature store's memory size limit for bulk requests, which are elements representing single strings. The default is 1 gigabyte. Message size must be greater than or equal to 1mb. The specified memory size should be an integer that represents @@ -618,16 +583,16 @@ def set_max_message_size(self, size: int = 1_073_741_824) -> None: :param size: maximum message size in bytes """ - self.set_db_conf("proto-max-bulk-len", str(size)) + self.set_fs_conf("proto-max-bulk-len", str(size)) - def set_db_conf(self, key: str, value: str) -> None: + def set_fs_conf(self, key: str, value: str) -> None: """Set any valid configuration at runtime without the need - to restart the database. All configuration parameters - that are set are immediately loaded by the database and + to restart the feature store. All configuration parameters + that are set are immediately loaded by the feature store and will take effect starting with the next command executed. :param key: the configuration parameter - :param value: the database configuration parameter's new value + :param value: the feature store configuration parameter's new value """ if self.is_active(): addresses = [] @@ -635,12 +600,12 @@ def set_db_conf(self, key: str, value: str) -> None: for port in self.ports: addresses.append(":".join([get_ip_from_host(host), str(port)])) - db_name, name = unpack_db_identifier(self.db_identifier, "_") + fs_name, name = unpack_fs_identifier(self.fs_identifier, "_") - environ[f"SSDB{db_name}"] = addresses[0] + environ[f"SSDB{fs_name}"] = addresses[0] - db_type = CLUSTERED if self.num_shards > 2 else STANDALONE - environ[f"SR_DB_TYPE{db_name}"] = db_type + fs_type = CLUSTERED if self.num_shards > 2 else STANDALONE + environ[f"SR_DB_TYPE{fs_name}"] = fs_type options = ConfigOptions.create_from_environment(name) client = Client(options) @@ -649,24 +614,20 @@ def set_db_conf(self, key: str, value: str) -> None: for address in addresses: client.config_set(key, value, address) - except RedisReplyError: - raise SmartSimError( - f"Invalid CONFIG key-value pair ({key}: {value})" - ) from None except TypeError: raise TypeError( "Incompatible function arguments. The key and value used for " - "setting the database configurations must be strings." + "setting the feature store configurations must be strings." ) from None else: raise SmartSimError( - "The SmartSim Orchestrator must be active in order to set the " - "database's configurations." + "The SmartSim FeatureStore must be active in order to set the " + "feature store's configurations." ) @staticmethod def _build_batch_settings( - db_nodes: int, + fs_nodes: int, alloc: str, batch: bool, account: str, @@ -684,7 +645,7 @@ def _build_batch_settings( # on or if user specified batch=False (alloc will be found through env) if not alloc and batch: batch_settings = create_batch_settings( - launcher, nodes=db_nodes, time=time, account=account, **kwargs + launcher, nodes=fs_nodes, time=time, account=account, **kwargs ) return batch_settings @@ -695,12 +656,12 @@ def _build_run_settings( exe_args: t.List[t.List[str]], *, run_args: t.Optional[t.Dict[str, t.Any]] = None, - db_nodes: int = 1, + fs_nodes: int = 1, single_cmd: bool = True, **kwargs: t.Any, ) -> RunSettings: run_args = {} if run_args is None else run_args - mpmd_nodes = single_cmd and db_nodes > 1 + mpmd_nodes = single_cmd and fs_nodes > 1 if mpmd_nodes: run_settings = create_run_settings( @@ -750,7 +711,7 @@ def _build_run_settings_lsf( if gpus_per_shard is None: raise ValueError("Expected an integer number of gpus per shard") - # We always run the DB on cpus 0:cpus_per_shard-1 + # We always run the fs on cpus 0:cpus_per_shard-1 # and gpus 0:gpus_per_shard-1 for shard_id, args in enumerate(exe_args): host = shard_id @@ -759,8 +720,8 @@ def _build_run_settings_lsf( run_settings = JsrunSettings(exe, args, run_args=run_args.copy()) run_settings.set_binding("none") - # This makes sure output is written to orchestrator_0.out, - # orchestrator_1.out, and so on + # This makes sure output is written to featurestore_0.out, + # featurestore_1.out, and so on run_settings.set_individual_output("_%t") erf_sets = { @@ -787,91 +748,93 @@ def _build_run_settings_lsf( def _initialize_entities( self, *, - db_nodes: int = 1, + fs_nodes: int = 1, single_cmd: bool = True, port: int = 6379, **kwargs: t.Any, ) -> None: - db_nodes = int(db_nodes) - if db_nodes == 2: - raise SSUnsupportedError("Orchestrator does not support clusters of size 2") + fs_nodes = int(fs_nodes) + if fs_nodes == 2: + raise SSUnsupportedError("FeatureStore does not support clusters of size 2") - if self.launcher == "local" and db_nodes > 1: + if self.launcher == "local" and fs_nodes > 1: raise ValueError( - "Local Orchestrator does not support multiple database shards" + "Local FeatureStore does not support multiple feature store shards" ) - mpmd_nodes = (single_cmd and db_nodes > 1) or self.launcher == "lsf" + mpmd_nodes = (single_cmd and fs_nodes > 1) or self.launcher == "lsf" if mpmd_nodes: self._initialize_entities_mpmd( - db_nodes=db_nodes, single_cmd=single_cmd, port=port, **kwargs + fs_nodes=fs_nodes, single_cmd=single_cmd, port=port, **kwargs ) else: - cluster = db_nodes >= 3 + cluster = fs_nodes >= 3 - for db_id in range(db_nodes): - db_node_name = "_".join((self.name, str(db_id))) + for fs_id in range(fs_nodes): + fs_node_name = "_".join((self.name, str(fs_id))) - # create the exe_args list for launching multiple databases - # per node. also collect port range for dbnode + # create the exe_args list for launching multiple feature stores + # per node. also collect port range for fsnode start_script_args = self._get_start_script_args( - db_node_name, port, cluster + fs_node_name, port, cluster ) - # if only launching 1 db per command, we don't need a + # if only launching 1 fs per command, we don't need a # list of exe args lists run_settings = self._build_run_settings( sys.executable, [start_script_args], port=port, **kwargs ) - node = DBNode( - db_node_name, + node = FSNode( + fs_node_name, self.path, - run_settings, - [port], - [db_node_name + ".out"], - self.db_identifier, + exe=sys.executable, + exe_args=[start_script_args], + run_settings=run_settings, + ports=[port], + output_files=[fs_node_name + ".out"], + fs_identifier=self.fs_identifier, ) self.entities.append(node) self.ports = [port] def _initialize_entities_mpmd( - self, *, db_nodes: int = 1, port: int = 6379, **kwargs: t.Any + self, *, fs_nodes: int = 1, port: int = 6379, **kwargs: t.Any ) -> None: - cluster = db_nodes >= 3 + cluster = fs_nodes >= 3 mpmd_node_name = self.name + "_0" exe_args_mpmd: t.List[t.List[str]] = [] - for db_id in range(db_nodes): - db_shard_name = "_".join((self.name, str(db_id))) - # create the exe_args list for launching multiple databases - # per node. also collect port range for dbnode + for fs_id in range(fs_nodes): + fs_shard_name = "_".join((self.name, str(fs_id))) + # create the exe_args list for launching multiple feature stores + # per node. also collect port range for fsnode start_script_args = self._get_start_script_args( - db_shard_name, port, cluster + fs_shard_name, port, cluster ) exe_args = " ".join(start_script_args) exe_args_mpmd.append(sh_split(exe_args)) run_settings: t.Optional[RunSettings] = None if self.launcher == "lsf": run_settings = self._build_run_settings_lsf( - sys.executable, exe_args_mpmd, db_nodes=db_nodes, port=port, **kwargs + sys.executable, exe_args_mpmd, fs_nodes=fs_nodes, port=port, **kwargs ) - output_files = [f"{self.name}_{db_id}.out" for db_id in range(db_nodes)] + output_files = [f"{self.name}_{fs_id}.out" for fs_id in range(fs_nodes)] else: run_settings = self._build_run_settings( - sys.executable, exe_args_mpmd, db_nodes=db_nodes, port=port, **kwargs + sys.executable, exe_args_mpmd, fs_nodes=fs_nodes, port=port, **kwargs ) output_files = [mpmd_node_name + ".out"] if not run_settings: raise ValueError(f"Could not build run settings for {self.launcher}") - node = DBNode( + node = FSNode( mpmd_node_name, self.path, run_settings, [port], output_files, - db_identifier=self.db_identifier, + fs_identifier=self.fs_identifier, ) self.entities.append(node) self.ports = [port] @@ -881,13 +844,7 @@ def _get_start_script_args( ) -> t.List[str]: cmd = [ "-m", - "smartsim._core.entrypoints.redis", # entrypoint - f"+orc-exe={self._redis_exe}", # redis-server - f"+conf-file={self._redis_conf}", # redis.conf file - "+rai-module", # load redisai.so - *self._rai_module, f"+name={name}", # name of node - f"+port={port}", # redis port f"+ifname={','.join(self._interfaces)}", # pass interface to start script ] if cluster: @@ -895,13 +852,13 @@ def _get_start_script_args( return cmd - def _get_db_hosts(self) -> t.List[str]: + def _get_fs_hosts(self) -> t.List[str]: hosts = [] - for db in self.entities: - if not db.is_mpmd: - hosts.append(db.host) + for fs in self.entities: + if not fs.is_mpmd: + hosts.append(fs.host) else: - hosts.extend(db.hosts) + hosts.extend(fs.hosts) return hosts def _check_network_interface(self) -> None: diff --git a/smartsim/entity/__init__.py b/smartsim/entity/__init__.py index 40f03fcddc..4f4c256289 100644 --- a/smartsim/entity/__init__.py +++ b/smartsim/entity/__init__.py @@ -24,10 +24,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .dbnode import DBNode +from .application import Application +from .dbnode import FSNode from .dbobject import * -from .ensemble import Ensemble from .entity import SmartSimEntity, TelemetryConfiguration -from .entityList import EntityList, EntitySequence -from .files import TaggedFilesHierarchy -from .model import Model diff --git a/smartsim/entity/strategies.py b/smartsim/entity/_mock.py similarity index 53% rename from smartsim/entity/strategies.py rename to smartsim/entity/_mock.py index 2af88b58e7..8f1043ed3c 100644 --- a/smartsim/entity/strategies.py +++ b/smartsim/entity/_mock.py @@ -24,41 +24,23 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Generation Strategies -import random -import typing as t -from itertools import product - +"""This module contains stubs of functionality that is not currently +implemented. -# create permutations of all parameters -# single model if parameters only have one value -def create_all_permutations( - param_names: t.List[str], param_values: t.List[t.List[str]], _n_models: int = 0 -) -> t.List[t.Dict[str, str]]: - perms = list(product(*param_values)) - all_permutations = [] - for permutation in perms: - temp_model = dict(zip(param_names, permutation)) - all_permutations.append(temp_model) - return all_permutations +THIS WHOLE MODULE SHOULD BE REMOVED IN FUTURE!! +""" +from __future__ import annotations -def step_values( - param_names: t.List[str], param_values: t.List[t.List[str]], _n_models: int = 0 -) -> t.List[t.Dict[str, str]]: - permutations = [] - for param_value in zip(*param_values): - permutations.append(dict(zip(param_names, param_value))) - return permutations +import typing as t -def random_permutations( - param_names: t.List[str], param_values: t.List[t.List[str]], n_models: int = 0 -) -> t.List[t.Dict[str, str]]: - permutations = create_all_permutations(param_names, param_values) +class Mock: + """Base mock class""" - # sample from available permutations if n_models is specified - if n_models and n_models < len(permutations): - permutations = random.sample(permutations, n_models) + def __init__(self, *_: t.Any, **__: t.Any): ... + def __getattr__(self, _: str) -> Mock: + return type(self)() - return permutations + def __deepcopy__(self, _: dict[t.Any, t.Any]) -> Mock: + return type(self)() diff --git a/smartsim/entity/application.py b/smartsim/entity/application.py new file mode 100644 index 0000000000..501279c85f --- /dev/null +++ b/smartsim/entity/application.py @@ -0,0 +1,263 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import collections +import copy +import textwrap +import typing as t +from os import path as osp + +from .._core.generation.operations.operations import FileSysOperationSet +from .._core.utils.helpers import expand_exe_path +from ..log import get_logger +from .entity import SmartSimEntity + +logger = get_logger(__name__) + + +# TODO: Remove this supression when we strip fileds/functionality +# (run-settings/batch_settings/params_as_args/etc)! +# pylint: disable-next=too-many-public-methods + + +class Application(SmartSimEntity): + """The Application class enables users to execute computational tasks in an + Experiment workflow, such as launching compiled applications, running scripts, + or performing general computational operations. + + Applications are designed to be added to Jobs, where LaunchSettings are also + provided to inject launcher-specific behavior into the Job. + """ + + def __init__( + self, + name: str, + exe: str, + exe_args: t.Optional[t.Union[str, t.Sequence[str]]] = None, + file_parameters: ( + t.Mapping[str, str] | None + ) = None, # TODO remove when Ensemble is addressed + ) -> None: + """Initialize an ``Application`` + + Applications require a name and an executable. Optionally, users may provide + executable arguments, files and file parameters. To create a simple Application + that echos `Hello World!`, consider the example below: + + .. highlight:: python + .. code-block:: python + + # Create an application that runs the 'echo' command + my_app = Application(name="my_app", exe="echo", exe_args="Hello World!") + + :param name: name of the application + :param exe: executable to run + :param exe_args: executable arguments + """ + super().__init__(name) + """The name of the application""" + self._exe = expand_exe_path(exe) + """The executable to run""" + self._exe_args = self._build_exe_args(exe_args) or [] + """The executable arguments""" + self.files = FileSysOperationSet([]) + """Attach files""" + self._file_parameters = ( + copy.deepcopy(file_parameters) if file_parameters else {} + ) + """TODO MOCK until Ensemble is implemented""" + """Files to be copied, symlinked, and/or configured prior to execution""" + self._incoming_entities: t.List[SmartSimEntity] = [] + """Entities for which the prefix will have to be known by other entities""" + self._key_prefixing_enabled = False + """Unique prefix to avoid key collisions""" + + @property + def exe(self) -> str: + """Return the executable. + + :return: the executable + """ + return self._exe + + @exe.setter + def exe(self, value: str) -> None: + """Set the executable. + + :param value: the executable + :raises TypeError: exe argument is not int + + """ + if not isinstance(value, str): + raise TypeError("exe argument was not of type str") + + if value == "": + raise ValueError("exe cannot be an empty str") + + self._exe = value + + @property + def exe_args(self) -> t.MutableSequence[str]: + """Return the executable arguments. + + :return: the executable arguments + """ + return self._exe_args + + @exe_args.setter + def exe_args(self, value: t.Union[str, t.Sequence[str], None]) -> None: + """Set the executable arguments. + + :param value: the executable arguments + """ + self._exe_args = self._build_exe_args(value) + + def add_exe_args(self, args: t.Union[str, t.List[str], None]) -> None: + """Add executable arguments to executable + + :param args: executable arguments + """ + args = self._build_exe_args(args) + self._exe_args.extend(args) + + @property + def file_parameters(self) -> t.Mapping[str, str]: + """Return file parameters. + + :return: the file parameters + """ + return self._file_parameters + + @file_parameters.setter + def file_parameters(self, value: t.Mapping[str, str]) -> None: + """Set the file parameters. + + :param value: the file parameters + :raises TypeError: file_parameters argument is not a mapping of str and str + """ + if not ( + isinstance(value, t.Mapping) + and all( + isinstance(key, str) and isinstance(val, str) + for key, val in value.items() + ) + ): + raise TypeError( + "file_parameters argument was not of type mapping of str and str" + ) + self._file_parameters = copy.deepcopy(value) + + @property + def incoming_entities(self) -> t.List[SmartSimEntity]: + """Return incoming entities. + + :return: incoming entities + """ + return self._incoming_entities + + @incoming_entities.setter + def incoming_entities(self, value: t.List[SmartSimEntity]) -> None: + """Set the incoming entities. + + :param value: incoming entities + :raises TypeError: incoming_entities argument is not a list of SmartSimEntity + """ + if not isinstance(value, list) or not all( + isinstance(x, SmartSimEntity) for x in value + ): + raise TypeError( + "incoming_entities argument was not of type list of SmartSimEntity" + ) + + self._incoming_entities = copy.copy(value) + + @property + def key_prefixing_enabled(self) -> bool: + """Return whether key prefixing is enabled for the application. + + :param value: key prefixing enabled + """ + return self._key_prefixing_enabled + + @key_prefixing_enabled.setter + def key_prefixing_enabled(self, value: bool) -> None: + """Set whether key prefixing is enabled for the application. + + :param value: key prefixing enabled + :raises TypeError: key prefixings enabled argument was not of type bool + """ + if not isinstance(value, bool): + raise TypeError("key_prefixing_enabled argument was not of type bool") + + self.key_prefixing_enabled = copy.deepcopy(value) + + def as_executable_sequence(self) -> t.Sequence[str]: + """Converts the executable and its arguments into a sequence of program arguments. + + :return: a sequence of strings representing the executable and its arguments + """ + return [self.exe, *self.exe_args] + + @staticmethod + def _build_exe_args(exe_args: t.Union[str, t.Sequence[str], None]) -> t.List[str]: + """Check and convert exe_args input to a desired collection format + + :param exe_args: + :raises TypeError: if exe_args is not a list of str or str + """ + if not exe_args: + return [] + + if not ( + isinstance(exe_args, str) + or ( + isinstance(exe_args, collections.abc.Sequence) + and all(isinstance(arg, str) for arg in exe_args) + ) + ): + raise TypeError("Executable arguments were not a list of str or a str.") + + if isinstance(exe_args, str): + return exe_args.split() + + return list(exe_args) + + def __str__(self) -> str: # pragma: no cover + exe_args_str = "\n".join(self.exe_args) + entities_str = "\n".join(str(entity) for entity in self.incoming_entities) + return textwrap.dedent(f"""\ + Name: {self.name} + Type: {self.type} + Executable: + {self.exe} + Executable Arguments: + {exe_args_str} + Incoming Entities: + {entities_str} + Key Prefixing Enabled: {self.key_prefixing_enabled} + """) diff --git a/smartsim/entity/dbnode.py b/smartsim/entity/dbnode.py index d371357f85..60a69b5222 100644 --- a/smartsim/entity/dbnode.py +++ b/smartsim/entity/dbnode.py @@ -34,20 +34,21 @@ from dataclasses import dataclass from .._core.config import CONFIG +from .._core.utils.helpers import expand_exe_path from ..error import SSDBFilesNotParseable from ..log import get_logger -from ..settings.base import RunSettings +from ..settings import RunSettings from .entity import SmartSimEntity logger = get_logger(__name__) -class DBNode(SmartSimEntity): - """DBNode objects are the entities that make up the orchestrator. - Each database node can be launched in a cluster configuration - and take launch multiple databases per node. +class FSNode(SmartSimEntity): + """FSNode objects are the entities that make up the feature store. + Each feature store node can be launched in a cluster configuration + and take launch multiple feature stores per node. - To configure how each instance of the database operates, look + To configure how each instance of the feature store operates, look into the smartsimdb.conf. """ @@ -55,13 +56,18 @@ def __init__( self, name: str, path: str, + exe: str, + exe_args: t.List[str], run_settings: RunSettings, ports: t.List[int], output_files: t.List[str], - db_identifier: str = "", + fs_identifier: str = "", ) -> None: - """Initialize a database node within an orchestrator.""" - super().__init__(name, path, run_settings) + """Initialize a feature store node within an feature store.""" + super().__init__(name) + self.run_settings = run_settings + self.exe = [exe] if run_settings.container else [expand_exe_path(exe)] + self.exe_args = exe_args or [] self.ports = ports self._hosts: t.Optional[t.List[str]] = None @@ -72,7 +78,7 @@ def __init__( ): raise ValueError("output_files must be of type list[str]") self._output_files = output_files - self.db_identifier = db_identifier + self.fs_identifier = fs_identifier @property def num_shards(self) -> int: @@ -88,14 +94,14 @@ def host(self) -> str: (host,) = self.hosts except ValueError: raise ValueError( - f"Multiple hosts detected for this DB Node: {', '.join(self.hosts)}" + f"Multiple hosts detected for this FS Node: {', '.join(self.hosts)}" ) from None return host @property def hosts(self) -> t.List[str]: if not self._hosts: - self._hosts = self._parse_db_hosts() + self._hosts = self._parse_fs_hosts() return self._hosts def clear_hosts(self) -> None: @@ -112,9 +118,9 @@ def is_mpmd(self) -> bool: def set_hosts(self, hosts: t.List[str]) -> None: self._hosts = [str(host) for host in hosts] - def remove_stale_dbnode_files(self) -> None: + def remove_stale_fsnode_files(self) -> None: """This function removes the .conf, .err, and .out files that - have the same names used by this dbnode that may have been + have the same names used by this fsnode that may have been created from a previous experiment execution. """ @@ -146,7 +152,7 @@ def _get_cluster_conf_filenames(self, port: int) -> t.List[str]: # cov-lsf This function should bu used if and only if ``_mpmd==True`` :param port: port number - :return: the dbnode configuration file name + :return: the fsnode configuration file name """ if self.num_shards == 1: return [f"nodes-{self.name}-{port}.conf"] @@ -182,7 +188,7 @@ def _parse_launched_shard_info_from_files( return cls._parse_launched_shard_info_from_iterable(ifstream, num_shards) def get_launched_shard_info(self) -> "t.List[LaunchedShardData]": - """Parse the launched database shard info from the output files + """Parse the launched feature store shard info from the output files :raises SSDBFilesNotParseable: if all shard info could not be found :return: The found launched shard info @@ -206,16 +212,16 @@ def get_launched_shard_info(self) -> "t.List[LaunchedShardData]": if len(ips) < self.num_shards: msg = ( - f"Failed to parse the launched DB shard information from file(s) " + f"Failed to parse the launched FS shard information from file(s) " f"{', '.join(output_files)}. Found the information for " - f"{len(ips)} out of {self.num_shards} DB shards." + f"{len(ips)} out of {self.num_shards} FS shards." ) logger.error(msg) raise SSDBFilesNotParseable(msg) return ips - def _parse_db_hosts(self) -> t.List[str]: - """Parse the database hosts/IPs from the output files + def _parse_fs_hosts(self) -> t.List[str]: + """Parse the feature store hosts/IPs from the output files The IP address is preferred, but if hostname is only present then a lookup to /etc/hosts is done through the socket library. @@ -228,7 +234,7 @@ def _parse_db_hosts(self) -> t.List[str]: @dataclass(frozen=True) class LaunchedShardData: - """Data class to write and parse data about a launched database shard""" + """Data class to write and parse data about a launched feature store shard""" name: str hostname: str diff --git a/smartsim/entity/dbobject.py b/smartsim/entity/dbobject.py index fa9983c502..477564e83d 100644 --- a/smartsim/entity/dbobject.py +++ b/smartsim/entity/dbobject.py @@ -31,28 +31,28 @@ from ..error import SSUnsupportedError -__all__ = ["DBObject", "DBModel", "DBScript"] +__all__ = ["FSObject", "FSModel", "FSScript"] -_DBObjectFuncT = t.TypeVar("_DBObjectFuncT", str, bytes) +_FSObjectFuncT = t.TypeVar("_FSObjectFuncT", str, bytes) -class DBObject(t.Generic[_DBObjectFuncT]): - """Base class for ML objects residing on DB. Should not +class FSObject(t.Generic[_FSObjectFuncT]): + """Base class for ML objects residing on FS. Should not be instantiated. """ def __init__( self, name: str, - func: t.Optional[_DBObjectFuncT], + func: t.Optional[_FSObjectFuncT], file_path: t.Optional[str], device: str, devices_per_node: int, first_device: int, ) -> None: self.name = name - self.func: t.Optional[_DBObjectFuncT] = func + self.func: t.Optional[_FSObjectFuncT] = func self.file: t.Optional[Path] = ( None # Need to have this explicitly to check on it ) @@ -108,9 +108,9 @@ def _check_device(device: str) -> str: return device def _enumerate_devices(self) -> t.List[str]: - """Enumerate devices for a DBObject + """Enumerate devices for a FSObject - :param dbobject: DBObject to enumerate + :param FSObject: FSObject to enumerate :return: list of device names """ @@ -150,7 +150,7 @@ def _check_devices( raise ValueError(msg) -class DBScript(DBObject[str]): +class FSScript(FSObject[str]): def __init__( self, name: str, @@ -205,7 +205,7 @@ def __str__(self) -> str: return desc_str -class DBModel(DBObject[bytes]): +class FSModel(FSObject[bytes]): def __init__( self, name: str, @@ -222,7 +222,7 @@ def __init__( inputs: t.Optional[t.List[str]] = None, outputs: t.Optional[t.List[str]] = None, ) -> None: - """A TF, TF-lite, PT, or ONNX model to load into the DB at runtime + """A TF, TF-lite, PT, or ONNX model to load into the FS at runtime One of either model (in memory representation) or model_path (file) must be provided diff --git a/smartsim/entity/ensemble.py b/smartsim/entity/ensemble.py deleted file mode 100644 index 965b10db7f..0000000000 --- a/smartsim/entity/ensemble.py +++ /dev/null @@ -1,573 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import os.path as osp -import typing as t -from copy import deepcopy -from os import getcwd - -from tabulate import tabulate - -from smartsim._core.types import Device - -from ..error import ( - EntityExistsError, - SmartSimError, - SSUnsupportedError, - UserStrategyError, -) -from ..log import get_logger -from ..settings.base import BatchSettings, RunSettings -from .dbobject import DBModel, DBScript -from .entity import SmartSimEntity -from .entityList import EntityList -from .model import Model -from .strategies import create_all_permutations, random_permutations, step_values - -logger = get_logger(__name__) - -StrategyFunction = t.Callable[ - [t.List[str], t.List[t.List[str]], int], t.List[t.Dict[str, str]] -] - - -class Ensemble(EntityList[Model]): - """``Ensemble`` is a group of ``Model`` instances that can - be treated as a reference to a single instance. - """ - - def __init__( - self, - name: str, - params: t.Dict[str, t.Any], - path: t.Optional[str] = getcwd(), - params_as_args: t.Optional[t.List[str]] = None, - batch_settings: t.Optional[BatchSettings] = None, - run_settings: t.Optional[RunSettings] = None, - perm_strat: str = "all_perm", - **kwargs: t.Any, - ) -> None: - """Initialize an Ensemble of Model instances. - - The kwargs argument can be used to pass custom input - parameters to the permutation strategy. - - :param name: name of the ensemble - :param params: parameters to expand into ``Model`` members - :param params_as_args: list of params that should be used as command - line arguments to the ``Model`` member executables and not written - to generator files - :param batch_settings: describes settings for ``Ensemble`` as batch workload - :param run_settings: describes how each ``Model`` should be executed - :param replicas: number of ``Model`` replicas to create - a keyword - argument of kwargs - :param perm_strategy: strategy for expanding ``params`` into - ``Model`` instances from params argument - options are "all_perm", "step", "random" - or a callable function. - :return: ``Ensemble`` instance - """ - self.params = params or {} - self.params_as_args = params_as_args or [] - self._key_prefixing_enabled = True - self.batch_settings = batch_settings - self.run_settings = run_settings - self.replicas: str - - super().__init__(name, str(path), perm_strat=perm_strat, **kwargs) - - @property - def models(self) -> t.Collection[Model]: - """An alias for a shallow copy of the ``entities`` attribute""" - return list(self.entities) - - def _initialize_entities(self, **kwargs: t.Any) -> None: - """Initialize all the models within the ensemble based - on the parameters passed to the ensemble and the permutation - strategy given at init. - - :raises UserStrategyError: if user generation strategy fails - """ - strategy = self._set_strategy(kwargs.pop("perm_strat")) - replicas = kwargs.pop("replicas", None) - self.replicas = replicas - - # if a ensemble has parameters and run settings, create - # the ensemble and assign run_settings to each member - if self.params: - if self.run_settings: - param_names, params = self._read_model_parameters() - - # Compute all combinations of model parameters and arguments - n_models = kwargs.get("n_models", 0) - all_model_params = strategy(param_names, params, n_models) - if not isinstance(all_model_params, list): - raise UserStrategyError(strategy) - - for i, param_set in enumerate(all_model_params): - if not isinstance(param_set, dict): - raise UserStrategyError(strategy) - run_settings = deepcopy(self.run_settings) - model_name = "_".join((self.name, str(i))) - model = Model( - name=model_name, - params=param_set, - path=osp.join(self.path, model_name), - run_settings=run_settings, - params_as_args=self.params_as_args, - ) - model.enable_key_prefixing() - model.params_to_args() - logger.debug( - f"Created ensemble member: {model_name} in {self.name}" - ) - self.add_model(model) - # cannot generate models without run settings - else: - raise SmartSimError( - "Ensembles without 'params' or 'replicas' argument to " - "expand into members cannot be given run settings" - ) - else: - if self.run_settings: - if replicas: - for i in range(replicas): - model_name = "_".join((self.name, str(i))) - model = Model( - name=model_name, - params={}, - path=osp.join(self.path, model_name), - run_settings=deepcopy(self.run_settings), - ) - model.enable_key_prefixing() - logger.debug( - f"Created ensemble member: {model_name} in {self.name}" - ) - self.add_model(model) - else: - raise SmartSimError( - "Ensembles without 'params' or 'replicas' argument to " - "expand into members cannot be given run settings" - ) - # if no params, no run settings and no batch settings, error because we - # don't know how to run the ensemble - elif not self.batch_settings: - raise SmartSimError( - "Ensemble must be provided batch settings or run settings" - ) - else: - logger.info("Empty ensemble created for batch launch") - - def add_model(self, model: Model) -> None: - """Add a model to this ensemble - - :param model: model instance to be added - :raises TypeError: if model is not an instance of ``Model`` - :raises EntityExistsError: if model already exists in this ensemble - """ - if not isinstance(model, Model): - raise TypeError( - f"Argument to add_model was of type {type(model)}, not Model" - ) - # "in" operator uses model name for __eq__ - if model in self.entities: - raise EntityExistsError( - f"Model {model.name} already exists in ensemble {self.name}" - ) - - if self._db_models: - self._extend_entity_db_models(model, self._db_models) - if self._db_scripts: - self._extend_entity_db_scripts(model, self._db_scripts) - - self.entities.append(model) - - def register_incoming_entity(self, incoming_entity: SmartSimEntity) -> None: - """Register future communication between entities. - - Registers the named data sources that this entity - has access to by storing the key_prefix associated - with that entity - - Only python clients can have multiple incoming connections - - :param incoming_entity: The entity that data will be received from - """ - for model in self.models: - model.register_incoming_entity(incoming_entity) - - def enable_key_prefixing(self) -> None: - """If called, each model within this ensemble will prefix its key with its - own model name. - """ - for model in self.models: - model.enable_key_prefixing() - - def query_key_prefixing(self) -> bool: - """Inquire as to whether each model within the ensemble will prefix their keys - - :returns: True if all models have key prefixing enabled, False otherwise - """ - return all(model.query_key_prefixing() for model in self.models) - - def attach_generator_files( - self, - to_copy: t.Optional[t.List[str]] = None, - to_symlink: t.Optional[t.List[str]] = None, - to_configure: t.Optional[t.List[str]] = None, - ) -> None: - """Attach files to each model within the ensemble for generation - - Attach files needed for the entity that, upon generation, - will be located in the path of the entity. - - During generation, files "to_copy" are copied into - the path of the entity, and files "to_symlink" are - symlinked into the path of the entity. - - Files "to_configure" are text based model input files where - parameters for the model are set. Note that only models - support the "to_configure" field. These files must have - fields tagged that correspond to the values the user - would like to change. The tag is settable but defaults - to a semicolon e.g. THERMO = ;10; - - :param to_copy: files to copy - :param to_symlink: files to symlink - :param to_configure: input files with tagged parameters - """ - for model in self.models: - model.attach_generator_files( - to_copy=to_copy, to_symlink=to_symlink, to_configure=to_configure - ) - - @property - def attached_files_table(self) -> str: - """Return a plain-text table with information about files - attached to models belonging to this ensemble. - - :returns: A table of all files attached to all models - """ - if not self.models: - return "The ensemble is empty, no files to show." - - table = tabulate( - [[model.name, model.attached_files_table] for model in self.models], - headers=["Model name", "Files"], - tablefmt="grid", - ) - - return table - - def print_attached_files(self) -> None: - """Print table of attached files to std out""" - print(self.attached_files_table) - - @staticmethod - def _set_strategy(strategy: str) -> StrategyFunction: - """Set the permutation strategy for generating models within - the ensemble - - :param strategy: name of the strategy or callable function - :raises SSUnsupportedError: if str name is not supported - :return: strategy function - """ - if strategy == "all_perm": - return create_all_permutations - if strategy == "step": - return step_values - if strategy == "random": - return random_permutations - if callable(strategy): - return strategy - raise SSUnsupportedError( - f"Permutation strategy given is not supported: {strategy}" - ) - - def _read_model_parameters(self) -> t.Tuple[t.List[str], t.List[t.List[str]]]: - """Take in the parameters given to the ensemble and prepare to - create models for the ensemble - - :raises TypeError: if params are of the wrong type - :return: param names and values for permutation strategy - """ - - if not isinstance(self.params, dict): - raise TypeError( - "Ensemble initialization argument 'params' must be of type dict" - ) - - param_names: t.List[str] = [] - parameters: t.List[t.List[str]] = [] - for name, val in self.params.items(): - param_names.append(name) - - if isinstance(val, list): - val = [str(v) for v in val] - parameters.append(val) - elif isinstance(val, (int, str)): - parameters.append([str(val)]) - else: - raise TypeError( - "Incorrect type for ensemble parameters\n" - + "Must be list, int, or string." - ) - return param_names, parameters - - def add_ml_model( - self, - name: str, - backend: str, - model: t.Optional[bytes] = None, - model_path: t.Optional[str] = None, - device: str = Device.CPU.value.upper(), - devices_per_node: int = 1, - first_device: int = 0, - batch_size: int = 0, - min_batch_size: int = 0, - min_batch_timeout: int = 0, - tag: str = "", - inputs: t.Optional[t.List[str]] = None, - outputs: t.Optional[t.List[str]] = None, - ) -> None: - """A TF, TF-lite, PT, or ONNX model to load into the DB at runtime - - Each ML Model added will be loaded into an - orchestrator (converged or not) prior to the execution - of every entity belonging to this ensemble - - One of either model (in memory representation) or model_path (file) - must be provided - - :param name: key to store model under - :param model: model in memory - :param model_path: serialized model - :param backend: name of the backend (TORCH, TF, TFLITE, ONNX) - :param device: name of device for execution - :param devices_per_node: number of GPUs per node in multiGPU nodes - :param first_device: first device in multi-GPU nodes to use for execution, - defaults to 0; ignored if devices_per_node is 1 - :param batch_size: batch size for execution - :param min_batch_size: minimum batch size for model execution - :param min_batch_timeout: time to wait for minimum batch size - :param tag: additional tag for model information - :param inputs: model inputs (TF only) - :param outputs: model outupts (TF only) - """ - db_model = DBModel( - name=name, - backend=backend, - model=model, - model_file=model_path, - device=device, - devices_per_node=devices_per_node, - first_device=first_device, - batch_size=batch_size, - min_batch_size=min_batch_size, - min_batch_timeout=min_batch_timeout, - tag=tag, - inputs=inputs, - outputs=outputs, - ) - dupe = next( - ( - db_model.name - for ensemble_ml_model in self._db_models - if ensemble_ml_model.name == db_model.name - ), - None, - ) - if dupe: - raise SSUnsupportedError( - f'An ML Model with name "{db_model.name}" already exists' - ) - self._db_models.append(db_model) - for entity in self.models: - self._extend_entity_db_models(entity, [db_model]) - - def add_script( - self, - name: str, - script: t.Optional[str] = None, - script_path: t.Optional[str] = None, - device: str = Device.CPU.value.upper(), - devices_per_node: int = 1, - first_device: int = 0, - ) -> None: - """TorchScript to launch with every entity belonging to this ensemble - - Each script added to the model will be loaded into an - orchestrator (converged or not) prior to the execution - of every entity belonging to this ensemble - - Device selection is either "GPU" or "CPU". If many devices are - present, a number can be passed for specification e.g. "GPU:1". - - Setting ``devices_per_node=N``, with N greater than one will result - in the model being stored in the first N devices of type ``device``. - - One of either script (in memory string representation) or script_path (file) - must be provided - - :param name: key to store script under - :param script: TorchScript code - :param script_path: path to TorchScript code - :param device: device for script execution - :param devices_per_node: number of devices on each host - :param first_device: first device to use on each host - """ - db_script = DBScript( - name=name, - script=script, - script_path=script_path, - device=device, - devices_per_node=devices_per_node, - first_device=first_device, - ) - dupe = next( - ( - db_script.name - for ensemble_script in self._db_scripts - if ensemble_script.name == db_script.name - ), - None, - ) - if dupe: - raise SSUnsupportedError( - f'A Script with name "{db_script.name}" already exists' - ) - self._db_scripts.append(db_script) - for entity in self.models: - self._extend_entity_db_scripts(entity, [db_script]) - - def add_function( - self, - name: str, - function: t.Optional[str] = None, - device: str = Device.CPU.value.upper(), - devices_per_node: int = 1, - first_device: int = 0, - ) -> None: - """TorchScript function to launch with every entity belonging to this ensemble - - Each script function to the model will be loaded into a - non-converged orchestrator prior to the execution - of every entity belonging to this ensemble. - - For converged orchestrators, the :meth:`add_script` method should be used. - - Device selection is either "GPU" or "CPU". If many devices are - present, a number can be passed for specification e.g. "GPU:1". - - Setting ``devices_per_node=N``, with N greater than one will result - in the script being stored in the first N devices of type ``device``; - alternatively, setting ``first_device=M`` will result in the script - being stored on nodes M through M + N - 1. - - :param name: key to store function under - :param function: TorchScript code - :param device: device for script execution - :param devices_per_node: number of devices on each host - :param first_device: first device to use on each host - """ - db_script = DBScript( - name=name, - script=function, - device=device, - devices_per_node=devices_per_node, - first_device=first_device, - ) - dupe = next( - ( - db_script.name - for ensemble_script in self._db_scripts - if ensemble_script.name == db_script.name - ), - None, - ) - if dupe: - raise SSUnsupportedError( - f'A Script with name "{db_script.name}" already exists' - ) - self._db_scripts.append(db_script) - for entity in self.models: - self._extend_entity_db_scripts(entity, [db_script]) - - @staticmethod - def _extend_entity_db_models(model: Model, db_models: t.List[DBModel]) -> None: - """ - Ensures that the Machine Learning model names being added to the Ensemble - are unique. - - This static method checks if the provided ML model names already exist in - the Ensemble. An SSUnsupportedError is raised if any duplicate names are - found. Otherwise, it appends the given list of DBModels to the Ensemble. - - :param model: SmartSim Model object. - :param db_models: List of DBModels to append to the Ensemble. - """ - for add_ml_model in db_models: - dupe = next( - ( - db_model.name - for db_model in model.db_models - if db_model.name == add_ml_model.name - ), - None, - ) - if dupe: - raise SSUnsupportedError( - f'An ML Model with name "{add_ml_model.name}" already exists' - ) - model.add_ml_model_object(add_ml_model) - - @staticmethod - def _extend_entity_db_scripts(model: Model, db_scripts: t.List[DBScript]) -> None: - """ - Ensures that the script/function names being added to the Ensemble are unique. - - This static method checks if the provided script/function names already exist - in the Ensemble. An SSUnsupportedError is raised if any duplicate names - are found. Otherwise, it appends the given list of DBScripts to the - Ensemble. - - :param model: SmartSim Model object. - :param db_scripts: List of DBScripts to append to the Ensemble. - """ - for add_script in db_scripts: - dupe = next( - ( - add_script.name - for db_script in model.db_scripts - if db_script.name == add_script.name - ), - None, - ) - if dupe: - raise SSUnsupportedError( - f'A Script with name "{add_script.name}" already exists' - ) - model.add_script_object(add_script) diff --git a/smartsim/entity/entity.py b/smartsim/entity/entity.py index 012a767449..3f5a9eabd0 100644 --- a/smartsim/entity/entity.py +++ b/smartsim/entity/entity.py @@ -24,11 +24,16 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import abc import typing as t +from smartsim.launchable.job_group import JobGroup + if t.TYPE_CHECKING: - # pylint: disable-next=unused-import - import smartsim.settings.base + from smartsim.launchable.job import Job + from smartsim.settings.launch_settings import LaunchSettings class TelemetryConfiguration: @@ -89,34 +94,47 @@ def _on_disable(self) -> None: to perform actions when attempts to change configuration are made""" -class SmartSimEntity: +class SmartSimEntity(abc.ABC): def __init__( - self, name: str, path: str, run_settings: "smartsim.settings.base.RunSettings" + self, + name: str, ) -> None: """Initialize a SmartSim entity. - Each entity must have a name, path, and - run_settings. All entities within SmartSim + Each entity must have a name and path. All entities within SmartSim share these attributes. :param name: Name of the entity - :param path: path to output, error, and configuration files - :param run_settings: Launcher settings specified in the experiment - entity """ self.name = name - self.run_settings = run_settings - self.path = path + """The name of the application""" + + @abc.abstractmethod + def as_executable_sequence(self) -> t.Sequence[str]: + """Converts the executable and its arguments into a sequence of program arguments. + + :return: a sequence of strings representing the executable and its arguments + """ @property def type(self) -> str: """Return the name of the class""" return type(self).__name__ - def set_path(self, path: str) -> None: - if not isinstance(path, str): - raise TypeError("path argument must be a string") - self.path = path - def __repr__(self) -> str: return self.name + + +class CompoundEntity(abc.ABC): + """An interface to create different types of collections of launchables + from a single set of launch settings. + + Objects that implement this interface describe how to turn their entities + into a collection of jobs and this interface will handle coercion into + other collections for jobs with slightly different launching behavior. + """ + + @abc.abstractmethod + def build_jobs(self, settings: LaunchSettings) -> t.Collection[Job]: ... + def as_job_group(self, settings: LaunchSettings) -> JobGroup: + return JobGroup(list(self.build_jobs(settings))) diff --git a/smartsim/entity/entityList.py b/smartsim/entity/entityList.py deleted file mode 100644 index edaa886687..0000000000 --- a/smartsim/entity/entityList.py +++ /dev/null @@ -1,144 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import typing as t - -from .entity import SmartSimEntity - -if t.TYPE_CHECKING: - # pylint: disable-next=unused-import - import smartsim - -_T = t.TypeVar("_T", bound=SmartSimEntity) -# Old style pyint from TF 2.6.x does not know about pep484 style ``TypeVar`` names -# pylint: disable-next=invalid-name -_T_co = t.TypeVar("_T_co", bound=SmartSimEntity, covariant=True) - - -class EntitySequence(t.Generic[_T_co]): - """Abstract class for containers for SmartSimEntities""" - - def __init__(self, name: str, path: str, **kwargs: t.Any) -> None: - self.name: str = name - self.path: str = path - - # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - # WARNING: This class cannot be made truly covariant until the - # following properties are made read-only. It is currently - # designed for in-house type checking only!! - # - # Despite the fact that these properties are type hinted as - # ``Sequence``s, the underlying types must remain ``list``s as that is - # what subclasses are expecting when implementing their - # ``_initialize_entities`` methods. - # - # I'm leaving it "as is" for now as to not introduce a potential API - # break in case any users subclassed the invariant version of this - # class (``EntityList``), but a "proper" solution would be to turn - # ``EntitySequence``/``EntityList`` into proper ``abc.ABC``s and have - # the properties we expect to be initialized represented as abstract - # properties. An additional benefit of this solution is would be that - # users could actually initialize their entities in the ``__init__`` - # method, and it would remove the need for the cumbersome and - # un-type-hint-able ``_initialize_entities`` method by returning all - # object construction into the class' constructor. - # --------------------------------------------------------------------- - # - self.entities: t.Sequence[_T_co] = [] - self._db_models: t.Sequence["smartsim.entity.DBModel"] = [] - self._db_scripts: t.Sequence["smartsim.entity.DBScript"] = [] - # - # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - - self._initialize_entities(**kwargs) - - def _initialize_entities(self, **kwargs: t.Any) -> None: - """Initialize the SmartSimEntity objects in the container""" - raise NotImplementedError - - @property - def db_models(self) -> t.Iterable["smartsim.entity.DBModel"]: - """Return an immutable collection of attached models""" - return (model for model in self._db_models) - - @property - def db_scripts(self) -> t.Iterable["smartsim.entity.DBScript"]: - """Return an immutable collection of attached scripts""" - return (script for script in self._db_scripts) - - @property - def batch(self) -> bool: - """Property indicating whether or not the entity sequence should be - launched as a batch job - - :return: ``True`` if entity sequence should be launched as a batch job, - ``False`` if the members will be launched individually. - """ - # pylint: disable-next=no-member - return hasattr(self, "batch_settings") and self.batch_settings - - @property - def type(self) -> str: - """Return the name of the class""" - return type(self).__name__ - - def set_path(self, new_path: str) -> None: - self.path = new_path - for entity in self.entities: - entity.path = new_path - - def __getitem__(self, name: str) -> t.Optional[_T_co]: - for entity in self.entities: - if entity.name == name: - return entity - return None - - def __iter__(self) -> t.Iterator[_T_co]: - for entity in self.entities: - yield entity - - def __len__(self) -> int: - return len(self.entities) - - -class EntityList(EntitySequence[_T]): - """An invariant subclass of an ``EntitySequence`` with mutable containers""" - - def __init__(self, name: str, path: str, **kwargs: t.Any) -> None: - super().__init__(name, path, **kwargs) - # Change container types to be invariant ``list``s - self.entities: t.List[_T] = list(self.entities) - self._db_models: t.List["smartsim.entity.DBModel"] = list(self._db_models) - self._db_scripts: t.List["smartsim.entity.DBScript"] = list(self._db_scripts) - - def _initialize_entities(self, **kwargs: t.Any) -> None: - """Initialize the SmartSimEntity objects in the container""" - # Need to identically re-define this "abstract method" or pylint - # complains that we are trying to define a concrete implementation of - # an abstract class despite the fact that we want this class to also be - # abstract. All the more reason to turn both of these classes into - # ``abc.ABC``s in my opinion. - raise NotImplementedError diff --git a/smartsim/entity/files.py b/smartsim/entity/files.py index d00e946e2a..42586f153e 100644 --- a/smartsim/entity/files.py +++ b/smartsim/entity/files.py @@ -23,25 +23,25 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import os import typing as t from os import path from tabulate import tabulate +# TODO remove when Ensemble is addressed class EntityFiles: """EntityFiles are the files a user wishes to have available to - models and nodes within SmartSim. Each entity has a method + applications and nodes within SmartSim. Each entity has a method `entity.attach_generator_files()` that creates one of these objects such that at generation time, each file type will be - present within the generated model or node directory. + present within the generated application or node directory. - Tagged files are the configuration files for a model that - can be searched through and edited by the ModelWriter. + Tagged files are the configuration files for a application that + can be searched through and edited by the ApplicationWriter. Copy files are files that a user wants to copy into the - model or node directory without searching through and + application or node directory without searching through and editing them for tags. Lastly, symlink can be used for big datasets or input @@ -57,16 +57,15 @@ def __init__( ) -> None: """Initialize an EntityFiles instance - :param tagged: tagged files for model configuration - :param copy: files or directories to copy into model + :param tagged: tagged files for application configuration + :param copy: files or directories to copy into application or node directories - :param symlink: files to symlink into model or node + :param symlink: files to symlink into application or node directories """ self.tagged = tagged or [] self.copy = copy or [] self.link = symlink or [] - self.tagged_hierarchy = None self._check_files() def _check_files(self) -> None: @@ -82,10 +81,6 @@ def _check_files(self) -> None: self.copy = self._type_check_files(self.copy, "Copyable") self.link = self._type_check_files(self.link, "Symlink") - self.tagged_hierarchy = TaggedFilesHierarchy.from_list_paths( - self.tagged, dir_contents_to_base=True - ) - for i, value in enumerate(self.copy): self.copy[i] = self._check_path(value) @@ -147,142 +142,3 @@ def __str__(self) -> str: return "No file attached to this entity." return tabulate(values, headers=["Strategy", "Files"], tablefmt="grid") - - -class TaggedFilesHierarchy: - """The TaggedFilesHierarchy represents a directory - containing potentially tagged files and subdirectories. - - TaggedFilesHierarchy.base is the directory path from - the the root of the generated file structure - - TaggedFilesHierarchy.files is a collection of paths to - files that need to be copied to directory that the - TaggedFilesHierarchy represents - - TaggedFilesHierarchy.dirs is a collection of child - TaggedFilesHierarchy, each representing a subdirectory - that needs to generated - - By performing a depth first search over the entire - hierarchy starting at the root directory structure, the - tagged file directory structure can be replicated - """ - - def __init__(self, parent: t.Optional[t.Any] = None, subdir_name: str = "") -> None: - """Initialize a TaggedFilesHierarchy - - :param parent: The parent hierarchy of the new hierarchy, - must be None if creating a root hierarchy, - must be provided if creating a subhierachy - :param subdir_name: Name of subdirectory representd by the new hierarchy, - must be "" if creating a root hierarchy, - must be any valid dir name if subhierarchy, - invalid names are ".", ".." or contain path seperators - :raises ValueError: if given a subdir_name without a parent, - if given a parent without a subdir_name, - or if the subdir_name is invalid - """ - if parent is None and subdir_name: - raise ValueError( - "TaggedFilesHierarchies should not have a subdirectory name without a" - + " parent" - ) - if parent is not None and not subdir_name: - raise ValueError( - "Child TaggedFilesHierarchies must have a subdirectory name" - ) - if subdir_name in {".", ".."} or path.sep in subdir_name: - raise ValueError( - "Child TaggedFilesHierarchies subdirectory names must not contain" - + " path seperators or be reserved dirs '.' or '..'" - ) - - if parent: - parent.dirs.add(self) - - self._base: str = path.join(parent.base, subdir_name) if parent else "" - self.parent: t.Any = parent - self.files: t.Set[str] = set() - self.dirs: t.Set[TaggedFilesHierarchy] = set() - - @property - def base(self) -> str: - """Property to ensure that self.base is read-only""" - return self._base - - @classmethod - def from_list_paths( - cls, path_list: t.List[str], dir_contents_to_base: bool = False - ) -> t.Any: - """Given a list of absolute paths to files and dirs, create and return - a TaggedFilesHierarchy instance representing the file hierarchy of - tagged files. All files in the path list will be placed in the base of - the file hierarchy. - - :param path_list: list of absolute paths to tagged files or dirs - containing tagged files - :param dir_contents_to_base: When a top level dir is encountered, if - this value is truthy, files in the dir are - put into the base hierarchy level. - Otherwise, a new sub level is created for - the dir - :return: A built tagged file hierarchy for the given files - """ - tagged_file_hierarchy = cls() - if dir_contents_to_base: - new_paths = [] - for tagged_path in path_list: - if os.path.isdir(tagged_path): - new_paths += [ - os.path.join(tagged_path, file) - for file in os.listdir(tagged_path) - ] - else: - new_paths.append(tagged_path) - path_list = new_paths - tagged_file_hierarchy._add_paths(path_list) - return tagged_file_hierarchy - - def _add_file(self, file: str) -> None: - """Add a file to the current level in the file hierarchy - - :param file: absoute path to a file to add to the hierarchy - """ - self.files.add(file) - - def _add_dir(self, dir_path: str) -> None: - """Add a dir contianing tagged files by creating a new sub level in the - tagged file hierarchy. All paths within the directroy are added to the - the new level sub level tagged file hierarchy - - :param dir: absoute path to a dir to add to the hierarchy - """ - tagged_file_hierarchy = TaggedFilesHierarchy(self, path.basename(dir_path)) - # pylint: disable-next=protected-access - tagged_file_hierarchy._add_paths( - [path.join(dir_path, file) for file in os.listdir(dir_path)] - ) - - def _add_paths(self, paths: t.List[str]) -> None: - """Takes a list of paths and iterates over it, determining if each - path is to a file or a dir and then appropriatly adding it to the - TaggedFilesHierarchy. - - :param paths: list of paths to files or dirs to add to the hierarchy - :raises ValueError: if link to dir is found - :raises FileNotFoundError: if path does not exist - """ - for candidate in paths: - candidate = os.path.abspath(candidate) - if os.path.isdir(candidate): - if os.path.islink(candidate): - raise ValueError( - "Tagged directories and thier subdirectories cannot be links" - + " to prevent circular directory structures" - ) - self._add_dir(candidate) - elif os.path.isfile(candidate): - self._add_file(candidate) - else: - raise FileNotFoundError(f"File or Directory {candidate} not found") diff --git a/smartsim/entity/model.py b/smartsim/entity/model.py deleted file mode 100644 index 3e8baad5cc..0000000000 --- a/smartsim/entity/model.py +++ /dev/null @@ -1,701 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -import itertools -import numbers -import re -import sys -import typing as t -import warnings -from os import getcwd -from os import path as osp - -from smartsim._core.types import Device - -from .._core.utils.helpers import cat_arg_and_value -from ..error import EntityExistsError, SSUnsupportedError -from ..log import get_logger -from ..settings.base import BatchSettings, RunSettings -from .dbobject import DBModel, DBScript -from .entity import SmartSimEntity -from .files import EntityFiles - -logger = get_logger(__name__) - - -def _parse_model_parameters(params_dict: t.Dict[str, t.Any]) -> t.Dict[str, str]: - """Convert the values in a params dict to strings - :raises TypeError: if params are of the wrong type - :return: param dictionary with values and keys cast as strings - """ - param_names: t.List[str] = [] - parameters: t.List[str] = [] - for name, val in params_dict.items(): - param_names.append(name) - if isinstance(val, (str, numbers.Number)): - parameters.append(str(val)) - else: - raise TypeError( - "Incorrect type for model parameters\n" - + "Must be numeric value or string." - ) - return dict(zip(param_names, parameters)) - - -class Model(SmartSimEntity): - def __init__( - self, - name: str, - params: t.Dict[str, str], - run_settings: RunSettings, - path: t.Optional[str] = getcwd(), - params_as_args: t.Optional[t.List[str]] = None, - batch_settings: t.Optional[BatchSettings] = None, - ): - """Initialize a ``Model`` - - :param name: name of the model - :param params: model parameters for writing into configuration files or - to be passed as command line arguments to executable. - :param path: path to output, error, and configuration files - :param run_settings: launcher settings specified in the experiment - :param params_as_args: list of parameters which have to be - interpreted as command line arguments to - be added to run_settings - :param batch_settings: Launcher settings for running the individual - model as a batch job - """ - super().__init__(name, str(path), run_settings) - self.params = _parse_model_parameters(params) - self.params_as_args = params_as_args - self.incoming_entities: t.List[SmartSimEntity] = [] - self._key_prefixing_enabled = False - self.batch_settings = batch_settings - self._db_models: t.List[DBModel] = [] - self._db_scripts: t.List[DBScript] = [] - self.files: t.Optional[EntityFiles] = None - - @property - def db_models(self) -> t.Iterable[DBModel]: - """Retrieve an immutable collection of attached models - - :return: Return an immutable collection of attached models - """ - return (model for model in self._db_models) - - @property - def db_scripts(self) -> t.Iterable[DBScript]: - """Retrieve an immutable collection attached of scripts - - :return: Return an immutable collection of attached scripts - """ - return (script for script in self._db_scripts) - - @property - def colocated(self) -> bool: - """Return True if this Model will run with a colocated Orchestrator - - :return: Return True of the Model will run with a colocated Orchestrator - """ - return bool(self.run_settings.colocated_db_settings) - - def register_incoming_entity(self, incoming_entity: SmartSimEntity) -> None: - """Register future communication between entities. - - Registers the named data sources that this entity - has access to by storing the key_prefix associated - with that entity - - :param incoming_entity: The entity that data will be received from - :raises SmartSimError: if incoming entity has already been registered - """ - if incoming_entity.name in [ - in_entity.name for in_entity in self.incoming_entities - ]: - raise EntityExistsError( - f"'{incoming_entity.name}' has already " - + "been registered as an incoming entity" - ) - - self.incoming_entities.append(incoming_entity) - - def enable_key_prefixing(self) -> None: - """If called, the entity will prefix its keys with its own model name""" - self._key_prefixing_enabled = True - - def disable_key_prefixing(self) -> None: - """If called, the entity will not prefix its keys with its own model name""" - self._key_prefixing_enabled = False - - def query_key_prefixing(self) -> bool: - """Inquire as to whether this entity will prefix its keys with its name - - :return: Return True if entity will prefix its keys with its name - """ - return self._key_prefixing_enabled - - def attach_generator_files( - self, - to_copy: t.Optional[t.List[str]] = None, - to_symlink: t.Optional[t.List[str]] = None, - to_configure: t.Optional[t.List[str]] = None, - ) -> None: - """Attach files to an entity for generation - - Attach files needed for the entity that, upon generation, - will be located in the path of the entity. Invoking this method - after files have already been attached will overwrite - the previous list of entity files. - - During generation, files "to_copy" are copied into - the path of the entity, and files "to_symlink" are - symlinked into the path of the entity. - - Files "to_configure" are text based model input files where - parameters for the model are set. Note that only models - support the "to_configure" field. These files must have - fields tagged that correspond to the values the user - would like to change. The tag is settable but defaults - to a semicolon e.g. THERMO = ;10; - - :param to_copy: files to copy - :param to_symlink: files to symlink - :param to_configure: input files with tagged parameters - """ - to_copy = to_copy or [] - to_symlink = to_symlink or [] - to_configure = to_configure or [] - - # Check that no file collides with the parameter file written - # by Generator. We check the basename, even though it is more - # restrictive than what we need (but it avoids relative path issues) - for strategy in [to_copy, to_symlink, to_configure]: - if strategy is not None and any( - osp.basename(filename) == "smartsim_params.txt" for filename in strategy - ): - raise ValueError( - "`smartsim_params.txt` is a file automatically " - + "generated by SmartSim and cannot be ovewritten." - ) - - self.files = EntityFiles(to_configure, to_copy, to_symlink) - - @property - def attached_files_table(self) -> str: - """Return a list of attached files as a plain text table - - :returns: String version of table - """ - if not self.files: - return "No file attached to this model." - return str(self.files) - - def print_attached_files(self) -> None: - """Print a table of the attached files on std out""" - print(self.attached_files_table) - - def colocate_db(self, *args: t.Any, **kwargs: t.Any) -> None: - """An alias for ``Model.colocate_db_tcp``""" - warnings.warn( - ( - "`colocate_db` has been deprecated and will be removed in a \n" - "future release. Please use `colocate_db_tcp` or `colocate_db_uds`." - ), - FutureWarning, - ) - self.colocate_db_tcp(*args, **kwargs) - - def colocate_db_uds( - self, - unix_socket: str = "/tmp/redis.socket", - socket_permissions: int = 755, - db_cpus: int = 1, - custom_pinning: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]] = None, - debug: bool = False, - db_identifier: str = "", - **kwargs: t.Any, - ) -> None: - """Colocate an Orchestrator instance with this Model over UDS. - - This method will initialize settings which add an unsharded - database to this Model instance. Only this Model will be able to communicate - with this colocated database by using Unix Domain sockets. - - Extra parameters for the db can be passed through kwargs. This includes - many performance, caching and inference settings. - - .. highlight:: python - .. code-block:: python - - example_kwargs = { - "maxclients": 100000, - "threads_per_queue": 1, - "inter_op_threads": 1, - "intra_op_threads": 1, - "server_threads": 2 # keydb only - } - - Generally these don't need to be changed. - - :param unix_socket: path to where the socket file will be created - :param socket_permissions: permissions for the socketfile - :param db_cpus: number of cpus to use for orchestrator - :param custom_pinning: CPUs to pin the orchestrator to. Passing an empty - iterable disables pinning - :param debug: launch Model with extra debug information about the colocated db - :param kwargs: additional keyword arguments to pass to the orchestrator database - """ - - if not re.match(r"^[a-zA-Z0-9.:\,_\-/]*$", unix_socket): - raise ValueError( - f"Invalid name for unix socket: {unix_socket}. Must only " - "contain alphanumeric characters or . : _ - /" - ) - uds_options: t.Dict[str, t.Union[int, str]] = { - "unix_socket": unix_socket, - "socket_permissions": socket_permissions, - # This is hardcoded to 0 as recommended by redis for UDS - "port": 0, - } - - common_options = { - "cpus": db_cpus, - "custom_pinning": custom_pinning, - "debug": debug, - "db_identifier": db_identifier, - } - self._set_colocated_db_settings(uds_options, common_options, **kwargs) - - def colocate_db_tcp( - self, - port: int = 6379, - ifname: t.Union[str, list[str]] = "lo", - db_cpus: int = 1, - custom_pinning: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]] = None, - debug: bool = False, - db_identifier: str = "", - **kwargs: t.Any, - ) -> None: - """Colocate an Orchestrator instance with this Model over TCP/IP. - - This method will initialize settings which add an unsharded - database to this Model instance. Only this Model will be able to communicate - with this colocated database by using the loopback TCP interface. - - Extra parameters for the db can be passed through kwargs. This includes - many performance, caching and inference settings. - - .. highlight:: python - .. code-block:: python - - ex. kwargs = { - maxclients: 100000, - threads_per_queue: 1, - inter_op_threads: 1, - intra_op_threads: 1, - server_threads: 2 # keydb only - } - - Generally these don't need to be changed. - - :param port: port to use for orchestrator database - :param ifname: interface to use for orchestrator - :param db_cpus: number of cpus to use for orchestrator - :param custom_pinning: CPUs to pin the orchestrator to. Passing an empty - iterable disables pinning - :param debug: launch Model with extra debug information about the colocated db - :param kwargs: additional keyword arguments to pass to the orchestrator database - """ - - tcp_options = {"port": port, "ifname": ifname} - common_options = { - "cpus": db_cpus, - "custom_pinning": custom_pinning, - "debug": debug, - "db_identifier": db_identifier, - } - self._set_colocated_db_settings(tcp_options, common_options, **kwargs) - - def _set_colocated_db_settings( - self, - connection_options: t.Mapping[str, t.Union[int, t.List[str], str]], - common_options: t.Dict[ - str, - t.Union[ - t.Union[t.Iterable[t.Union[int, t.Iterable[int]]], None], - bool, - int, - str, - None, - ], - ], - **kwargs: t.Union[int, None], - ) -> None: - """ - Ingest the connection-specific options (UDS/TCP) and set the final settings - for the colocated database - """ - - if hasattr(self.run_settings, "mpmd") and len(self.run_settings.mpmd) > 0: - raise SSUnsupportedError( - "Models colocated with databases cannot be run as a mpmd workload" - ) - - if hasattr(self.run_settings, "_prep_colocated_db"): - # pylint: disable-next=protected-access - self.run_settings._prep_colocated_db(common_options["cpus"]) - - if "limit_app_cpus" in kwargs: - raise SSUnsupportedError( - "Pinning app CPUs via limit_app_cpus is not supported. Modify " - "RunSettings using the correct binding option for your launcher." - ) - - # TODO list which db settings can be extras - custom_pinning_ = t.cast( - t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]], - common_options.get("custom_pinning"), - ) - cpus_ = t.cast(int, common_options.get("cpus")) - common_options["custom_pinning"] = self._create_pinning_string( - custom_pinning_, cpus_ - ) - - colo_db_config: t.Dict[ - str, - t.Union[ - bool, - int, - str, - None, - t.List[str], - t.Iterable[t.Union[int, t.Iterable[int]]], - t.List[DBModel], - t.List[DBScript], - t.Dict[str, t.Union[int, None]], - t.Dict[str, str], - ], - ] = {} - colo_db_config.update(connection_options) - colo_db_config.update(common_options) - - redis_ai_temp = { - "threads_per_queue": kwargs.get("threads_per_queue", None), - "inter_op_parallelism": kwargs.get("inter_op_parallelism", None), - "intra_op_parallelism": kwargs.get("intra_op_parallelism", None), - } - # redisai arguments for inference settings - colo_db_config["rai_args"] = redis_ai_temp - colo_db_config["extra_db_args"] = { - k: str(v) for k, v in kwargs.items() if k not in redis_ai_temp - } - - self._check_db_objects_colo() - colo_db_config["db_models"] = self._db_models - colo_db_config["db_scripts"] = self._db_scripts - - self.run_settings.colocated_db_settings = colo_db_config - - @staticmethod - def _create_pinning_string( - pin_ids: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]], cpus: int - ) -> t.Optional[str]: - """Create a comma-separated string of CPU ids. By default, ``None`` - returns 0,1,...,cpus-1; an empty iterable will disable pinning - altogether, and an iterable constructs a comma separated string of - integers (e.g. ``[0, 2, 5]`` -> ``"0,2,5"``) - """ - - def _stringify_id(_id: int) -> str: - """Return the cPU id as a string if an int, otherwise raise a ValueError""" - if isinstance(_id, int): - if _id < 0: - raise ValueError("CPU id must be a nonnegative number") - return str(_id) - - raise TypeError(f"Argument is of type '{type(_id)}' not 'int'") - - try: - pin_ids = tuple(pin_ids) if pin_ids is not None else None - except TypeError: - raise TypeError( - "Expected a cpu pinning specification of type iterable of ints or " - f"iterables of ints. Instead got type `{type(pin_ids)}`" - ) from None - - # Deal with MacOSX limitations first. The "None" (default) disables pinning - # and is equivalent to []. The only invalid option is a non-empty pinning - if sys.platform == "darwin": - if pin_ids: - warnings.warn( - "CPU pinning is not supported on MacOSX. Ignoring pinning " - "specification.", - RuntimeWarning, - ) - return None - - # Flatten the iterable into a list and check to make sure that the resulting - # elements are all ints - if pin_ids is None: - return ",".join(_stringify_id(i) for i in range(cpus)) - if not pin_ids: - return None - pin_ids = ((x,) if isinstance(x, int) else x for x in pin_ids) - to_fmt = itertools.chain.from_iterable(pin_ids) - return ",".join(sorted({_stringify_id(x) for x in to_fmt})) - - def params_to_args(self) -> None: - """Convert parameters to command line arguments and update run settings.""" - if self.params_as_args is not None: - for param in self.params_as_args: - if not param in self.params: - raise ValueError( - f"Tried to convert {param} to command line argument for Model " - f"{self.name}, but its value was not found in model params" - ) - if self.run_settings is None: - raise ValueError( - "Tried to configure command line parameter for Model " - f"{self.name}, but no RunSettings are set." - ) - self.run_settings.add_exe_args( - cat_arg_and_value(param, self.params[param]) - ) - - def add_ml_model( - self, - name: str, - backend: str, - model: t.Optional[bytes] = None, - model_path: t.Optional[str] = None, - device: str = Device.CPU.value.upper(), - devices_per_node: int = 1, - first_device: int = 0, - batch_size: int = 0, - min_batch_size: int = 0, - min_batch_timeout: int = 0, - tag: str = "", - inputs: t.Optional[t.List[str]] = None, - outputs: t.Optional[t.List[str]] = None, - ) -> None: - """A TF, TF-lite, PT, or ONNX model to load into the DB at runtime - - Each ML Model added will be loaded into an - orchestrator (converged or not) prior to the execution - of this Model instance - - One of either model (in memory representation) or model_path (file) - must be provided - - :param name: key to store model under - :param backend: name of the backend (TORCH, TF, TFLITE, ONNX) - :param model: A model in memory (only supported for non-colocated orchestrators) - :param model_path: serialized model - :param device: name of device for execution - :param devices_per_node: The number of GPU devices available on the host. - This parameter only applies to GPU devices and will be ignored if device - is specified as CPU. - :param first_device: The first GPU device to use on the host. - This parameter only applies to GPU devices and will be ignored if device - is specified as CPU. - :param batch_size: batch size for execution - :param min_batch_size: minimum batch size for model execution - :param min_batch_timeout: time to wait for minimum batch size - :param tag: additional tag for model information - :param inputs: model inputs (TF only) - :param outputs: model outupts (TF only) - """ - db_model = DBModel( - name=name, - backend=backend, - model=model, - model_file=model_path, - device=device, - devices_per_node=devices_per_node, - first_device=first_device, - batch_size=batch_size, - min_batch_size=min_batch_size, - min_batch_timeout=min_batch_timeout, - tag=tag, - inputs=inputs, - outputs=outputs, - ) - self.add_ml_model_object(db_model) - - def add_script( - self, - name: str, - script: t.Optional[str] = None, - script_path: t.Optional[str] = None, - device: str = Device.CPU.value.upper(), - devices_per_node: int = 1, - first_device: int = 0, - ) -> None: - """TorchScript to launch with this Model instance - - Each script added to the model will be loaded into an - orchestrator (converged or not) prior to the execution - of this Model instance - - Device selection is either "GPU" or "CPU". If many devices are - present, a number can be passed for specification e.g. "GPU:1". - - Setting ``devices_per_node=N``, with N greater than one will result - in the script being stored in the first N devices of type ``device``; - alternatively, setting ``first_device=M`` will result in the script - being stored on nodes M through M + N - 1. - - One of either script (in memory string representation) or script_path (file) - must be provided - - :param name: key to store script under - :param script: TorchScript code (only supported for non-colocated orchestrators) - :param script_path: path to TorchScript code - :param device: device for script execution - :param devices_per_node: The number of GPU devices available on the host. - This parameter only applies to GPU devices and will be ignored if device - is specified as CPU. - :param first_device: The first GPU device to use on the host. - This parameter only applies to GPU devices and will be ignored if device - is specified as CPU. - """ - db_script = DBScript( - name=name, - script=script, - script_path=script_path, - device=device, - devices_per_node=devices_per_node, - first_device=first_device, - ) - self.add_script_object(db_script) - - def add_function( - self, - name: str, - function: t.Optional[str] = None, - device: str = Device.CPU.value.upper(), - devices_per_node: int = 1, - first_device: int = 0, - ) -> None: - """TorchScript function to launch with this Model instance - - Each script function to the model will be loaded into a - non-converged orchestrator prior to the execution - of this Model instance. - - For converged orchestrators, the :meth:`add_script` method should be used. - - Device selection is either "GPU" or "CPU". If many devices are - present, a number can be passed for specification e.g. "GPU:1". - - Setting ``devices_per_node=N``, with N greater than one will result - in the model being stored in the first N devices of type ``device``. - - :param name: key to store function under - :param function: TorchScript function code - :param device: device for script execution - :param devices_per_node: The number of GPU devices available on the host. - This parameter only applies to GPU devices and will be ignored if device - is specified as CPU. - :param first_device: The first GPU device to use on the host. - This parameter only applies to GPU devices and will be ignored if device - is specified as CPU. - """ - db_script = DBScript( - name=name, - script=function, - device=device, - devices_per_node=devices_per_node, - first_device=first_device, - ) - self.add_script_object(db_script) - - def __hash__(self) -> int: - return hash(self.name) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Model): - return False - - if self.name == other.name: - return True - return False - - def __str__(self) -> str: # pragma: no cover - entity_str = "Name: " + self.name + "\n" - entity_str += "Type: " + self.type + "\n" - entity_str += str(self.run_settings) + "\n" - if self._db_models: - entity_str += "DB Models: \n" + str(len(self._db_models)) + "\n" - if self._db_scripts: - entity_str += "DB Scripts: \n" + str(len(self._db_scripts)) + "\n" - return entity_str - - def add_ml_model_object(self, db_model: DBModel) -> None: - if not db_model.is_file and self.colocated: - err_msg = "ML model can not be set from memory for colocated databases.\n" - err_msg += ( - f"Please store the ML model named {db_model.name} in binary format " - ) - err_msg += "and add it to the SmartSim Model as file." - raise SSUnsupportedError(err_msg) - - self._db_models.append(db_model) - - def add_script_object(self, db_script: DBScript) -> None: - if db_script.func and self.colocated: - if not isinstance(db_script.func, str): - err_msg = ( - "Functions can not be set from memory for colocated databases.\n" - f"Please convert the function named {db_script.name} " - "to a string or store it as a text file and add it to the " - "SmartSim Model with add_script." - ) - raise SSUnsupportedError(err_msg) - self._db_scripts.append(db_script) - - def _check_db_objects_colo(self) -> None: - for db_model in self._db_models: - if not db_model.is_file: - err_msg = ( - "ML model can not be set from memory for colocated databases.\n" - f"Please store the ML model named {db_model.name} in binary " - "format and add it to the SmartSim Model as file." - ) - raise SSUnsupportedError(err_msg) - - for db_script in self._db_scripts: - if db_script.func: - if not isinstance(db_script.func, str): - err_msg = ( - "Functions can not be set from memory for colocated " - "databases.\nPlease convert the function named " - f"{db_script.name} to a string or store it as a text" - "file and add it to the SmartSim Model with add_script." - ) - raise SSUnsupportedError(err_msg) diff --git a/smartsim/error/errors.py b/smartsim/error/errors.py index 0cb38d7e6b..54536281e9 100644 --- a/smartsim/error/errors.py +++ b/smartsim/error/errors.py @@ -44,7 +44,7 @@ class EntityExistsError(SmartSimError): class UserStrategyError(SmartSimError): - """Raised when there is an error with model creation inside an ensemble + """Raised when there is an error with application creation inside an ensemble that is from a user provided permutation strategy """ @@ -60,7 +60,7 @@ def create_message(perm_strat: str) -> str: class ParameterWriterError(SmartSimError): - """Raised in the event that input parameter files for a model + """Raised in the event that input parameter files for a application could not be written. """ @@ -82,13 +82,13 @@ class SSReservedKeywordError(SmartSimError): class SSDBIDConflictError(SmartSimError): - """Raised in the event that a database identifier - is not unique when multiple databases are created + """Raised in the event that a feature store identifier + is not unique when multiple feature stores are created """ class SSDBFilesNotParseable(SmartSimError): - """Raised when the files related to the database cannot be parsed. + """Raised when the files related to the feature store cannot be parsed. Includes the case when the files do not exist. """ @@ -112,6 +112,14 @@ class LauncherUnsupportedFeature(LauncherError): """Raised when the launcher does not support a given method""" +class LauncherNotFoundError(LauncherError): + """A requested launcher could not be found""" + + +class LauncherJobNotFound(LauncherError): + """Launcher was asked to get information about a job it did not start""" + + class AllocationError(LauncherError): """Raised when there is a problem with the user WLM allocation""" diff --git a/smartsim/experiment.py b/smartsim/experiment.py index 9a14eecdc8..4db503819a 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -24,32 +24,36 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# pylint: disable=too-many-lines +from __future__ import annotations -import os +import datetime +import itertools import os.path as osp +import pathlib import typing as t from os import environ, getcwd from tabulate import tabulate +from smartsim._core import dispatch from smartsim._core.config import CONFIG -from smartsim.error.errors import SSUnsupportedError -from smartsim.status import SmartSimStatus - -from ._core import Controller, Generator, Manifest, previewrenderer -from .database import Orchestrator -from .entity import ( - Ensemble, - EntitySequence, - Model, - SmartSimEntity, - TelemetryConfiguration, -) +from smartsim._core.control import interval as _interval +from smartsim._core.control import preview_renderer +from smartsim._core.control.launch_history import LaunchHistory as _LaunchHistory +from smartsim._core.utils import helpers as _helpers +from smartsim.error import errors +from smartsim.launchable.job import Job +from smartsim.status import TERMINAL_STATUSES, InvalidJobStatus, JobStatus + +from ._core import Generator, Manifest +from ._core.generation.generator import Job_Path +from .entity import TelemetryConfiguration from .error import SmartSimError from .log import ctx_exp_path, get_logger, method_contextualizer -from .settings import Container, base, settings -from .wlm import detect_launcher + +if t.TYPE_CHECKING: + from smartsim.launchable.job import Job + from smartsim.types import LaunchedJobID logger = get_logger(__name__) @@ -82,73 +86,55 @@ def _on_disable(self) -> None: # pylint: disable=no-self-use class Experiment: - """Experiment is a factory class that creates stages of a workflow - and manages their execution. - - The instances created by an Experiment represent executable code - that is either user-specified, like the ``Model`` instance created - by ``Experiment.create_model``, or pre-configured, like the ``Orchestrator`` - instance created by ``Experiment.create_database``. - - Experiment methods that accept a variable list of arguments, such as - ``Experiment.start`` or ``Experiment.stop``, accept any number of the - instances created by the Experiment. - - In general, the Experiment class is designed to be initialized once - and utilized throughout runtime. + """The Experiment class is used to schedule, launch, track, and manage + jobs and job groups. Also, it is the SmartSim class that manages + internal data structures, processes, and infrastructure for interactive + capabilities such as the SmartSim dashboard and historical lookback on + launched jobs and job groups. The Experiment class is designed to be + initialized once and utilized throughout the entirety of a workflow. """ - def __init__( - self, - name: str, - exp_path: t.Optional[str] = None, - launcher: str = "local", - ): + def __init__(self, name: str, exp_path: str | None = None): """Initialize an Experiment instance. - With the default settings, the Experiment will use the - local launcher, which will start all Experiment created - instances on the localhost. - Example of initializing an Experiment with the local launcher + Example of initializing an Experiment - .. highlight:: python - .. code-block:: python - - exp = Experiment(name="my_exp", launcher="local") - - SmartSim supports multiple launchers which also can be specified - based on the type of system you are running on. .. highlight:: python .. code-block:: python - exp = Experiment(name="my_exp", launcher="slurm") + exp = Experiment(name="my_exp") - If you want your Experiment driver script to be run across - multiple system with different schedulers (workload managers) - you can also use the `auto` argument to have the Experiment detect - which launcher to use based on system installed binaries and libraries. + The name of a SmartSim ``Experiment`` will determine the + name of the ``Experiment`` directory that is created inside of the + current working directory. + + If a different ``Experiment`` path is desired, the ``exp_path`` + parameter can be set as shown in the example below. .. highlight:: python .. code-block:: python - exp = Experiment(name="my_exp", launcher="auto") + exp = Experiment(name="my_exp", exp_path="/full/path/to/exp") - The Experiment path will default to the current working directory - and if the ``Experiment.generate`` method is called, a directory - with the Experiment name will be created to house the output - from the Experiment. + Note that the provided path must exist prior to ``Experiment`` + construction and that an experiment name subdirectory will not be + created inside of the provide path. :param name: name for the ``Experiment`` :param exp_path: path to location of ``Experiment`` directory - :param launcher: type of launcher being used, options are "slurm", "pbs", - "lsf", "sge", or "local". If set to "auto", - an attempt will be made to find an available launcher - on the system. """ + + if name: + if not isinstance(name, str): + raise TypeError("name argument was not of type str") + else: + raise TypeError("Experiment name must be non-empty string") + self.name = name + if exp_path: if not isinstance(exp_path, str): raise TypeError("exp_path argument was not of type str") @@ -159,21 +145,18 @@ def __init__( exp_path = osp.join(getcwd(), name) self.exp_path = exp_path + """The path under which the experiment operate""" - self._launcher = launcher.lower() - - if self._launcher == "auto": - self._launcher = detect_launcher() - if self._launcher == "cobalt": - raise SSUnsupportedError("Cobalt launcher is no longer supported.") - - if launcher == "dragon": - self._set_dragon_server_path() - - self._control = Controller(launcher=self._launcher) - - self.db_identifiers: t.Set[str] = set() + self._launch_history = _LaunchHistory() + """A cache of launchers used and which ids they have issued""" + self._fs_identifiers: t.Set[str] = set() + """Set of feature store identifiers currently in use by this + experiment + """ self._telemetry_cfg = ExperimentTelemetryConfiguration() + """Switch to specify if telemetry data should be produced for this + experiment + """ def _set_dragon_server_path(self) -> None: """Set path for dragon server through environment varialbes""" @@ -182,638 +165,220 @@ def _set_dragon_server_path(self) -> None: self.exp_path, CONFIG.dragon_default_subdir ) - @_contextualize - def start( - self, - *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], - block: bool = True, - summary: bool = False, - kill_on_interrupt: bool = True, - ) -> None: - """Start passed instances using Experiment launcher - - Any instance ``Model``, ``Ensemble`` or ``Orchestrator`` - instance created by the Experiment can be passed as - an argument to the start method. - - .. highlight:: python - .. code-block:: python - - exp = Experiment(name="my_exp", launcher="slurm") - settings = exp.create_run_settings(exe="./path/to/binary") - model = exp.create_model("my_model", settings) - exp.start(model) + def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]: + """Execute a collection of `Job` instances. - Multiple entity instances can also be passed to the start method - at once no matter which type of instance they are. These will - all be launched together. - - .. highlight:: python - .. code-block:: python - - exp.start(model_1, model_2, db, ensemble, block=True) - # alternatively - stage_1 = [model_1, model_2, db, ensemble] - exp.start(*stage_1, block=True) - - - If `block==True` the Experiment will poll the launched instances - at runtime until all non-database jobs have completed. Database - jobs *must* be killed by the user by passing them to - ``Experiment.stop``. This allows for multiple stages of a workflow - to produce to and consume from the same Orchestrator database. - - If `kill_on_interrupt=True`, then all jobs launched by this - experiment are guaranteed to be killed when ^C (SIGINT) signal is - received. If `kill_on_interrupt=False`, then it is not guaranteed - that all jobs launched by this experiment will be killed, and the - zombie processes will need to be manually killed. - - :param block: block execution until all non-database - jobs are finished - :param summary: print a launch summary prior to launch - :param kill_on_interrupt: flag for killing jobs when ^C (SIGINT) - signal is received. + :param jobs: A collection of other job instances to start + :raises TypeError: If jobs provided are not the correct type + :raises ValueError: No Jobs were provided. + :returns: A sequence of ids with order corresponding to the sequence of + jobs that can be used to query or alter the status of that + particular execution of the job. """ - start_manifest = Manifest(*args) - self._create_entity_dir(start_manifest) - try: - if summary: - self._launch_summary(start_manifest) - self._control.start( - exp_name=self.name, - exp_path=self.exp_path, - manifest=start_manifest, - block=block, - kill_on_interrupt=kill_on_interrupt, - ) - except SmartSimError as e: - logger.error(e) - raise - - @_contextualize - def stop( - self, *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> None: - """Stop specific instances launched by this ``Experiment`` - - Instances of ``Model``, ``Ensemble`` and ``Orchestrator`` - can all be passed as arguments to the stop method. - - Whichever launcher was specified at Experiment initialization - will be used to stop the instance. For example, which using - the slurm launcher, this equates to running `scancel` on the - instance. - - Example - - .. highlight:: python - .. code-block:: python - - exp.stop(model) - # multiple - exp.stop(model_1, model_2, db, ensemble) - - :param args: One or more SmartSimEntity or EntitySequence objects. - :raises TypeError: if wrong type - :raises SmartSimError: if stop request fails - """ - stop_manifest = Manifest(*args) - try: - for entity in stop_manifest.models: - self._control.stop_entity(entity) - for entity_list in stop_manifest.ensembles: - self._control.stop_entity_list(entity_list) - dbs = stop_manifest.dbs - for db in dbs: - self._control.stop_db(db) - except SmartSimError as e: - logger.error(e) - raise - - @_contextualize - def generate( - self, - *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], - tag: t.Optional[str] = None, - overwrite: bool = False, - verbose: bool = False, - ) -> None: - """Generate the file structure for an ``Experiment`` - - ``Experiment.generate`` creates directories for each entity - passed to organize Experiments that launch many entities. - - If files or directories are attached to ``Model`` objects - using ``Model.attach_generator_files()``, those files or - directories will be symlinked, copied, or configured and - written into the created directory for that instance. - - Instances of ``Model``, ``Ensemble`` and ``Orchestrator`` - can all be passed as arguments to the generate method. - - :param tag: tag used in `to_configure` generator files - :param overwrite: overwrite existing folders and contents - :param verbose: log parameter settings to std out - """ - try: - generator = Generator(self.exp_path, overwrite=overwrite, verbose=verbose) - if tag: - generator.set_tag(tag) - generator.generate_experiment(*args) - except SmartSimError as e: - logger.error(e) - raise - - @_contextualize - def poll( - self, interval: int = 10, verbose: bool = True, kill_on_interrupt: bool = True - ) -> None: - """Monitor jobs through logging to stdout. - - This method should only be used if jobs were launched - with ``Experiment.start(block=False)`` - - The internal specified will control how often the - logging is performed, not how often the polling occurs. - By default, internal polling is set to every second for - local launcher jobs and every 10 seconds for all other - launchers. - - If internal polling needs to be slower or faster based on - system or site standards, set the ``SMARTSIM_JM_INTERNAL`` - environment variable to control the internal polling interval - for SmartSim. - - For more verbose logging output, the ``SMARTSIM_LOG_LEVEL`` - environment variable can be set to `debug` - - If `kill_on_interrupt=True`, then all jobs launched by this - experiment are guaranteed to be killed when ^C (SIGINT) signal is - received. If `kill_on_interrupt=False`, then it is not guaranteed - that all jobs launched by this experiment will be killed, and the - zombie processes will need to be manually killed. - - :param interval: frequency (in seconds) of logging to stdout - :param verbose: set verbosity - :param kill_on_interrupt: flag for killing jobs when SIGINT is received - :raises SmartSimError: if poll request fails - """ - try: - self._control.poll(interval, verbose, kill_on_interrupt=kill_on_interrupt) - except SmartSimError as e: - logger.error(e) - raise - - @_contextualize - def finished(self, entity: SmartSimEntity) -> bool: - """Query if a job has completed. - - An instance of ``Model`` or ``Ensemble`` can be passed - as an argument. - - Passing ``Orchestrator`` will return an error as a - database deployment is never finished until stopped - by the user. - - :param entity: object launched by this ``Experiment`` - :returns: True if the job has finished, False otherwise - :raises SmartSimError: if entity has not been launched - by this ``Experiment`` - """ - try: - return self._control.finished(entity) - except SmartSimError as e: - logger.error(e) - raise - - @_contextualize - def get_status( - self, *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> t.List[SmartSimStatus]: - """Query the status of launched entity instances - Return a smartsim.status string representing - the status of the launched instance. + if not jobs: + raise ValueError("No jobs provided to start") - .. highlight:: python - .. code-block:: python - - exp.get_status(model) + # Create the run id + jobs_ = list(_helpers.unpack(jobs)) - As with an Experiment method, multiple instance of - varying types can be passed to and all statuses will - be returned at once. - - .. highlight:: python - .. code-block:: python - - statuses = exp.get_status(model, ensemble, orchestrator) - complete = [s == smartsim.status.STATUS_COMPLETED for s in statuses] - assert all(complete) - - :returns: status of the instances passed as arguments - :raises SmartSimError: if status retrieval fails - """ - try: - manifest = Manifest(*args) - statuses: t.List[SmartSimStatus] = [] - for entity in manifest.models: - statuses.append(self._control.get_entity_status(entity)) - for entity_list in manifest.all_entity_lists: - statuses.extend(self._control.get_entity_list_status(entity_list)) - return statuses - except SmartSimError as e: - logger.error(e) - raise + run_id = datetime.datetime.now().replace(microsecond=0).isoformat() + root = pathlib.Path(self.exp_path, run_id.replace(":", ".")) + return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs_) - @_contextualize - def create_ensemble( + def _dispatch( self, - name: str, - params: t.Optional[t.Dict[str, t.Any]] = None, - batch_settings: t.Optional[base.BatchSettings] = None, - run_settings: t.Optional[base.RunSettings] = None, - replicas: t.Optional[int] = None, - perm_strategy: str = "all_perm", - path: t.Optional[str] = None, - **kwargs: t.Any, - ) -> Ensemble: - """Create an ``Ensemble`` of ``Model`` instances - - Ensembles can be launched sequentially or as a batch - if using a non-local launcher. e.g. slurm - - Ensembles require one of the following combinations - of arguments: - - - ``run_settings`` and ``params`` - - ``run_settings`` and ``replicas`` - - ``batch_settings`` - - ``batch_settings``, ``run_settings``, and ``params`` - - ``batch_settings``, ``run_settings``, and ``replicas`` - - If given solely batch settings, an empty ensemble - will be created that Models can be added to manually - through ``Ensemble.add_model()``. - The entire Ensemble will launch as one batch. - - Provided batch and run settings, either ``params`` - or ``replicas`` must be passed and the entire ensemble - will launch as a single batch. - - Provided solely run settings, either ``params`` - or ``replicas`` must be passed and the Ensemble members - will each launch sequentially. - - The kwargs argument can be used to pass custom input - parameters to the permutation strategy. - - :param name: name of the ``Ensemble`` - :param params: parameters to expand into ``Model`` members - :param batch_settings: describes settings for ``Ensemble`` as batch workload - :param run_settings: describes how each ``Model`` should be executed - :param replicas: number of replicas to create - :param perm_strategy: strategy for expanding ``params`` into - ``Model`` instances from params argument - options are "all_perm", "step", "random" - or a callable function. - :raises SmartSimError: if initialization fails - :return: ``Ensemble`` instance + generator: Generator, + dispatcher: dispatch.Dispatcher, + job: Job, + *jobs: Job, + ) -> tuple[LaunchedJobID, ...]: + """Dispatch a series of jobs with a particular dispatcher + + :param generator: The generator is responsible for creating the + job run and log directory. + :param dispatcher: The dispatcher that should be used to determine how + to start a job based on its launch settings. + :param job: The first job instance to dispatch + :param jobs: A collection of other job instances to dispatch + :returns: A sequence of ids with order corresponding to the sequence of + jobs that can be used to query or alter the status of that + particular dispatch of the job. """ - if name is None: - raise AttributeError("Entity has no name. Please set name attribute.") - check_path = path or osp.join(self.exp_path, name) - entity_path: str = osp.abspath(check_path) - try: - new_ensemble = Ensemble( - name=name, - params=params or {}, - path=entity_path, - batch_settings=batch_settings, - run_settings=run_settings, - perm_strat=perm_strategy, - replicas=replicas, - **kwargs, + def execute_dispatch(generator: Generator, job: Job, idx: int) -> LaunchedJobID: + args = job.launch_settings.launch_args + env = job.launch_settings.env_vars + exe = job.entity.as_executable_sequence() + dispatch = dispatcher.get_dispatch(args) + try: + # Check to see if one of the existing launchers can be + # configured to handle the launch arguments ... + launch_config = dispatch.configure_first_compatible_launcher( + from_available_launchers=self._launch_history.iter_past_launchers(), + with_arguments=args, + ) + except errors.LauncherNotFoundError: + # ... otherwise create a new launcher that _can_ handle the + # launch arguments and configure _that_ one + launch_config = dispatch.create_new_launcher_configuration( + for_experiment=self, with_arguments=args + ) + # Generate the job directory and return the generated job path + job_paths = self._generate(generator, job, idx) + id_ = launch_config.start( + exe, job_paths.run_path, env, job_paths.out_path, job_paths.err_path ) - return new_ensemble - except SmartSimError as e: - logger.error(e) - raise - - @_contextualize - def create_model( - self, - name: str, - run_settings: base.RunSettings, - params: t.Optional[t.Dict[str, t.Any]] = None, - path: t.Optional[str] = None, - enable_key_prefixing: bool = False, - batch_settings: t.Optional[base.BatchSettings] = None, - ) -> Model: - """Create a general purpose ``Model`` - - The ``Model`` class is the most general encapsulation of - executable code in SmartSim. ``Model`` instances are named - references to pieces of a workflow that can be parameterized, - and executed. - - ``Model`` instances can be launched sequentially, as a batch job, - or as a group by adding them into an ``Ensemble``. - - All ``Models`` require a reference to run settings to specify which - executable to launch as well provide options for how to launch - the executable with the underlying WLM. Furthermore, batch a - reference to a batch settings can be added to launch the ``Model`` - as a batch job through ``Experiment.start``. If a ``Model`` with - a reference to a set of batch settings is added to a larger - entity with its own set of batch settings (for e.g. an - ``Ensemble``) the batch settings of the larger entity will take - precedence and the batch setting of the ``Model`` will be - strategically ignored. - - Parameters supplied in the `params` argument can be written into - configuration files supplied at runtime to the ``Model`` through - ``Model.attach_generator_files``. `params` can also be turned - into executable arguments by calling ``Model.params_to_args`` - - By default, ``Model`` instances will be executed in the - exp_path/model_name directory if no `path` argument is supplied. - If a ``Model`` instance is passed to ``Experiment.generate``, - a directory within the ``Experiment`` directory will be created - to house the input and output files from the ``Model``. - - Example initialization of a ``Model`` instance - - .. highlight:: python - .. code-block:: python + # Save the underlying launcher instance and launched job id. That + # way we do not need to spin up a launcher instance for each + # individual job, and the experiment can monitor job statuses. + # pylint: disable-next=protected-access + self._launch_history.save_launch(launch_config._adapted_launcher, id_) + return id_ + + return execute_dispatch(generator, job, 0), *( + execute_dispatch(generator, job, idx) for idx, job in enumerate(jobs, 1) + ) - from smartsim import Experiment - run_settings = exp.create_run_settings("python", "run_pytorch_model.py") - model = exp.create_model("pytorch_model", run_settings) - - # adding parameters to a model - run_settings = exp.create_run_settings("python", "run_pytorch_model.py") - train_params = { - "batch": 32, - "epoch": 10, - "lr": 0.001 - } - model = exp.create_model("pytorch_model", run_settings, params=train_params) - model.attach_generator_files(to_configure="./train.cfg") - exp.generate(model) - - New in 0.4.0, ``Model`` instances can be colocated with an - Orchestrator database shard through ``Model.colocate_db``. This - will launch a single ``Orchestrator`` instance on each compute - host used by the (possibly distributed) application. This is - useful for performant online inference or processing - at runtime. - - New in 0.4.2, ``Model`` instances can now be colocated with - an Orchestrator database over either TCP or UDS using the - ``Model.colocate_db_tcp`` or ``Model.colocate_db_uds`` method - respectively. The original ``Model.colocate_db`` method is now - deprecated, but remains as an alias for ``Model.colocate_db_tcp`` - for backward compatibility. - - :param name: name of the ``Model`` - :param run_settings: defines how ``Model`` should be run - :param params: ``Model`` parameters for writing into configuration files - :param path: path to where the ``Model`` should be executed at runtime - :param enable_key_prefixing: If True, data sent to the ``Orchestrator`` - using SmartRedis from this ``Model`` will - be prefixed with the ``Model`` name. - :param batch_settings: Settings to run ``Model`` individually as a batch job. - :raises SmartSimError: if initialization fails - :return: the created ``Model`` + def get_status( + self, *ids: LaunchedJobID + ) -> tuple[JobStatus | InvalidJobStatus, ...]: + """Get the status of jobs launched through the `Experiment` from their + launched job id returned when calling `Experiment.start`. + + The `Experiment` will map the launched ID back to the launcher that + started the job and request a status update. The order of the returned + statuses exactly matches the order of the launched job ids. + + If the `Experiment` cannot find any launcher that started the job + associated with the launched job id, then a + `InvalidJobStatus.NEVER_STARTED` status is returned for that id. + + If the experiment maps the launched job id to multiple launchers, then + a `ValueError` is raised. This should only happen in the case when + launched job ids issued by user defined launcher are not sufficiently + unique. + + :param ids: A sequence of launched job ids issued by the experiment. + :raises TypeError: If ids provided are not the correct type + :raises ValueError: No IDs were provided. + :returns: A tuple of statuses with order respective of the order of the + calling arguments. """ - if name is None: - raise AttributeError("Entity has no name. Please set name attribute.") - check_path = path or osp.join(self.exp_path, name) - entity_path: str = osp.abspath(check_path) - if params is None: - params = {} - - try: - new_model = Model( - name=name, - params=params, - path=entity_path, - run_settings=run_settings, - batch_settings=batch_settings, - ) - if enable_key_prefixing: - new_model.enable_key_prefixing() - return new_model - except SmartSimError as e: - logger.error(e) - raise - - @_contextualize - def create_run_settings( - self, - exe: str, - exe_args: t.Optional[t.List[str]] = None, - run_command: str = "auto", - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - container: t.Optional[Container] = None, - **kwargs: t.Any, - ) -> settings.RunSettings: - """Create a ``RunSettings`` instance. - - run_command="auto" will attempt to automatically - match a run command on the system with a ``RunSettings`` - class in SmartSim. If found, the class corresponding - to that run_command will be created and returned. - - If the local launcher is being used, auto detection will - be turned off. - - If a recognized run command is passed, the ``RunSettings`` - instance will be a child class such as ``SrunSettings`` - - If not supported by smartsim, the base ``RunSettings`` class - will be created and returned with the specified run_command and run_args - will be evaluated literally. - - Run Commands with implemented helper classes: - - aprun (ALPS) - - srun (SLURM) - - mpirun (OpenMPI) - - jsrun (LSF) - - :param run_command: command to run the executable - :param exe: executable to run - :param exe_args: arguments to pass to the executable - :param run_args: arguments to pass to the ``run_command`` - :param env_vars: environment variables to pass to the executable - :param container: if execution environment is containerized - :return: the created ``RunSettings`` + if not ids: + raise ValueError("No job ids provided to get status") + if not all(isinstance(id, str) for id in ids): + raise TypeError("ids argument was not of type LaunchedJobID") + + to_query = self._launch_history.group_by_launcher( + set(ids), unknown_ok=True + ).items() + stats_iter = (launcher.get_status(*ids).items() for launcher, ids in to_query) + stats_map = dict(itertools.chain.from_iterable(stats_iter)) + stats = (stats_map.get(i, InvalidJobStatus.NEVER_STARTED) for i in ids) + return tuple(stats) + + def wait( + self, *ids: LaunchedJobID, timeout: float | None = None, verbose: bool = True + ) -> None: + """Block execution until all of the provided launched jobs, represented + by an ID, have entered a terminal status. + + :param ids: The ids of the launched jobs to wait for. + :param timeout: The max time to wait for all of the launched jobs to end. + :param verbose: Whether found statuses should be displayed in the console. + :raises TypeError: If IDs provided are not the correct type + :raises ValueError: No IDs were provided. """ + if ids: + if not all(isinstance(id, str) for id in ids): + raise TypeError("ids argument was not of type LaunchedJobID") + else: + raise ValueError("No job ids to wait on provided") + self._poll_for_statuses( + ids, TERMINAL_STATUSES, timeout=timeout, verbose=verbose + ) - try: - return settings.create_run_settings( - self._launcher, - exe, - exe_args=exe_args, - run_command=run_command, - run_args=run_args, - env_vars=env_vars, - container=container, - **kwargs, - ) - except SmartSimError as e: - logger.error(e) - raise - - @_contextualize - def create_batch_settings( + def _poll_for_statuses( self, - nodes: int = 1, - time: str = "", - queue: str = "", - account: str = "", - batch_args: t.Optional[t.Dict[str, str]] = None, - **kwargs: t.Any, - ) -> base.BatchSettings: - """Create a ``BatchSettings`` instance - - Batch settings parameterize batch workloads. The result of this - function can be passed to the ``Ensemble`` initialization. - - the `batch_args` parameter can be used to pass in a dictionary - of additional batch command arguments that aren't supported through - the smartsim interface - - - .. highlight:: python - .. code-block:: python - - # i.e. for Slurm - batch_args = { - "distribution": "block" - "exclusive": None - } - bs = exp.create_batch_settings(nodes=3, - time="10:00:00", - batch_args=batch_args) - bs.set_account("default") - - :param nodes: number of nodes for batch job - :param time: length of batch job - :param queue: queue or partition (if slurm) - :param account: user account name for batch system - :param batch_args: additional batch arguments - :return: a newly created BatchSettings instance - :raises SmartSimError: if batch creation fails + ids: t.Sequence[LaunchedJobID], + statuses: t.Collection[JobStatus], + timeout: float | None = None, + interval: float = 5.0, + verbose: bool = True, + ) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]: + """Poll the experiment's launchers for the statuses of the launched + jobs with the provided ids, until the status of the changes to one of + the provided statuses. + + :param ids: The ids of the launched jobs to wait for. + :param statuses: A collection of statuses to poll for. + :param timeout: The minimum amount of time to spend polling all jobs to + reach one of the supplied statuses. If not supplied or `None`, the + experiment will poll indefinitely. + :param interval: The minimum time between polling launchers. + :param verbose: Whether or not to log polled states to the console. + :raises ValueError: The interval between polling launchers is infinite + :raises TimeoutError: The polling interval was exceeded. + :returns: A mapping of ids to the status they entered that ended + polling. """ - try: - return settings.create_batch_settings( - self._launcher, - nodes=nodes, - time=time, - queue=queue, - account=account, - batch_args=batch_args, - **kwargs, + terminal = frozenset(itertools.chain(statuses, InvalidJobStatus)) + log = logger.info if verbose else lambda *_, **__: None + method_timeout = _interval.SynchronousTimeInterval(timeout) + iter_timeout = _interval.SynchronousTimeInterval(interval) + final: dict[LaunchedJobID, JobStatus | InvalidJobStatus] = {} + + def is_finished( + id_: LaunchedJobID, status: JobStatus | InvalidJobStatus + ) -> bool: + job_title = f"Job({id_}): " + if done := status in terminal: + log(f"{job_title}Finished with status '{status.value}'") + else: + log(f"{job_title}Running with status '{status.value}'") + return done + + if iter_timeout.infinite: + raise ValueError("Polling interval cannot be infinite") + while ids and not method_timeout.expired: + iter_timeout = iter_timeout.new_interval() + stats = zip(ids, self.get_status(*ids)) + is_done = _helpers.group_by(_helpers.pack_params(is_finished), stats) + final |= dict(is_done.get(True, ())) + ids = tuple(id_ for id_, _ in is_done.get(False, ())) + if ids: + ( + iter_timeout + if iter_timeout.remaining < method_timeout.remaining + else method_timeout + ).block() + if ids: + raise TimeoutError( + f"Job ID(s) {', '.join(map(str, ids))} failed to reach " + "terminal status before timeout" ) - except SmartSimError as e: - logger.error(e) - raise + return final @_contextualize - def create_database( - self, - port: int = 6379, - path: t.Optional[str] = None, - db_nodes: int = 1, - batch: bool = False, - hosts: t.Optional[t.Union[t.List[str], str]] = None, - run_command: str = "auto", - interface: t.Union[str, t.List[str]] = "ipogif0", - account: t.Optional[str] = None, - time: t.Optional[str] = None, - queue: t.Optional[str] = None, - single_cmd: bool = True, - db_identifier: str = "orchestrator", - **kwargs: t.Any, - ) -> Orchestrator: - """Initialize an ``Orchestrator`` database - - The ``Orchestrator`` database is a key-value store based - on Redis that can be launched together with other ``Experiment`` - created instances for online data storage. - - When launched, ``Orchestrator`` can be used to communicate - data between Fortran, Python, C, and C++ applications. - - Machine Learning models in Pytorch, Tensorflow, and ONNX (i.e. scikit-learn) - can also be stored within the ``Orchestrator`` database where they - can be called remotely and executed on CPU or GPU where - the database is hosted. - - To enable a SmartSim ``Model`` to communicate with the database - the workload must utilize the SmartRedis clients. For more - information on the database, and SmartRedis clients see the - documentation at https://www.craylabs.org/docs/smartredis.html - - :param port: TCP/IP port - :param db_nodes: number of database shards - :param batch: run as a batch workload - :param hosts: specify hosts to launch on - :param run_command: specify launch binary or detect automatically - :param interface: Network interface - :param account: account to run batch on - :param time: walltime for batch 'HH:MM:SS' format - :param queue: queue to run the batch on - :param single_cmd: run all shards with one (MPMD) command - :param db_identifier: an identifier to distinguish this orchestrator in - multiple-database experiments - :raises SmartSimError: if detection of launcher or of run command fails - :raises SmartSimError: if user indicated an incompatible run command - for the launcher - :return: Orchestrator or derived class - """ - - self._append_to_db_identifier_list(db_identifier) - check_path = path or osp.join(self.exp_path, db_identifier) - entity_path: str = osp.abspath(check_path) - return Orchestrator( - port=port, - path=entity_path, - db_nodes=db_nodes, - batch=batch, - hosts=hosts, - run_command=run_command, - interface=interface, - account=account, - time=time, - queue=queue, - single_cmd=single_cmd, - launcher=self._launcher, - db_identifier=db_identifier, - **kwargs, - ) - - @_contextualize - def reconnect_orchestrator(self, checkpoint: str) -> Orchestrator: - """Reconnect to a running ``Orchestrator`` - - This method can be used to connect to a ``Orchestrator`` deployment - that was launched by a previous ``Experiment``. This can be - helpful in the case where separate runs of an ``Experiment`` - wish to use the same ``Orchestrator`` instance currently - running on a system. - - :param checkpoint: the `smartsim_db.dat` file created - when an ``Orchestrator`` is launched + def _generate(self, generator: Generator, job: Job, job_index: int) -> Job_Path: + """Generate the directory structure and files for a ``Job`` + + If files or directories are attached to an ``Application`` object + associated with the Job using ``Application.attach_generator_files()``, + those files or directories will be symlinked, copied, or configured and + written into the created job directory. + + :param generator: The generator is responsible for creating the job + run and log directory. + :param job: The Job instance for which the output is generated. + :param job_index: The index of the Job instance (used for naming). + :returns: The paths to the generated output for the Job instance. + :raises: A SmartSimError if an error occurs during the generation process. """ try: - orc = self._control.reload_saved_db(checkpoint) - return orc + job_paths = generator.generate_job(job, job_index) + return job_paths except SmartSimError as e: logger.error(e) raise @@ -821,14 +386,14 @@ def reconnect_orchestrator(self, checkpoint: str) -> Orchestrator: def preview( self, *args: t.Any, - verbosity_level: previewrenderer.Verbosity = previewrenderer.Verbosity.INFO, - output_format: previewrenderer.Format = previewrenderer.Format.PLAINTEXT, + verbosity_level: preview_renderer.Verbosity = preview_renderer.Verbosity.INFO, + output_format: preview_renderer.Format = preview_renderer.Format.PLAINTEXT, output_filename: t.Optional[str] = None, ) -> None: """Preview entity information prior to launch. This method aggregates multiple pieces of information to give users insight into what and how entities will be launched. Any instance of - ``Model``, ``Ensemble``, or ``Orchestrator`` created by the + ``Model``, ``Ensemble``, or ``Feature Store`` created by the Experiment can be passed as an argument to the preview method. Verbosity levels: @@ -847,24 +412,16 @@ def preview( output to stdout. Defaults to None. """ - # Retrieve any active orchestrator jobs - active_dbjobs = self._control.active_orchestrator_jobs - preview_manifest = Manifest(*args) - previewrenderer.render( + preview_renderer.render( self, preview_manifest, verbosity_level, output_format, output_filename, - active_dbjobs, ) - @property - def launcher(self) -> str: - return self._launcher - @_contextualize def summary(self, style: str = "github") -> str: """Return a summary of the ``Experiment`` @@ -877,7 +434,6 @@ def summary(self, style: str = "github") -> str: https://github.com/astanin/python-tabulate :return: tabulate string of ``Experiment`` history """ - values = [] headers = [ "Name", "Entity-Type", @@ -887,21 +443,8 @@ def summary(self, style: str = "github") -> str: "Status", "Returncode", ] - for job in self._control.get_jobs().values(): - for run in range(job.history.runs + 1): - values.append( - [ - job.entity.name, - job.entity.type, - job.history.jids[run], - run, - f"{job.history.job_times[run]:.4f}", - job.history.statuses[run], - job.history.returns[run], - ] - ) return tabulate( - values, + [], headers, showindex=True, tablefmt=style, @@ -909,6 +452,29 @@ def summary(self, style: str = "github") -> str: disable_numparse=True, ) + def stop(self, *ids: LaunchedJobID) -> tuple[JobStatus | InvalidJobStatus, ...]: + """Cancel the execution of a previously launched job. + + :param ids: The ids of the launched jobs to stop. + :raises TypeError: If ids provided are not the correct type + :raises ValueError: No job ids were provided. + :returns: A tuple of job statuses upon cancellation with order + respective of the order of the calling arguments. + """ + if ids: + if not all(isinstance(id, str) for id in ids): + raise TypeError("ids argument was not of type LaunchedJobID") + else: + raise ValueError("No job ids provided") + by_launcher = self._launch_history.group_by_launcher(set(ids), unknown_ok=True) + id_to_stop_stat = ( + launcher.stop_jobs(*launched).items() + for launcher, launched in by_launcher.items() + ) + stats_map = dict(itertools.chain.from_iterable(id_to_stop_stat)) + stats = (stats_map.get(id_, InvalidJobStatus.NEVER_STARTED) for id_ in ids) + return tuple(stats) + @property def telemetry(self) -> TelemetryConfiguration: """Return the telemetry configuration for this entity. @@ -917,57 +483,16 @@ def telemetry(self) -> TelemetryConfiguration: """ return self._telemetry_cfg - def _launch_summary(self, manifest: Manifest) -> None: - """Experiment pre-launch summary of entities that will be launched - - :param manifest: Manifest of deployables. - """ - - summary = "\n\n=== Launch Summary ===\n" - summary += f"Experiment: {self.name}\n" - summary += f"Experiment Path: {self.exp_path}\n" - summary += f"Launcher: {self._launcher}\n" - if manifest.models: - summary += f"Models: {len(manifest.models)}\n" - - if self._control.orchestrator_active: - summary += "Database Status: active\n" - elif manifest.dbs: - summary += "Database Status: launching\n" - else: - summary += "Database Status: inactive\n" - - summary += f"\n{str(manifest)}" - - logger.info(summary) - - def _create_entity_dir(self, start_manifest: Manifest) -> None: - def create_entity_dir(entity: t.Union[Orchestrator, Model, Ensemble]) -> None: - if not os.path.isdir(entity.path): - os.makedirs(entity.path) - - for model in start_manifest.models: - create_entity_dir(model) - - for orch in start_manifest.dbs: - create_entity_dir(orch) - - for ensemble in start_manifest.ensembles: - create_entity_dir(ensemble) - - for member in ensemble.models: - create_entity_dir(member) - def __str__(self) -> str: return self.name - def _append_to_db_identifier_list(self, db_identifier: str) -> None: - """Check if db_identifier already exists when calling create_database""" - if db_identifier in self.db_identifiers: + def _append_to_fs_identifier_list(self, fs_identifier: str) -> None: + """Check if fs_identifier already exists when calling create_feature_store""" + if fs_identifier in self._fs_identifiers: logger.warning( - f"A database with the identifier {db_identifier} has already been made " - "An error will be raised if multiple databases are started " + f"A feature store with the identifier {fs_identifier} has already been made " + "An error will be raised if multiple Feature Stores are started " "with the same identifier" ) # Otherwise, add - self.db_identifiers.add(db_identifier) + self._fs_identifiers.add(fs_identifier) diff --git a/smartsim/launchable/__init__.py b/smartsim/launchable/__init__.py new file mode 100644 index 0000000000..383b458f09 --- /dev/null +++ b/smartsim/launchable/__init__.py @@ -0,0 +1,34 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .base_job import BaseJob +from .base_job_group import BaseJobGroup +from .colocated_job_group import ColocatedJobGroup +from .job import Job +from .job_group import JobGroup +from .launchable import Launchable +from .mpmd_job import MPMDJob +from .mpmd_pair import MPMDPair diff --git a/smartsim/launchable/base_job.py b/smartsim/launchable/base_job.py new file mode 100644 index 0000000000..878a59e532 --- /dev/null +++ b/smartsim/launchable/base_job.py @@ -0,0 +1,43 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +from abc import ABC, abstractmethod + +from smartsim.launchable.launchable import Launchable + +if t.TYPE_CHECKING: + from smartsim._core.commands.launch_commands import LaunchCommands + + +class BaseJob(ABC, Launchable): + """The highest level abstract base class for a single job that can be launched""" + + @abstractmethod + def get_launch_steps(self) -> "LaunchCommands": + """Return the launch steps corresponding to the + internal data. + """ diff --git a/smartsim/launchable/base_job_group.py b/smartsim/launchable/base_job_group.py new file mode 100644 index 0000000000..9031705f39 --- /dev/null +++ b/smartsim/launchable/base_job_group.py @@ -0,0 +1,91 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t +from abc import ABC, abstractmethod +from collections.abc import MutableSequence +from copy import deepcopy + +from smartsim.launchable.launchable import Launchable + +from .base_job import BaseJob + + +class BaseJobGroup(Launchable, MutableSequence[BaseJob], ABC): + """Highest level ABC of a group of jobs that can be + launched + """ + + @property + @abstractmethod + def jobs(self) -> t.List[BaseJob]: + """This property method returns a list of BaseJob objects. + It represents the collection of jobs associated with an + instance of the BaseJobGroup abstract class. + """ + pass + + def insert(self, idx: int, value: BaseJob) -> None: + """Inserts the given value at the specified index (idx) in + the list of jobs. If the index is out of bounds, the method + prints an error message. + """ + self.jobs.insert(idx, value) + + def __iter__(self) -> t.Iterator[BaseJob]: + """Allows iteration over the jobs in the collection.""" + return iter(self.jobs) + + @t.overload + def __setitem__(self, idx: int, value: BaseJob) -> None: ... + @t.overload + def __setitem__(self, idx: slice, value: t.Iterable[BaseJob]) -> None: ... + def __setitem__( + self, idx: int | slice, value: BaseJob | t.Iterable[BaseJob] + ) -> None: + """Sets the job at the specified index (idx) to the given value.""" + if isinstance(idx, int): + if not isinstance(value, BaseJob): + raise TypeError("Can only assign a `BaseJob`") + self.jobs[idx] = deepcopy(value) + else: + if not isinstance(value, t.Iterable): + raise TypeError("Can only assign an iterable") + self.jobs[idx] = (deepcopy(val) for val in value) + + def __delitem__(self, idx: int | slice) -> None: + """Deletes the job at the specified index (idx).""" + del self.jobs[idx] + + def __len__(self) -> int: + """Returns the total number of jobs in the collection.""" + return len(self.jobs) + + def __str__(self) -> str: # pragma: no-cover + """Returns a string representation of the collection of jobs.""" + return f"Jobs: {self.jobs}" diff --git a/smartsim/launchable/colocated_job_group.py b/smartsim/launchable/colocated_job_group.py new file mode 100644 index 0000000000..db187a46c0 --- /dev/null +++ b/smartsim/launchable/colocated_job_group.py @@ -0,0 +1,75 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t +from copy import deepcopy + +from .base_job import BaseJob +from .base_job_group import BaseJobGroup + +if t.TYPE_CHECKING: + from typing_extensions import Self + + +class ColocatedJobGroup(BaseJobGroup): + """A colocated job group holds references to multiple jobs that + will be executed all at the same time when resources + permit. Execution is blocked until resources are available. + """ + + def __init__( + self, + jobs: t.List[BaseJob], + ) -> None: + super().__init__() + self._jobs = deepcopy(jobs) + + @property + def jobs(self) -> t.List[BaseJob]: + """This property method returns a list of BaseJob objects. + It represents the collection of jobs associated with an + instance of the BaseJobGroup abstract class. + """ + return self._jobs + + @t.overload + def __getitem__(self, idx: int) -> BaseJob: ... + @t.overload + def __getitem__(self, idx: slice) -> Self: ... + def __getitem__(self, idx: int | slice) -> BaseJob | Self: + """Retrieves the job at the specified index (idx).""" + jobs = self.jobs[idx] + if isinstance(jobs, BaseJob): + return jobs + return type(self)(jobs) + + def __str__(self) -> str: # pragma: no-cover + """Returns a string representation of the collection of + colocated job groups. + """ + return f"Colocated Jobs: {self.jobs}" diff --git a/smartsim/launchable/job.py b/smartsim/launchable/job.py new file mode 100644 index 0000000000..6082ba61d7 --- /dev/null +++ b/smartsim/launchable/job.py @@ -0,0 +1,160 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t +from copy import deepcopy + +from smartsim._core.commands.launch_commands import LaunchCommands +from smartsim._core.utils.helpers import check_name +from smartsim.launchable.base_job import BaseJob +from smartsim.log import get_logger +from smartsim.settings import LaunchSettings + +logger = get_logger(__name__) + +if t.TYPE_CHECKING: + from smartsim.entity.entity import SmartSimEntity + + +@t.final +class Job(BaseJob): + """A Job holds a reference to a SmartSimEntity and associated + LaunchSettings prior to launch. It is responsible for turning + the stored SmartSimEntity and LaunchSettings into commands that can be + executed by a launcher. Jobs are designed to be started by the Experiment. + """ + + def __init__( + self, + entity: SmartSimEntity, + launch_settings: LaunchSettings, + name: str | None = None, + ): + """Initialize a ``Job`` + + Jobs require a SmartSimEntity and a LaunchSettings. Optionally, users may provide + a name. To create a simple Job that echos `Hello World!`, consider the example below: + + .. highlight:: python + .. code-block:: python + + # Create an application that runs the 'echo' command + my_app = Application(name="my_app", exe="echo", exe_args="Hello World!") + # Define the launch settings using SLURM + srun_settings = LaunchSettings(launcher="slurm") + + # Create a Job with the `my_app` and `srun_settings` + my_job = Job(my_app, srun_settings, name="my_job") + + :param entity: the SmartSimEntity object + :param launch_settings: the LaunchSettings object + :param name: the Job name + """ + super().__init__() + """Initialize the parent class BaseJob""" + self.entity = entity + """Deepcopy of the SmartSimEntity object""" + self.launch_settings = launch_settings + """Deepcopy of the LaunchSettings object""" + self._name = name if name else entity.name + """Name of the Job""" + check_name(self._name) + + @property + def name(self) -> str: + """Return the name of the Job. + + :return: the name of the Job + """ + return self._name + + @name.setter + def name(self, name: str) -> None: + """Set the name of the Job. + + :param name: the name of the Job + """ + check_name(name) + logger.debug(f'Overwriting the Job name from "{self._name}" to "{name}"') + self._name = name + + @property + def entity(self) -> SmartSimEntity: + """Return the attached entity. + + :return: the attached SmartSimEntity + """ + return deepcopy(self._entity) + + @entity.setter + def entity(self, value: SmartSimEntity) -> None: + """Set the Job entity. + + :param value: the SmartSimEntity + :raises Type Error: if entity is not SmartSimEntity + """ + from smartsim.entity.entity import SmartSimEntity + + if not isinstance(value, SmartSimEntity): + raise TypeError("entity argument was not of type SmartSimEntity") + + self._entity = deepcopy(value) + + @property + def launch_settings(self) -> LaunchSettings: + """Return the attached LaunchSettings. + + :return: the attached LaunchSettings + """ + return deepcopy(self._launch_settings) + + @launch_settings.setter + def launch_settings(self, value: LaunchSettings) -> None: + """Set the Jobs LaunchSettings. + + :param value: the LaunchSettings + :raises Type Error: if launch_settings is not a LaunchSettings + """ + if not isinstance(value, LaunchSettings): + raise TypeError("launch_settings argument was not of type LaunchSettings") + + self._launch_settings = deepcopy(value) + + def get_launch_steps(self) -> LaunchCommands: + """Return the launch steps corresponding to the + internal data. + + :returns: The Jobs launch steps + """ + # TODO: return JobWarehouseRunner.run(self) + raise NotImplementedError + + def __str__(self) -> str: # pragma: no cover + string = f"SmartSim Entity: {self.entity}\n" + string += f"Launch Settings: {self.launch_settings}" + return string diff --git a/smartsim/launchable/job_group.py b/smartsim/launchable/job_group.py new file mode 100644 index 0000000000..f06313dd8d --- /dev/null +++ b/smartsim/launchable/job_group.py @@ -0,0 +1,96 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t +from copy import deepcopy + +from smartsim.log import get_logger + +from .._core.utils.helpers import check_name +from .base_job import BaseJob +from .base_job_group import BaseJobGroup + +logger = get_logger(__name__) + +if t.TYPE_CHECKING: + from typing_extensions import Self + + +@t.final +class JobGroup(BaseJobGroup): + """A job group holds references to multiple jobs that + will be executed all at the same time when resources + permit. Execution is blocked until resources are available. + """ + + def __init__( + self, + jobs: t.List[BaseJob], + name: str = "job_group", + ) -> None: + super().__init__() + self._jobs = deepcopy(jobs) + self._name = name + check_name(self._name) + + @property + def name(self) -> str: + """Retrieves the name of the JobGroup.""" + return self._name + + @name.setter + def name(self, name: str) -> None: + """Sets the name of the JobGroup.""" + check_name(name) + logger.debug(f'Overwriting Job name from "{self._name}" to "{name}"') + self._name = name + + @property + def jobs(self) -> t.List[BaseJob]: + """This property method returns a list of BaseJob objects. + It represents the collection of jobs associated with an + instance of the BaseJobGroup abstract class. + """ + return self._jobs + + @t.overload + def __getitem__(self, idx: int) -> BaseJob: ... + @t.overload + def __getitem__(self, idx: slice) -> Self: ... + def __getitem__(self, idx: int | slice) -> BaseJob | Self: + """Retrieves the job at the specified index (idx).""" + jobs = self.jobs[idx] + if isinstance(jobs, BaseJob): + return jobs + return type(self)(jobs) + + def __str__(self) -> str: # pragma: no-cover + """Returns a string representation of the collection of + job groups. + """ + return f"Job Groups: {self.jobs}" diff --git a/smartsim/launchable/launchable.py b/smartsim/launchable/launchable.py new file mode 100644 index 0000000000..7a8af2c19a --- /dev/null +++ b/smartsim/launchable/launchable.py @@ -0,0 +1,38 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class SmartSimObject: + """Base Class for SmartSim Objects""" + + ... + + +class Launchable(SmartSimObject): + """Base Class for anything than can be passed + into Experiment.start()""" + + ... diff --git a/smartsim/launchable/mpmd_job.py b/smartsim/launchable/mpmd_job.py new file mode 100644 index 0000000000..e526f10746 --- /dev/null +++ b/smartsim/launchable/mpmd_job.py @@ -0,0 +1,118 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import textwrap +import typing as t +from copy import deepcopy + +from smartsim.error.errors import SSUnsupportedError +from smartsim.launchable.base_job import BaseJob +from smartsim.launchable.mpmd_pair import MPMDPair +from smartsim.settings.launch_settings import LaunchSettings + +if t.TYPE_CHECKING: + from smartsim._core.commands.launch_commands import LaunchCommands + from smartsim.entity.entity import SmartSimEntity + + +def _check_launcher(mpmd_pairs: t.List[MPMDPair]) -> None: + """Enforce all pairs have the same launcher""" + flag = 0 + ret = None + for mpmd_pair in mpmd_pairs: + if flag == 1: + if ret == mpmd_pair.launch_settings.launcher: + flag = 0 + else: + raise SSUnsupportedError("MPMD pairs must all share the same launcher.") + ret = mpmd_pair.launch_settings.launcher + flag = 1 + + +def _check_entity(mpmd_pairs: t.List[MPMDPair]) -> None: + """Enforce all pairs have the same entity types""" + flag = 0 + ret: SmartSimEntity | None = None + for mpmd_pair in mpmd_pairs: + if flag == 1: + if type(ret) == type(mpmd_pair.entity): + flag = 0 + else: + raise SSUnsupportedError( + "MPMD pairs must all share the same entity type." + ) + ret = mpmd_pair.entity + flag = 1 + + +class MPMDJob(BaseJob): + """An MPMDJob holds references to SmartSimEntity and + LaunchSettings pairs. It is responsible for turning + The stored pairs into an MPMD command(s) + """ + + def __init__(self, mpmd_pairs: t.List[MPMDPair] | None = None) -> None: + super().__init__() + self._mpmd_pairs = deepcopy(mpmd_pairs) if mpmd_pairs else [] + _check_launcher(self._mpmd_pairs) + _check_entity(self._mpmd_pairs) + # TODO: self.warehouse_runner = MPMDJobWarehouseRunner + + @property + def mpmd_pairs(self) -> list[MPMDPair]: + return deepcopy(self._mpmd_pairs) + + @mpmd_pairs.setter + def mpmd_pairs(self, value: list[MPMDPair]) -> None: + self._mpmd_pair = deepcopy(value) + + def add_mpmd_pair( + self, entity: SmartSimEntity, launch_settings: LaunchSettings + ) -> None: + """ + Add a mpmd pair to the mpmd job + """ + self._mpmd_pairs.append(MPMDPair(entity, launch_settings)) + _check_launcher(self.mpmd_pairs) + _check_entity(self.mpmd_pairs) + + def get_launch_steps(self) -> LaunchCommands: + """Return the launch steps corresponding to the + internal data. + """ + # TODO: return MPMDJobWarehouseRunner.run(self) + raise NotImplementedError + + def __str__(self) -> str: # pragma: no cover + """returns A user-readable string of a MPMD Job""" + fmt = lambda mpmd_pair: textwrap.dedent(f"""\ + == MPMD Pair == + {mpmd_pair.entity} + {mpmd_pair.launch_settings} + """) + return "\n".join(map(fmt, self.mpmd_pairs)) diff --git a/smartsim/launchable/mpmd_pair.py b/smartsim/launchable/mpmd_pair.py new file mode 100644 index 0000000000..722a16cdee --- /dev/null +++ b/smartsim/launchable/mpmd_pair.py @@ -0,0 +1,43 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import copy +import typing as t + +from smartsim.settings.launch_settings import LaunchSettings + +if t.TYPE_CHECKING: + from smartsim.entity.entity import SmartSimEntity + + +class MPMDPair: + """Class to store MPMD Pairs""" + + def __init__(self, entity: SmartSimEntity, launch_settings: LaunchSettings): + self.entity = copy.deepcopy(entity) + self.launch_settings = copy.deepcopy(launch_settings) diff --git a/smartsim/ml/data.py b/smartsim/ml/data.py index 6175259b25..21e4e33a5d 100644 --- a/smartsim/ml/data.py +++ b/smartsim/ml/data.py @@ -29,8 +29,6 @@ from os import environ import numpy as np -from smartredis import Client, Dataset -from smartredis.error import RedisReplyError from ..error import SSInternalError from ..log import get_logger @@ -75,48 +73,25 @@ def __init__( self.num_classes = num_classes self._ds_name = form_name(self.list_name, "info") - def publish(self, client: Client) -> None: - """Upload DataInfo information to Orchestrator + def publish(self) -> None: + """Upload DataInfo information to FeatureStore The information is put on the DB as a DataSet, with strings stored as metastrings and integers stored as metascalars. - :param client: Client to connect to Database + :param client: Client to connect to Feature Store """ - info_ds = Dataset(self._ds_name) - info_ds.add_meta_string("sample_name", self.sample_name) - if self.target_name: - info_ds.add_meta_string("target_name", self.target_name) - if self.num_classes: - info_ds.add_meta_scalar("num_classes", self.num_classes) - client.put_dataset(info_ds) + ... - def download(self, client: Client) -> None: - """Download DataInfo information from Orchestrator + def download(self) -> None: + """Download DataInfo information from FeatureStore The information retrieved from the DB is used to populate this object's members. If the information is not available on the DB, the object members are not modified. - :param client: Client to connect to Database + :param client: Client to connect to Feature Store """ - try: - info_ds = client.get_dataset(self._ds_name) - except RedisReplyError as e: - # If the info was not published, proceed with default parameters - logger.warning( - "Could not retrieve data for DataInfo object, the following " - "values will be kept." - ) - logger.error(f"Original error from Redis was {e}") - logger.warning(str(self)) - return - self.sample_name = info_ds.get_meta_strings("sample_name")[0] - field_names = info_ds.get_metadata_field_names() - if "target_name" in field_names: - self.target_name = info_ds.get_meta_strings("target_name")[0] - if "num_classes" in field_names: - self.num_classes = int(info_ds.get_meta_scalars("num_classes")[0]) def __repr__(self) -> str: strings = ["DataInfo object"] @@ -134,7 +109,7 @@ class TrainingDataUploader: This class can be used to upload samples following a simple convention for naming. Once created, the function `publish_info` can be used - to put all details about the data set on the Orchestrator. A training + to put all details about the data set on the FeatureStore. A training process can thus access them and get all relevant information to download the batches which are uploaded. @@ -142,12 +117,12 @@ class TrainingDataUploader: and the data will be stored following the naming convention specified by the attributes of this class. - :param list_name: Name of the dataset as stored on the Orchestrator + :param list_name: Name of the dataset as stored on the FeatureStore :param sample_name: Name of samples tensor in uploaded Datasets :param target_name: Name of targets tensor (if needed) in uploaded Datasets :param num_classes: Number of classes of targets, if categorical - :param cluster: Whether the SmartSim Orchestrator is being run as a cluster - :param address: Address of Redis DB as : + :param cluster: Whether the SmartSim FeatureStore is being run as a cluster + :param address: :param rank: Rank of DataUploader in multi-process application (e.g. MPI rank). :param verbose: If output should be logged to screen. @@ -169,7 +144,6 @@ def __init__( if not sample_name: raise ValueError("Sample name can not be empty") - self.client = Client(cluster, address=address) self.verbose = verbose self.batch_idx = 0 self.rank = rank @@ -192,7 +166,7 @@ def num_classes(self) -> t.Optional[int]: return self._info.num_classes def publish_info(self) -> None: - self._info.publish(self.client) + self._info.publish() def put_batch( self, @@ -200,25 +174,20 @@ def put_batch( targets: t.Optional[np.ndarray] = None, # type: ignore[type-arg] ) -> None: batch_ds_name = form_name("training_samples", self.rank, self.batch_idx) - batch_ds = Dataset(batch_ds_name) - batch_ds.add_tensor(self.sample_name, samples) if ( targets is not None and self.target_name and (self.target_name != self.sample_name) ): - batch_ds.add_tensor(self.target_name, targets) if self.verbose: logger.info(f"Putting dataset {batch_ds_name} with samples and targets") else: if self.verbose: logger.info(f"Putting dataset {batch_ds_name} with samples") - self.client.put_dataset(batch_ds) - self.client.append_to_list(self.list_name, batch_ds) if self.verbose: logger.info(f"Added dataset to list {self.list_name}") - logger.info(f"List length {self.client.get_list_length(self.list_name)}") + logger.info(f"List length") self.batch_idx += 1 @@ -261,8 +230,8 @@ class DataDownloader: download, if a string is passed, it is used to download DataInfo data from DB, assuming it was stored with ``list_name=data_info_or_list_name`` :param list_name: Name of aggregation list used to upload data - :param cluster: Whether the Orchestrator will be run as a cluster - :param address: Address of Redis client as : + :param cluster: Whether the FeatureStore will be run as a cluster + :param address: :param replica_rank: When StaticDataDownloader is used distributedly, indicates the rank of this object :param num_replicas: When BatchDownlaoder is used distributedly, indicates @@ -301,11 +270,9 @@ def __init__( self._info = data_info_or_list_name elif isinstance(data_info_or_list_name, str): self._info = DataInfo(list_name=data_info_or_list_name) - client = Client(self.cluster, self.address) - self._info.download(client) + self._info.download() else: raise TypeError("data_info_or_list_name must be either DataInfo or str") - self._client: t.Optional[Client] = None sskeyin = environ.get("SSKEYIN", "") self.uploader_keys = sskeyin.split(",") @@ -314,12 +281,6 @@ def __init__( if init_samples: self.init_samples(max_fetch_trials, wait_interval) - @property - def client(self) -> Client: - if self._client is None: - raise ValueError("Client not initialized") - return self._client - def log(self, message: str) -> None: if self.verbose: logger.info(message) @@ -387,7 +348,6 @@ def init_samples(self, init_trials: int = -1, wait_interval: float = 10.0) -> No :param init_trials: maximum number of attempts to fetch data """ - self._client = Client(self.cluster, self.address) num_trials = 0 max_trials = init_trials or -1 @@ -406,73 +366,15 @@ def init_samples(self, init_trials: int = -1, wait_interval: float = 10.0) -> No if self.shuffle: np.random.shuffle(self.indices) - def _data_exists(self, batch_name: str, target_name: str) -> bool: - if self.need_targets: - return all( - self.client.tensor_exists(datum) for datum in [batch_name, target_name] - ) - - return bool(self.client.tensor_exists(batch_name)) + def _data_exists(self, batch_name: str, target_name: str) -> None: + pass def _add_samples(self, indices: t.List[int]) -> None: - datasets: t.List[Dataset] = [] - - if self.num_replicas == 1: - datasets = self.client.get_dataset_list_range( - self.list_name, start_index=indices[0], end_index=indices[-1] - ) - else: - for idx in indices: - datasets += self.client.get_dataset_list_range( - self.list_name, start_index=idx, end_index=idx - ) - - if self.samples is None: - self.samples = datasets[0].get_tensor(self.sample_name) - if self.need_targets: - self.targets = datasets[0].get_tensor(self.target_name) - - if len(datasets) > 1: - datasets = datasets[1:] - - if self.samples is not None: - for dataset in datasets: - self.samples = np.concatenate( - ( - t.cast("npt.NDArray[t.Any]", self.samples), - dataset.get_tensor(self.sample_name), - ) - ) - if self.need_targets: - self.targets = np.concatenate( - ( - t.cast("npt.NDArray[t.Any]", self.targets), - dataset.get_tensor(self.target_name), - ) - ) - - self.num_samples = t.cast("npt.NDArray[t.Any]", self.samples).shape[0] - self.indices = np.arange(self.num_samples) - - self.log(f"New dataset size: {self.num_samples}, batches: {len(self)}") + pass def _update_samples_and_targets(self) -> None: self.log(f"Rank {self.replica_rank} out of {self.num_replicas} replicas") - for uploader_idx, uploader_key in enumerate(self.uploader_keys): - if uploader_key: - self.client.use_list_ensemble_prefix(True) - self.client.set_data_source(uploader_key) - - list_length = self.client.get_list_length(self.list_name) - - # Strictly greater, because next_index is 0-based - if list_length > self.next_indices[uploader_idx]: - start = self.next_indices[uploader_idx] - indices = list(range(start, list_length, self.num_replicas)) - self._add_samples(indices) - self.next_indices[uploader_idx] = indices[-1] + self.num_replicas - def update_data(self) -> None: if self.dynamic: self._update_samples_and_targets() diff --git a/smartsim/ml/tf/utils.py b/smartsim/ml/tf/utils.py index 4e45f18475..74e39d35b2 100644 --- a/smartsim/ml/tf/utils.py +++ b/smartsim/ml/tf/utils.py @@ -39,12 +39,8 @@ def freeze_model( ) -> t.Tuple[str, t.List[str], t.List[str]]: """Freeze a Keras or TensorFlow Graph - to use a Keras or TensorFlow model in SmartSim, the model - must be frozen and the inputs and outputs provided to the - smartredis.client.set_model_from_file() method. - This utiliy function provides everything users need to take - a trained model and put it inside an ``orchestrator`` instance + a trained model and put it inside an ``featurestore`` instance :param model: TensorFlow or Keras model :param output_dir: output dir to save model file to @@ -81,12 +77,8 @@ def freeze_model( def serialize_model(model: keras.Model) -> t.Tuple[str, t.List[str], t.List[str]]: """Serialize a Keras or TensorFlow Graph - to use a Keras or TensorFlow model in SmartSim, the model - must be frozen and the inputs and outputs provided to the - smartredis.client.set_model() method. - This utiliy function provides everything users need to take - a trained model and put it inside an ``orchestrator`` instance. + a trained model and put it inside an ``featurestore`` instance. :param model: TensorFlow or Keras model :return: serialized model, model input layer names, model output layer names diff --git a/smartsim/ml/torch/data.py b/smartsim/ml/torch/data.py index c6a8e6eac5..71addd04e6 100644 --- a/smartsim/ml/torch/data.py +++ b/smartsim/ml/torch/data.py @@ -28,11 +28,23 @@ import numpy as np import torch -from smartredis import Client, Dataset +from smartsim.entity._mock import Mock from smartsim.ml.data import DataDownloader +class Client(Mock): + """Mock Client""" + + pass + + +class Dataset(Mock): + """Mock Dataset""" + + pass + + class _TorchDataGenerationCommon(DataDownloader, torch.utils.data.IterableDataset): def __init__(self, **kwargs: t.Any) -> None: init_samples = kwargs.pop("init_samples", False) diff --git a/smartsim/settings/__init__.py b/smartsim/settings/__init__.py index 8052121e25..59aeeffbd8 100644 --- a/smartsim/settings/__init__.py +++ b/smartsim/settings/__init__.py @@ -24,32 +24,65 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .alpsSettings import AprunSettings -from .base import RunSettings, SettingsBase -from .containers import Container, Singularity -from .dragonRunSettings import DragonRunSettings -from .lsfSettings import BsubBatchSettings, JsrunSettings -from .mpiSettings import MpiexecSettings, MpirunSettings, OrterunSettings -from .palsSettings import PalsMpiexecSettings -from .pbsSettings import QsubBatchSettings -from .sgeSettings import SgeQsubBatchSettings -from .slurmSettings import SbatchSettings, SrunSettings - -__all__ = [ - "AprunSettings", - "BsubBatchSettings", - "JsrunSettings", - "MpirunSettings", - "MpiexecSettings", - "OrterunSettings", - "QsubBatchSettings", - "RunSettings", - "SettingsBase", - "SbatchSettings", - "SgeQsubBatchSettings", - "SrunSettings", - "PalsMpiexecSettings", - "DragonRunSettings", - "Container", - "Singularity", -] +import typing as t + +from .base_settings import BaseSettings +from .batch_settings import BatchSettings +from .launch_settings import LaunchSettings + +__all__ = ["LaunchSettings", "BaseSettings", "BatchSettings"] + + +# TODO Mock imports for compiling tests +class SettingsBase: + def __init__(self, *_: t.Any, **__: t.Any) -> None: ... + def __getattr__(self, _: str) -> t.Any: ... + + +class QsubBatchSettings(SettingsBase): ... + + +class SgeQsubBatchSettings(SettingsBase): ... + + +class SbatchSettings(SettingsBase): ... + + +class Singularity: ... + + +class AprunSettings(SettingsBase): ... + + +class RunSettings(SettingsBase): ... + + +class DragonRunSettings(RunSettings): ... + + +class OrterunSettings(RunSettings): ... + + +class MpirunSettings(RunSettings): ... + + +class MpiexecSettings(RunSettings): ... + + +class JsrunSettings(RunSettings): ... + + +class BsubBatchSettings(RunSettings): ... + + +class PalsMpiexecSettings(RunSettings): ... + + +class SrunSettings(RunSettings): ... + + +class Container: ... + + +def create_batch_settings(*_: t.Any, **__: t.Any) -> t.Any: ... +def create_run_settings(*_: t.Any, **__: t.Any) -> t.Any: ... diff --git a/smartsim/settings/arguments/__init__.py b/smartsim/settings/arguments/__init__.py new file mode 100644 index 0000000000..f79a3b4bf9 --- /dev/null +++ b/smartsim/settings/arguments/__init__.py @@ -0,0 +1,30 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .batch_arguments import BatchArguments +from .launch_arguments import LaunchArguments + +__all__ = ["LaunchArguments", "BatchArguments"] diff --git a/smartsim/settings/arguments/batch/__init__.py b/smartsim/settings/arguments/batch/__init__.py new file mode 100644 index 0000000000..e6dc055ead --- /dev/null +++ b/smartsim/settings/arguments/batch/__init__.py @@ -0,0 +1,35 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .lsf import BsubBatchArguments +from .pbs import QsubBatchArguments +from .slurm import SlurmBatchArguments + +__all__ = [ + "BsubBatchArguments", + "QsubBatchArguments", + "SlurmBatchArguments", +] diff --git a/smartsim/settings/arguments/batch/lsf.py b/smartsim/settings/arguments/batch/lsf.py new file mode 100644 index 0000000000..23f948bd09 --- /dev/null +++ b/smartsim/settings/arguments/batch/lsf.py @@ -0,0 +1,163 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t + +from smartsim.log import get_logger + +from ...batch_command import BatchSchedulerType +from ...common import StringArgument +from ..batch_arguments import BatchArguments + +logger = get_logger(__name__) + + +class BsubBatchArguments(BatchArguments): + """A class to represent the arguments required for submitting batch + jobs using the bsub command. + """ + + def scheduler_str(self) -> str: + """Get the string representation of the scheduler + + :returns: The string representation of the scheduler + """ + return BatchSchedulerType.Lsf.value + + def set_walltime(self, walltime: str) -> None: + """Set the walltime + + This sets ``-W``. + + :param walltime: Time in hh:mm format, e.g. "10:00" for 10 hours, + if time is supplied in hh:mm:ss format, seconds + will be ignored and walltime will be set as ``hh:mm`` + """ + # For compatibility with other launchers, as explained in docstring + if walltime: + if len(walltime.split(":")) > 2: + walltime = ":".join(walltime.split(":")[:2]) + self.set("W", walltime) + + def set_smts(self, smts: int) -> None: + """Set SMTs + + This sets ``-alloc_flags``. If the user sets + SMT explicitly through ``-alloc_flags``, then that + takes precedence. + + :param smts: SMT (e.g on Summit: 1, 2, or 4) + """ + self.set("alloc_flags", str(smts)) + + def set_project(self, project: str) -> None: + """Set the project + + This sets ``-P``. + + :param time: project name + """ + self.set("P", project) + + def set_account(self, account: str) -> None: + """Set the project + + this function is an alias for `set_project`. + + :param account: project name + """ + return self.set_project(account) + + def set_nodes(self, num_nodes: int) -> None: + """Set the number of nodes for this batch job + + This sets ``-nnodes``. + + :param nodes: number of nodes + """ + self.set("nnodes", str(num_nodes)) + + def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + """Specify the hostlist for this job + + :param host_list: hosts to launch on + :raises TypeError: if not str or list of str + """ + if isinstance(host_list, str): + host_list = [host_list.strip()] + if not isinstance(host_list, list): + raise TypeError("host_list argument must be a list of strings") + if not all(isinstance(host, str) for host in host_list): + raise TypeError("host_list argument must be list of strings") + self.set("m", '"' + " ".join(host_list) + '"') + + def set_tasks(self, tasks: int) -> None: + """Set the number of tasks for this job + + This sets ``-n`` + + :param tasks: number of tasks + """ + self.set("n", str(tasks)) + + def set_queue(self, queue: str) -> None: + """Set the queue for this job + + This sets ``-q`` + + :param queue: The queue to submit the job on + """ + self.set("q", queue) + + def format_batch_args(self) -> t.List[str]: + """Get the formatted batch arguments for a preview + + :return: list of batch arguments for `bsub` + """ + opts = [] + + for opt, value in self._batch_args.items(): + + prefix = "-" # LSF only uses single dashses + + if value is None: + opts += [prefix + opt] + else: + opts += [f"{prefix}{opt}", str(value)] + + return opts + + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary scheduler argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` + """ + # Store custom arguments in the launcher_args + self._batch_args[key] = value diff --git a/smartsim/settings/arguments/batch/pbs.py b/smartsim/settings/arguments/batch/pbs.py new file mode 100644 index 0000000000..1262076656 --- /dev/null +++ b/smartsim/settings/arguments/batch/pbs.py @@ -0,0 +1,186 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t +from copy import deepcopy + +from smartsim.log import get_logger + +from ....error import SSConfigError +from ...batch_command import BatchSchedulerType +from ...common import StringArgument +from ..batch_arguments import BatchArguments + +logger = get_logger(__name__) + + +class QsubBatchArguments(BatchArguments): + """A class to represent the arguments required for submitting batch + jobs using the qsub command. + """ + + def scheduler_str(self) -> str: + """Get the string representation of the scheduler + + :returns: The string representation of the scheduler + """ + return BatchSchedulerType.Pbs.value + + def set_nodes(self, num_nodes: int) -> None: + """Set the number of nodes for this batch job + + In PBS, 'select' is the more primitive way of describing how + many nodes to allocate for the job. 'nodes' is equivalent to + 'select' with a 'place' statement. Assuming that only advanced + users would use 'set_resource' instead, defining the number of + nodes here is sets the 'nodes' resource. + + :param num_nodes: number of nodes + """ + + self.set("nodes", str(num_nodes)) + + def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + """Specify the hostlist for this job + + :param host_list: hosts to launch on + :raises TypeError: if not str or list of str + """ + if isinstance(host_list, str): + host_list = [host_list.strip()] + if not isinstance(host_list, list): + raise TypeError("host_list argument must be a list of strings") + if not all(isinstance(host, str) for host in host_list): + raise TypeError("host_list argument must be a list of strings") + self.set("hostname", ",".join(host_list)) + + def set_walltime(self, walltime: str) -> None: + """Set the walltime of the job + + format = "HH:MM:SS" + + If a walltime argument is provided in + ``QsubBatchSettings.resources``, then + this value will be overridden + + :param walltime: wall time + """ + self.set("walltime", walltime) + + def set_queue(self, queue: str) -> None: + """Set the queue for the batch job + + :param queue: queue name + """ + self.set("q", str(queue)) + + def set_ncpus(self, num_cpus: int) -> None: + """Set the number of cpus obtained in each node. + + If a select argument is provided in + ``QsubBatchSettings.resources``, then + this value will be overridden + + :param num_cpus: number of cpus per node in select + """ + self.set("ppn", str(num_cpus)) + + def set_account(self, account: str) -> None: + """Set the account for this batch job + + :param acct: account id + """ + self.set("A", str(account)) + + def format_batch_args(self) -> t.List[str]: + """Get the formatted batch arguments for a preview + + :return: batch arguments for `qsub` + :raises ValueError: if options are supplied without values + """ + opts, batch_arg_copy = self._create_resource_list(self._batch_args) + for opt, value in batch_arg_copy.items(): + prefix = "-" + if not value: + raise ValueError("PBS options without values are not allowed") + opts += [f"{prefix}{opt}", str(value)] + return opts + + @staticmethod + def _sanity_check_resources(batch_args: t.Dict[str, str | None]) -> None: + """Check that only select or nodes was specified in resources + + Note: For PBS Pro, nodes is equivalent to 'select' and 'place' so + they are not quite synonyms. Here we assume that + """ + + has_select = batch_args.get("select", None) + has_nodes = batch_args.get("nodes", None) + + if has_select and has_nodes: + raise SSConfigError( + "'select' and 'nodes' cannot both be specified. This can happen " + "if nodes were specified using the 'set_nodes' method and " + "'select' was set using 'set_resource'. Please only specify one." + ) + + def _create_resource_list( + self, batch_args: t.Dict[str, str | None] + ) -> t.Tuple[t.List[str], t.Dict[str, str | None]]: + self._sanity_check_resources(batch_args) + res = [] + + batch_arg_copy = deepcopy(batch_args) + # Construct the basic select/nodes statement + if select := batch_arg_copy.pop("select", None): + select_command = f"-l select={select}" + elif nodes := batch_arg_copy.pop("nodes", None): + select_command = f"-l nodes={nodes}" + else: + raise SSConfigError( + "Insufficient resource specification: no nodes or select statement" + ) + if ncpus := batch_arg_copy.pop("ppn", None): + select_command += f":ncpus={ncpus}" + if hosts := batch_arg_copy.pop("hostname", None): + hosts_list = ["=".join(("host", str(host))) for host in hosts.split(",")] + select_command += f":{'+'.join(hosts_list)}" + res += select_command.split() + if walltime := batch_arg_copy.pop("walltime", None): + res += ["-l", f"walltime={walltime}"] + + return res, batch_arg_copy + + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary launch argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` + """ + self._batch_args[key] = value diff --git a/smartsim/settings/arguments/batch/slurm.py b/smartsim/settings/arguments/batch/slurm.py new file mode 100644 index 0000000000..26f9cf8549 --- /dev/null +++ b/smartsim/settings/arguments/batch/slurm.py @@ -0,0 +1,156 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import re +import typing as t + +from smartsim.log import get_logger + +from ...batch_command import BatchSchedulerType +from ...common import StringArgument +from ..batch_arguments import BatchArguments + +logger = get_logger(__name__) + + +class SlurmBatchArguments(BatchArguments): + """A class to represent the arguments required for submitting batch + jobs using the sbatch command. + """ + + def scheduler_str(self) -> str: + """Get the string representation of the scheduler + + :returns: The string representation of the scheduler + """ + return BatchSchedulerType.Slurm.value + + def set_walltime(self, walltime: str) -> None: + """Set the walltime of the job + + format = "HH:MM:SS" + + :param walltime: wall time + """ + pattern = r"^\d{2}:\d{2}:\d{2}$" + if walltime and re.match(pattern, walltime): + self.set("time", str(walltime)) + else: + raise ValueError("Invalid walltime format. Please use 'HH:MM:SS' format.") + + def set_nodes(self, num_nodes: int) -> None: + """Set the number of nodes for this batch job + + This sets ``--nodes``. + + :param num_nodes: number of nodes + """ + self.set("nodes", str(num_nodes)) + + def set_account(self, account: str) -> None: + """Set the account for this batch job + + This sets ``--account``. + + :param account: account id + """ + self.set("account", account) + + def set_partition(self, partition: str) -> None: + """Set the partition for the batch job + + This sets ``--partition``. + + :param partition: partition name + """ + self.set("partition", str(partition)) + + def set_queue(self, queue: str) -> None: + """alias for set_partition + + Sets the partition for the slurm batch job + + :param queue: the partition to run the batch job on + """ + return self.set_partition(queue) + + def set_cpus_per_task(self, cpus_per_task: int) -> None: + """Set the number of cpus to use per task + + This sets ``--cpus-per-task`` + + :param num_cpus: number of cpus to use per task + """ + self.set("cpus-per-task", str(cpus_per_task)) + + def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + """Specify the hostlist for this job + + This sets ``--nodelist``. + + :param host_list: hosts to launch on + :raises TypeError: if not str or list of str + """ + if isinstance(host_list, str): + host_list = [host_list.strip()] + if not isinstance(host_list, list): + raise TypeError("host_list argument must be a list of strings") + if not all(isinstance(host, str) for host in host_list): + raise TypeError("host_list argument must be list of strings") + self.set("nodelist", ",".join(host_list)) + + def format_batch_args(self) -> t.List[str]: + """Get the formatted batch arguments for a preview + + :return: batch arguments for `sbatch` + """ + opts = [] + # TODO add restricted here + for opt, value in self._batch_args.items(): + # attach "-" prefix if argument is 1 character otherwise "--" + short_arg = len(opt) == 1 + prefix = "-" if short_arg else "--" + + if not value: + opts += [prefix + opt] + else: + if short_arg: + opts += [prefix + opt, str(value)] + else: + opts += ["=".join((prefix + opt, str(value)))] + return opts + + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary scheduler argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` + """ + # Store custom arguments in the launcher_args + self._batch_args[key] = value diff --git a/smartsim/settings/arguments/batch_arguments.py b/smartsim/settings/arguments/batch_arguments.py new file mode 100644 index 0000000000..0fa8d39640 --- /dev/null +++ b/smartsim/settings/arguments/batch_arguments.py @@ -0,0 +1,109 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import copy +import typing as t +from abc import ABC, abstractmethod + +from smartsim.log import get_logger + +from ..._core.utils.helpers import fmt_dict + +logger = get_logger(__name__) + + +class BatchArguments(ABC): + """Abstract base class that defines all generic scheduler + argument methods that are not supported. It is the + responsibility of child classes for each launcher to translate + the input parameter to a properly formatted launcher argument. + """ + + def __init__(self, batch_args: t.Dict[str, str | None] | None) -> None: + self._batch_args = copy.deepcopy(batch_args) or {} + """A dictionary of batch arguments""" + + @abstractmethod + def scheduler_str(self) -> str: + """Get the string representation of the launcher""" + pass + + @abstractmethod + def set_account(self, account: str) -> None: + """Set the account for this batch job + + :param account: account id + """ + pass + + @abstractmethod + def set_queue(self, queue: str) -> None: + """alias for set_partition + + Sets the partition for the slurm batch job + + :param queue: the partition to run the batch job on + """ + pass + + @abstractmethod + def set_walltime(self, walltime: str) -> None: + """Set the walltime of the job + + :param walltime: wall time + """ + pass + + @abstractmethod + def set_nodes(self, num_nodes: int) -> None: + """Set the number of nodes for this batch job + + :param num_nodes: number of nodes + """ + pass + + @abstractmethod + def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + """Specify the hostlist for this job + + :param host_list: hosts to launch on + :raises TypeError: if not str or list of str + """ + pass + + @abstractmethod + def format_batch_args(self) -> t.List[str]: + """Get the formatted batch arguments for a preview + + :return: batch arguments for Sbatch + """ + pass + + def __str__(self) -> str: # pragma: no-cover + string = f"\nScheduler Arguments:\n{fmt_dict(self._batch_args)}" + return string diff --git a/smartsim/settings/arguments/launch/__init__.py b/smartsim/settings/arguments/launch/__init__.py new file mode 100644 index 0000000000..629d45f679 --- /dev/null +++ b/smartsim/settings/arguments/launch/__init__.py @@ -0,0 +1,19 @@ +from .alps import AprunLaunchArguments +from .dragon import DragonLaunchArguments +from .local import LocalLaunchArguments +from .lsf import JsrunLaunchArguments +from .mpi import MpiexecLaunchArguments, MpirunLaunchArguments, OrterunLaunchArguments +from .pals import PalsMpiexecLaunchArguments +from .slurm import SlurmLaunchArguments + +__all__ = [ + "AprunLaunchArguments", + "DragonLaunchArguments", + "LocalLaunchArguments", + "JsrunLaunchArguments", + "MpirunLaunchArguments", + "MpiexecLaunchArguments", + "OrterunLaunchArguments", + "PalsMpiexecLaunchArguments", + "SlurmLaunchArguments", +] diff --git a/smartsim/settings/alpsSettings.py b/smartsim/settings/arguments/launch/alps.py similarity index 63% rename from smartsim/settings/alpsSettings.py rename to smartsim/settings/arguments/launch/alps.py index 54b9c7525b..356a443d65 100644 --- a/smartsim/settings/alpsSettings.py +++ b/smartsim/settings/arguments/launch/alps.py @@ -28,55 +28,33 @@ import typing as t -from ..error import SSUnsupportedError -from .base import RunSettings - - -class AprunSettings(RunSettings): - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - **kwargs: t.Any, - ): - """Settings to run job with ``aprun`` command - - ``AprunSettings`` can be used for the `pbs` launcher. - - :param exe: executable - :param exe_args: executable arguments - :param run_args: arguments for run command - :param env_vars: environment vars to launch job with - """ - super().__init__( - exe, - exe_args, - run_command="aprun", - run_args=run_args, - env_vars=env_vars, - **kwargs, - ) - self.mpmd: t.List[RunSettings] = [] - - def make_mpmd(self, settings: RunSettings) -> None: - """Make job an MPMD job - - This method combines two ``AprunSettings`` - into a single MPMD command joined with ':' - - :param settings: ``AprunSettings`` instance - """ - if self.colocated_db_settings: - raise SSUnsupportedError( - "Colocated models cannot be run as a mpmd workload" - ) - if self.container: - raise SSUnsupportedError( - "Containerized MPMD workloads are not yet supported." - ) - self.mpmd.append(settings) +from smartsim._core.arguments.shell import ShellLaunchArguments +from smartsim._core.dispatch import dispatch +from smartsim._core.shell.shell_launcher import ShellLauncher, make_shell_format_fn +from smartsim.log import get_logger + +from ...common import set_check_input +from ...launch_command import LauncherType + +logger = get_logger(__name__) +_as_aprun_command = make_shell_format_fn(run_command="aprun") + + +@dispatch(with_format=_as_aprun_command, to_launcher=ShellLauncher) +class AprunLaunchArguments(ShellLaunchArguments): + def _reserved_launch_args(self) -> set[str]: + """Return reserved launch arguments. + + :returns: The set of reserved launcher arguments + """ + return {"wdir"} + + def launcher_str(self) -> str: + """Get the string representation of the launcher + + :returns: The string representation of the launcher + """ + return LauncherType.Alps.value def set_cpus_per_task(self, cpus_per_task: int) -> None: """Set the number of cpus to use per task @@ -85,7 +63,7 @@ def set_cpus_per_task(self, cpus_per_task: int) -> None: :param cpus_per_task: number of cpus to use per task """ - self.run_args["cpus-per-pe"] = int(cpus_per_task) + self.set("cpus-per-pe", str(cpus_per_task)) def set_tasks(self, tasks: int) -> None: """Set the number of tasks for this job @@ -94,7 +72,7 @@ def set_tasks(self, tasks: int) -> None: :param tasks: number of tasks """ - self.run_args["pes"] = int(tasks) + self.set("pes", str(tasks)) def set_tasks_per_node(self, tasks_per_node: int) -> None: """Set the number of tasks for this job @@ -103,11 +81,13 @@ def set_tasks_per_node(self, tasks_per_node: int) -> None: :param tasks_per_node: number of tasks per node """ - self.run_args["pes-per-node"] = int(tasks_per_node) + self.set("pes-per-node", str(tasks_per_node)) def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: """Specify the hostlist for this job + This sets ``--node-list`` + :param host_list: hosts to launch on :raises TypeError: if not str or list of str """ @@ -117,7 +97,7 @@ def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: raise TypeError("host_list argument must be a list of strings") if not all(isinstance(host, str) for host in host_list): raise TypeError("host_list argument must be list of strings") - self.run_args["node-list"] = ",".join(host_list) + self.set("node-list", ",".join(host_list)) def set_hostlist_from_file(self, file_path: str) -> None: """Use the contents of a file to set the node list @@ -126,11 +106,13 @@ def set_hostlist_from_file(self, file_path: str) -> None: :param file_path: Path to the hostlist file """ - self.run_args["node-list-file"] = file_path + self.set("node-list-file", file_path) def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: """Specify a list of hosts to exclude for launching this job + This sets ``--exclude-node-list`` + :param host_list: hosts to exclude :raises TypeError: if not str or list of str """ @@ -140,7 +122,7 @@ def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: raise TypeError("host_list argument must be a list of strings") if not all(isinstance(host, str) for host in host_list): raise TypeError("host_list argument must be list of strings") - self.run_args["exclude-node-list"] = ",".join(host_list) + self.set("exclude-node-list", ",".join(host_list)) def set_cpu_bindings(self, bindings: t.Union[int, t.List[int]]) -> None: """Specifies the cores to which MPI processes are bound @@ -151,7 +133,7 @@ def set_cpu_bindings(self, bindings: t.Union[int, t.List[int]]) -> None: """ if isinstance(bindings, int): bindings = [bindings] - self.run_args["cpu-binding"] = ",".join(str(int(num)) for num in bindings) + self.set("cpu-binding", ",".join(str(num) for num in bindings)) def set_memory_per_node(self, memory_per_node: int) -> None: """Specify the real memory required per node @@ -160,7 +142,16 @@ def set_memory_per_node(self, memory_per_node: int) -> None: :param memory_per_node: Per PE memory limit in megabytes """ - self.run_args["memory-per-pe"] = int(memory_per_node) + self.set("memory-per-pe", str(memory_per_node)) + + def set_walltime(self, walltime: str) -> None: + """Set the walltime of the job + + Walltime is given in total number of seconds + + :param walltime: wall time + """ + self.set("cpu-time-limit", str(walltime)) def set_verbose_launch(self, verbose: bool) -> None: """Set the job to run in verbose mode @@ -170,9 +161,9 @@ def set_verbose_launch(self, verbose: bool) -> None: :param verbose: Whether the job should be run verbosely """ if verbose: - self.run_args["debug"] = 7 + self.set("debug", "7") else: - self.run_args.pop("debug", None) + self._launch_args.pop("debug", None) def set_quiet_launch(self, quiet: bool) -> None: """Set the job to run in quiet mode @@ -182,48 +173,56 @@ def set_quiet_launch(self, quiet: bool) -> None: :param quiet: Whether the job should be run quietly """ if quiet: - self.run_args["quiet"] = None + self._launch_args["quiet"] = None else: - self.run_args.pop("quiet", None) + self._launch_args.pop("quiet", None) - def format_run_args(self) -> t.List[str]: - """Return a list of ALPS formatted run arguments - - :return: list of ALPS arguments for these settings - """ - # args launcher uses - args = [] - restricted = ["wdir"] - - for opt, value in self.run_args.items(): - if opt not in restricted: - short_arg = bool(len(str(opt)) == 1) - prefix = "-" if short_arg else "--" - if not value: - args += [prefix + opt] - else: - if short_arg: - args += [prefix + opt, str(value)] - else: - args += ["=".join((prefix + opt, str(value)))] - return args - - def format_env_vars(self) -> t.List[str]: + def format_env_vars(self, env_vars: t.Mapping[str, str | None]) -> list[str]: """Format the environment variables for aprun :return: list of env vars """ formatted = [] - if self.env_vars: - for name, value in self.env_vars.items(): + if env_vars: + for name, value in env_vars.items(): formatted += ["-e", name + "=" + str(value)] return formatted - def set_walltime(self, walltime: str) -> None: - """Set the walltime of the job + def format_launch_args(self) -> t.List[str]: + """Return a list of ALPS formatted run arguments - Walltime is given in total number of seconds + :return: list of ALPS arguments for these settings + """ + # args launcher uses + args = [] + for opt, value in self._launch_args.items(): + short_arg = len(opt) == 1 + prefix = "-" if short_arg else "--" + if not value: + args += [prefix + opt] + else: + if short_arg: + args += [prefix + opt, str(value)] + else: + args += ["=".join((prefix + opt, str(value)))] + return args - :param walltime: wall time + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary launch argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` """ - self.run_args["cpu-time-limit"] = str(walltime) + set_check_input(key, value) + if key in self._reserved_launch_args(): + logger.warning( + ( + f"Could not set argument '{key}': " + f"it is a reserved argument of '{type(self).__name__}'" + ) + ) + return + if key in self._launch_args and key != self._launch_args[key]: + logger.warning(f"Overwritting argument '{key}' with value '{value}'") + self._launch_args[key] = value diff --git a/smartsim/settings/dragonRunSettings.py b/smartsim/settings/arguments/launch/dragon.py similarity index 70% rename from smartsim/settings/dragonRunSettings.py rename to smartsim/settings/arguments/launch/dragon.py index 15e5855448..d8044267e6 100644 --- a/smartsim/settings/dragonRunSettings.py +++ b/smartsim/settings/arguments/launch/dragon.py @@ -30,58 +30,50 @@ from typing_extensions import override -from ..log import get_logger -from .base import RunSettings - -logger = get_logger(__name__) +from smartsim.log import get_logger +from ...common import set_check_input +from ...launch_command import LauncherType +from ..launch_arguments import LaunchArguments -class DragonRunSettings(RunSettings): - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - **kwargs: t.Any, - ) -> None: - """Initialize run parameters for a Dragon process +logger = get_logger(__name__) - ``DragonRunSettings`` should only be used on systems where Dragon - is available and installed in the current environment. - If an allocation is specified, the instance receiving these run - parameters will launch on that allocation. +class DragonLaunchArguments(LaunchArguments): + def launcher_str(self) -> str: + """Get the string representation of the launcher - :param exe: executable to run - :param exe_args: executable arguments, defaults to None - :param env_vars: environment variables for job, defaults to None - :param alloc: allocation ID if running on existing alloc, defaults to None + :returns: The string representation of the launcher """ - super().__init__( - exe, - exe_args, - run_command="", - env_vars=env_vars, - **kwargs, - ) + return LauncherType.Dragon.value - @override def set_nodes(self, nodes: int) -> None: """Set the number of nodes :param nodes: number of nodes to run with """ - self.run_args["nodes"] = nodes + self.set("nodes", str(nodes)) - @override def set_tasks_per_node(self, tasks_per_node: int) -> None: """Set the number of tasks for this job :param tasks_per_node: number of tasks per node """ - self.run_args["tasks-per-node"] = tasks_per_node + self.set("tasks_per_node", str(tasks_per_node)) @override + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary launch argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` + """ + set_check_input(key, value) + if key in self._launch_args and key != self._launch_args[key]: + logger.warning(f"Overwritting argument '{key}' with value '{value}'") + self._launch_args[key] = value + def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: """Specify the node feature for this job @@ -92,10 +84,8 @@ def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: feature_list = feature_list.strip().split() elif not all(isinstance(feature, str) for feature in feature_list): raise TypeError("feature_list must be string or list of strings") + self.set("node-feature", ",".join(feature_list)) - self.run_args["node-feature"] = ",".join(feature_list) - - @override def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: """Specify the hostlist for this job @@ -112,19 +102,18 @@ def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: cleaned_list = [host.strip() for host in host_list if host and host.strip()] if not len(cleaned_list) == len(host_list): raise ValueError(f"invalid names found in hostlist: {host_list}") - - self.run_args["host-list"] = ",".join(cleaned_list) + self.set("host-list", ",".join(cleaned_list)) def set_cpu_affinity(self, devices: t.List[int]) -> None: """Set the CPU affinity for this job :param devices: list of CPU indices to execute on """ - self.run_args["cpu-affinity"] = ",".join(str(device) for device in devices) + self.set("cpu-affinity", ",".join(str(device) for device in devices)) def set_gpu_affinity(self, devices: t.List[int]) -> None: """Set the GPU affinity for this job :param devices: list of GPU indices to execute on. """ - self.run_args["gpu-affinity"] = ",".join(str(device) for device in devices) + self.set("gpu-affinity", ",".join(str(device) for device in devices)) diff --git a/smartsim/settings/arguments/launch/local.py b/smartsim/settings/arguments/launch/local.py new file mode 100644 index 0000000000..2c589cb48d --- /dev/null +++ b/smartsim/settings/arguments/launch/local.py @@ -0,0 +1,87 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t + +from smartsim._core.arguments.shell import ShellLaunchArguments +from smartsim._core.dispatch import dispatch +from smartsim._core.shell.shell_launcher import ShellLauncher, make_shell_format_fn +from smartsim.log import get_logger + +from ...common import StringArgument, set_check_input +from ...launch_command import LauncherType + +logger = get_logger(__name__) +_as_local_command = make_shell_format_fn(run_command=None) + + +@dispatch(with_format=_as_local_command, to_launcher=ShellLauncher) +class LocalLaunchArguments(ShellLaunchArguments): + def launcher_str(self) -> str: + """Get the string representation of the launcher + + :returns: The string representation of the launcher + """ + return LauncherType.Local.value + + def format_env_vars(self, env_vars: t.Mapping[str, str | None]) -> list[str]: + """Build bash compatible sequence of strings to specify an environment + + :param env_vars: An environment mapping + :returns: the formatted string of environment variables + """ + formatted = [] + for key, val in env_vars.items(): + if val is None: + formatted.append(f"{key}=") + else: + formatted.append(f"{key}={val}") + return formatted + + def format_launch_args(self) -> t.List[str]: + """Build launcher argument string + + :returns: formatted list of launcher arguments + """ + formatted = [] + for arg, value in self._launch_args.items(): + formatted.append(arg) + formatted.append(str(value)) + return formatted + + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary launch argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` + """ + set_check_input(key, value) + if key in self._launch_args and key != self._launch_args[key]: + logger.warning(f"Overwritting argument '{key}' with value '{value}'") + self._launch_args[key] = value diff --git a/smartsim/settings/arguments/launch/lsf.py b/smartsim/settings/arguments/launch/lsf.py new file mode 100644 index 0000000000..ed24271985 --- /dev/null +++ b/smartsim/settings/arguments/launch/lsf.py @@ -0,0 +1,152 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import pathlib +import subprocess +import typing as t + +from smartsim._core.arguments.shell import ShellLaunchArguments +from smartsim._core.dispatch import EnvironMappingType, dispatch +from smartsim._core.shell.shell_launcher import ShellLauncher, ShellLauncherCommand +from smartsim.log import get_logger + +from ...common import set_check_input +from ...launch_command import LauncherType + +logger = get_logger(__name__) + + +def _as_jsrun_command( + args: ShellLaunchArguments, + exe: t.Sequence[str], + path: pathlib.Path, + env: EnvironMappingType, + stdout_path: pathlib.Path, + stderr_path: pathlib.Path, +) -> ShellLauncherCommand: + command_tuple = ( + "jsrun", + *(args.format_launch_args() or ()), + f"--stdio_stdout={stdout_path}", + f"--stdio_stderr={stderr_path}", + "--", + *exe, + ) + return ShellLauncherCommand( + env, path, subprocess.DEVNULL, subprocess.DEVNULL, command_tuple + ) + + +@dispatch(with_format=_as_jsrun_command, to_launcher=ShellLauncher) +class JsrunLaunchArguments(ShellLaunchArguments): + def launcher_str(self) -> str: + """Get the string representation of the launcher + + :returns: The string representation of the launcher + """ + return LauncherType.Lsf.value + + def _reserved_launch_args(self) -> set[str]: + """Return reserved launch arguments. + + :returns: The set of reserved launcher arguments + """ + return {"chdir", "h", "stdio_stdout", "o", "stdio_stderr", "k"} + + def set_tasks(self, tasks: int) -> None: + """Set the number of tasks for this job + + This sets ``--np`` + + :param tasks: number of tasks + """ + self.set("np", str(tasks)) + + def set_binding(self, binding: str) -> None: + """Set binding + + This sets ``--bind`` + + :param binding: Binding, e.g. `packed:21` + """ + self.set("bind", binding) + + def format_env_vars(self, env_vars: t.Mapping[str, str | None]) -> list[str]: + """Format environment variables. Each variable needs + to be passed with ``--env``. If a variable is set to ``None``, + its value is propagated from the current environment. + + :returns: formatted list of strings to export variables + """ + format_str = [] + for k, v in env_vars.items(): + if v: + format_str += ["-E", f"{k}={v}"] + else: + format_str += ["-E", f"{k}"] + return format_str + + def format_launch_args(self) -> t.List[str]: + """Return a list of LSF formatted run arguments + + :return: list of LSF arguments for these settings + """ + # args launcher uses + args = [] + + for opt, value in self._launch_args.items(): + short_arg = bool(len(str(opt)) == 1) + prefix = "-" if short_arg else "--" + if value is None: + args += [prefix + opt] + else: + if short_arg: + args += [prefix + opt, str(value)] + else: + args += ["=".join((prefix + opt, str(value)))] + return args + + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary launch argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` + """ + set_check_input(key, value) + if key in self._reserved_launch_args(): + logger.warning( + ( + f"Could not set argument '{key}': " + f"it is a reserved argument of '{type(self).__name__}'" + ) + ) + return + if key in self._launch_args and key != self._launch_args[key]: + logger.warning(f"Overwritting argument '{key}' with value '{value}'") + self._launch_args[key] = value diff --git a/smartsim/settings/arguments/launch/mpi.py b/smartsim/settings/arguments/launch/mpi.py new file mode 100644 index 0000000000..ce8c43aa5c --- /dev/null +++ b/smartsim/settings/arguments/launch/mpi.py @@ -0,0 +1,255 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t + +from smartsim._core.arguments.shell import ShellLaunchArguments +from smartsim._core.dispatch import dispatch +from smartsim._core.shell.shell_launcher import ShellLauncher, make_shell_format_fn +from smartsim.log import get_logger + +from ...common import set_check_input +from ...launch_command import LauncherType + +logger = get_logger(__name__) +_as_mpirun_command = make_shell_format_fn("mpirun") +_as_mpiexec_command = make_shell_format_fn("mpiexec") +_as_orterun_command = make_shell_format_fn("orterun") + + +class _BaseMPILaunchArguments(ShellLaunchArguments): + def _reserved_launch_args(self) -> set[str]: + """Return reserved launch arguments. + + :returns: The set of reserved launcher arguments + """ + return {"wd", "wdir"} + + def set_task_map(self, task_mapping: str) -> None: + """Set ``mpirun`` task mapping + + this sets ``--map-by `` + + For examples, see the man page for ``mpirun`` + + :param task_mapping: task mapping + """ + self.set("map-by", task_mapping) + + def set_cpus_per_task(self, cpus_per_task: int) -> None: + """Set the number of tasks for this job + + This sets ``--cpus-per-proc`` for MPI compliant implementations + + note: this option has been deprecated in openMPI 4.0+ + and will soon be replaced. + + :param cpus_per_task: number of tasks + """ + self.set("cpus-per-proc", str(cpus_per_task)) + + def set_executable_broadcast(self, dest_path: str) -> None: + """Copy the specified executable(s) to remote machines + + This sets ``--preload-binary`` + + :param dest_path: Destination path (Ignored) + """ + if dest_path is not None and isinstance(dest_path, str): + logger.warning( + ( + f"{type(self)} cannot set a destination path during broadcast. " + "Using session directory instead" + ) + ) + self.set("preload-binary", dest_path) + + def set_cpu_binding_type(self, bind_type: str) -> None: + """Specifies the cores to which MPI processes are bound + + This sets ``--bind-to`` for MPI compliant implementations + + :param bind_type: binding type + """ + self.set("bind-to", bind_type) + + def set_tasks_per_node(self, tasks_per_node: int) -> None: + """Set the number of tasks per node + + :param tasks_per_node: number of tasks to launch per node + """ + self.set("npernode", str(tasks_per_node)) + + def set_tasks(self, tasks: int) -> None: + """Set the number of tasks for this job + + This sets ``-n`` for MPI compliant implementations + + :param tasks: number of tasks + """ + self.set("n", str(tasks)) + + def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + """Set the hostlist for the ``mpirun`` command + + This sets ``--host`` + + :param host_list: list of host names + :raises TypeError: if not str or list of str + """ + if isinstance(host_list, str): + host_list = [host_list.strip()] + if not isinstance(host_list, list): + raise TypeError("host_list argument must be a list of strings") + if not all(isinstance(host, str) for host in host_list): + raise TypeError("host_list argument must be list of strings") + self.set("host", ",".join(host_list)) + + def set_hostlist_from_file(self, file_path: str) -> None: + """Use the contents of a file to set the hostlist + + This sets ``--hostfile`` + + :param file_path: Path to the hostlist file + """ + self.set("hostfile", file_path) + + def set_verbose_launch(self, verbose: bool) -> None: + """Set the job to run in verbose mode + + This sets ``--verbose`` + + :param verbose: Whether the job should be run verbosely + """ + if verbose: + self.set("verbose", None) + else: + self._launch_args.pop("verbose", None) + + def set_walltime(self, walltime: str) -> None: + """Set the maximum number of seconds that a job will run + + This sets ``--timeout`` + + :param walltime: number like string of seconds that a job will run in secs + """ + self.set("timeout", walltime) + + def set_quiet_launch(self, quiet: bool) -> None: + """Set the job to run in quiet mode + + This sets ``--quiet`` + + :param quiet: Whether the job should be run quietly + """ + if quiet: + self.set("quiet", None) + else: + self._launch_args.pop("quiet", None) + + def format_env_vars(self, env_vars: t.Mapping[str, str | None]) -> list[str]: + """Format the environment variables for mpirun + + :return: list of env vars + """ + formatted = [] + env_string = "-x" + + if env_vars: + for name, value in env_vars.items(): + if value: + formatted += [env_string, "=".join((name, str(value)))] + else: + formatted += [env_string, name] + return formatted + + def format_launch_args(self) -> t.List[str]: + """Return a list of MPI-standard formatted run arguments + + :return: list of MPI-standard arguments for these settings + """ + # args launcher uses + args = [] + + for opt, value in self._launch_args.items(): + prefix = "--" + if not value: + args += [prefix + opt] + else: + args += [prefix + opt, str(value)] + return args + + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary launch argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` + """ + set_check_input(key, value) + if key in self._reserved_launch_args(): + logger.warning( + ( + f"Could not set argument '{key}': " + f"it is a reserved argument of '{type(self).__name__}'" + ) + ) + return + if key in self._launch_args and key != self._launch_args[key]: + logger.warning(f"Overwritting argument '{key}' with value '{value}'") + self._launch_args[key] = value + + +@dispatch(with_format=_as_mpirun_command, to_launcher=ShellLauncher) +class MpirunLaunchArguments(_BaseMPILaunchArguments): + def launcher_str(self) -> str: + """Get the string representation of the launcher + + :returns: The string representation of the launcher + """ + return LauncherType.Mpirun.value + + +@dispatch(with_format=_as_mpiexec_command, to_launcher=ShellLauncher) +class MpiexecLaunchArguments(_BaseMPILaunchArguments): + def launcher_str(self) -> str: + """Get the string representation of the launcher + + :returns: The string representation of the launcher + """ + return LauncherType.Mpiexec.value + + +@dispatch(with_format=_as_orterun_command, to_launcher=ShellLauncher) +class OrterunLaunchArguments(_BaseMPILaunchArguments): + def launcher_str(self) -> str: + """Get the string representation of the launcher + + :returns: The string representation of the launcher + """ + return LauncherType.Orterun.value diff --git a/smartsim/settings/arguments/launch/pals.py b/smartsim/settings/arguments/launch/pals.py new file mode 100644 index 0000000000..d48dc799b9 --- /dev/null +++ b/smartsim/settings/arguments/launch/pals.py @@ -0,0 +1,162 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t + +from smartsim._core.arguments.shell import ShellLaunchArguments +from smartsim._core.dispatch import dispatch +from smartsim._core.shell.shell_launcher import ShellLauncher, make_shell_format_fn +from smartsim.log import get_logger + +from ...common import set_check_input +from ...launch_command import LauncherType + +logger = get_logger(__name__) +_as_pals_command = make_shell_format_fn(run_command="mpiexec") + + +@dispatch(with_format=_as_pals_command, to_launcher=ShellLauncher) +class PalsMpiexecLaunchArguments(ShellLaunchArguments): + def launcher_str(self) -> str: + """Get the string representation of the launcher + + :returns: The string representation of the launcher + """ + return LauncherType.Pals.value + + def _reserved_launch_args(self) -> set[str]: + """Return reserved launch arguments. + + :returns: The set of reserved launcher arguments + """ + return {"wdir", "wd"} + + def set_cpu_binding_type(self, bind_type: str) -> None: + """Specifies the cores to which MPI processes are bound + + This sets ``--bind-to`` for MPI compliant implementations + + :param bind_type: binding type + """ + self.set("bind-to", bind_type) + + def set_tasks(self, tasks: int) -> None: + """Set the number of tasks + + :param tasks: number of total tasks to launch + """ + self.set("np", str(tasks)) + + def set_executable_broadcast(self, dest_path: str) -> None: + """Copy the specified executable(s) to remote machines + + This sets ``--transfer`` + + :param dest_path: Destination path (Ignored) + """ + self.set("transfer", dest_path) + + def set_tasks_per_node(self, tasks_per_node: int) -> None: + """Set the number of tasks per node + + This sets ``--ppn`` + + :param tasks_per_node: number of tasks to launch per node + """ + self.set("ppn", str(tasks_per_node)) + + def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + """Set the hostlist for the PALS ``mpiexec`` command + + This sets ``hosts`` + + :param host_list: list of host names + :raises TypeError: if not str or list of str + """ + if isinstance(host_list, str): + host_list = [host_list.strip()] + if not isinstance(host_list, list): + raise TypeError("host_list argument must be a list of strings") + if not all(isinstance(host, str) for host in host_list): + raise TypeError("host_list argument must be list of strings") + self.set("hosts", ",".join(host_list)) + + def format_env_vars(self, env_vars: t.Mapping[str, str | None]) -> list[str]: + """Format the environment variables for mpirun + + :return: list of env vars + """ + formatted = [] + + export_vars = [] + if env_vars: + for name, value in env_vars.items(): + if value: + formatted += ["--env", "=".join((name, str(value)))] + else: + export_vars.append(name) + + if export_vars: + formatted += ["--envlist", ",".join(export_vars)] + + return formatted + + def format_launch_args(self) -> t.List[str]: + """Return a list of MPI-standard formatted launcher arguments + + :return: list of MPI-standard arguments for these settings + """ + # args launcher uses + args = [] + + for opt, value in self._launch_args.items(): + prefix = "--" + if not value: + args += [prefix + opt] + else: + args += [prefix + opt, str(value)] + + return args + + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary launch argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` + """ + set_check_input(key, value) + if key in self._reserved_launch_args(): + logger.warning( + f"Could not set argument '{key}': " + f"it is a reserved argument of '{type(self).__name__}'" + ) + return + if key in self._launch_args and key != self._launch_args[key]: + logger.warning(f"Overwritting argument '{key}' with value '{value}'") + self._launch_args[key] = value diff --git a/smartsim/settings/arguments/launch/slurm.py b/smartsim/settings/arguments/launch/slurm.py new file mode 100644 index 0000000000..c5dceff628 --- /dev/null +++ b/smartsim/settings/arguments/launch/slurm.py @@ -0,0 +1,353 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import os +import pathlib +import re +import subprocess +import typing as t + +from smartsim._core.arguments.shell import ShellLaunchArguments +from smartsim._core.dispatch import EnvironMappingType, dispatch +from smartsim._core.shell.shell_launcher import ShellLauncher, ShellLauncherCommand +from smartsim.log import get_logger + +from ...common import set_check_input +from ...launch_command import LauncherType + +logger = get_logger(__name__) + + +def _as_srun_command( + args: ShellLaunchArguments, + exe: t.Sequence[str], + path: pathlib.Path, + env: EnvironMappingType, + stdout_path: pathlib.Path, + stderr_path: pathlib.Path, +) -> ShellLauncherCommand: + command_tuple = ( + "srun", + *(args.format_launch_args() or ()), + f"--output={stdout_path}", + f"--error={stderr_path}", + "--", + *exe, + ) + return ShellLauncherCommand( + env, path, subprocess.DEVNULL, subprocess.DEVNULL, command_tuple + ) + + +@dispatch(with_format=_as_srun_command, to_launcher=ShellLauncher) +class SlurmLaunchArguments(ShellLaunchArguments): + def launcher_str(self) -> str: + """Get the string representation of the launcher + + :returns: The string representation of the launcher + """ + return LauncherType.Slurm.value + + def _reserved_launch_args(self) -> set[str]: + """Return reserved launch arguments. + + :returns: The set of reserved launcher arguments + """ + return {"chdir", "D"} + + def set_nodes(self, nodes: int) -> None: + """Set the number of nodes + + Effectively this is setting: ``srun --nodes `` + + :param nodes: nodes to launch on + :return: launcher argument + """ + self.set("nodes", str(nodes)) + + def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + """Specify the hostlist for this job + + This sets ``--nodelist`` + + :param host_list: hosts to launch on + :raises TypeError: if not str or list of str + """ + if isinstance(host_list, str): + host_list = [host_list.strip()] + elif not isinstance(host_list, list): + raise TypeError("host_list argument must be a string or list of strings") + elif not all(isinstance(host, str) for host in host_list): + raise TypeError("host_list argument must be list of strings") + self.set("nodelist", ",".join(host_list)) + + def set_hostlist_from_file(self, file_path: str) -> None: + """Use the contents of a file to set the node list + + This sets ``--nodefile`` + + :param file_path: Path to the nodelist file + """ + self.set("nodefile", file_path) + + def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: + """Specify a list of hosts to exclude for launching this job + + :param host_list: hosts to exclude + :raises TypeError: if not str or list of str + """ + if isinstance(host_list, str): + host_list = [host_list.strip()] + if not isinstance(host_list, list): + raise TypeError("host_list argument must be a list of strings") + if not all(isinstance(host, str) for host in host_list): + raise TypeError("host_list argument must be list of strings") + self.set("exclude", ",".join(host_list)) + + def set_cpus_per_task(self, cpus_per_task: int) -> None: + """Set the number of cpus to use per task + + This sets ``--cpus-per-task`` + + :param num_cpus: number of cpus to use per task + """ + self.set("cpus-per-task", str(cpus_per_task)) + + def set_tasks(self, tasks: int) -> None: + """Set the number of tasks for this job + + This sets ``--ntasks`` + + :param tasks: number of tasks + """ + self.set("ntasks", str(tasks)) + + def set_tasks_per_node(self, tasks_per_node: int) -> None: + """Set the number of tasks for this job + + This sets ``--ntasks-per-node`` + + :param tasks_per_node: number of tasks per node + """ + self.set("ntasks-per-node", str(tasks_per_node)) + + def set_cpu_bindings(self, bindings: t.Union[int, t.List[int]]) -> None: + """Bind by setting CPU masks on tasks + + This sets ``--cpu-bind`` using the ``map_cpu:`` option + + :param bindings: List specifing the cores to which MPI processes are bound + """ + if isinstance(bindings, int): + bindings = [bindings] + self.set("cpu_bind", "map_cpu:" + ",".join(str(num) for num in bindings)) + + def set_memory_per_node(self, memory_per_node: int) -> None: + """Specify the real memory required per node + + This sets ``--mem`` in megabytes + + :param memory_per_node: Amount of memory per node in megabytes + """ + self.set("mem", f"{memory_per_node}M") + + def set_executable_broadcast(self, dest_path: str) -> None: + """Copy executable file to allocated compute nodes + + This sets ``--bcast`` + + :param dest_path: Path to copy an executable file + """ + self.set("bcast", dest_path) + + def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: + """Specify the node feature for this job + + This sets ``-C`` + + :param feature_list: node feature to launch on + :raises TypeError: if not str or list of str + """ + if isinstance(feature_list, str): + feature_list = [feature_list.strip()] + elif not all(isinstance(feature, str) for feature in feature_list): + raise TypeError("node_feature argument must be string or list of strings") + self.set("C", ",".join(feature_list)) + + def set_walltime(self, walltime: str) -> None: + """Set the walltime of the job + + format = "HH:MM:SS" + + :param walltime: wall time + """ + pattern = r"^\d{2}:\d{2}:\d{2}$" + if walltime and re.match(pattern, walltime): + self.set("time", str(walltime)) + else: + raise ValueError("Invalid walltime format. Please use 'HH:MM:SS' format.") + + def set_het_group(self, het_group: t.Iterable[int]) -> None: + """Set the heterogeneous group for this job + + this sets `--het-group` + + :param het_group: list of heterogeneous groups + """ + het_size_env = os.getenv("SLURM_HET_SIZE") + if het_size_env is None: + msg = "Requested to set het group, but the allocation is not a het job" + raise ValueError(msg) + het_size = int(het_size_env) + if any(group >= het_size for group in het_group): + msg = ( + f"Het group {max(het_group)} requested, " + f"but max het group in allocation is {het_size-1}" + ) + raise ValueError(msg) + self.set("het-group", ",".join(str(group) for group in het_group)) + + def set_verbose_launch(self, verbose: bool) -> None: + """Set the job to run in verbose mode + + This sets ``--verbose`` + + :param verbose: Whether the job should be run verbosely + """ + if verbose: + self.set("verbose", None) + else: + self._launch_args.pop("verbose", None) + + def set_quiet_launch(self, quiet: bool) -> None: + """Set the job to run in quiet mode + + This sets ``--quiet`` + + :param quiet: Whether the job should be run quietly + """ + if quiet: + self.set("quiet", None) + else: + self._launch_args.pop("quiet", None) + + def format_launch_args(self) -> t.List[str]: + """Return a list of slurm formatted launch arguments + + :return: list of slurm arguments for these settings + """ + formatted = [] + for key, value in self._launch_args.items(): + short_arg = bool(len(str(key)) == 1) + prefix = "-" if short_arg else "--" + if not value: + formatted += [prefix + key] + else: + if short_arg: + formatted += [prefix + key, str(value)] + else: + formatted += ["=".join((prefix + key, str(value)))] + return formatted + + def format_env_vars(self, env_vars: t.Mapping[str, str | None]) -> list[str]: + """Build bash compatible environment variable string for Slurm + + :returns: the formatted string of environment variables + """ + self._check_env_vars(env_vars) + return [f"{k}={v}" for k, v in env_vars.items() if "," not in str(v)] + + def format_comma_sep_env_vars( + self, env_vars: t.Dict[str, t.Optional[str]] + ) -> t.Union[t.Tuple[str, t.List[str]], None]: + """Build environment variable string for Slurm + + Slurm takes exports in comma separated lists + the list starts with all as to not disturb the rest of the environment + for more information on this, see the slurm documentation for srun + + :param env_vars: An environment mapping + :returns: the formatted string of environment variables + """ + self._check_env_vars(env_vars) + exportable_env, compound_env, key_only = [], [], [] + + for k, v in env_vars.items(): + kvp = f"{k}={v}" + + if "," in str(v): + key_only.append(k) + compound_env.append(kvp) + else: + exportable_env.append(kvp) + + # Append keys to exportable KVPs, e.g. `--export x1=v1,KO1,KO2` + fmt_exported_env = ",".join(v for v in exportable_env + key_only) + + return fmt_exported_env, compound_env + + def _check_env_vars(self, env_vars: t.Mapping[str, str | None]) -> None: + """Warn a user trying to set a variable which is set in the environment + + Given Slurm's env var precedence, trying to export a variable which is already + present in the environment will not work. + """ + for k, v in env_vars.items(): + if "," not in str(v): + # If a variable is defined, it will take precedence over --export + # we warn the user + preexisting_var = os.environ.get(k, None) + if preexisting_var is not None and preexisting_var != v: + msg = ( + f"Variable {k} is set to {preexisting_var} in current " + "environment. If the job is running in an interactive " + f"allocation, the value {v} will not be set. Please " + "consider removing the variable from the environment " + "and re-run the experiment." + ) + logger.warning(msg) + + def set(self, key: str, value: str | None) -> None: + """Set an arbitrary launch argument + + :param key: The launch argument + :param value: A string representation of the value for the launch + argument (if applicable), otherwise `None` + """ + set_check_input(key, value) + if key in self._reserved_launch_args(): + logger.warning( + ( + f"Could not set argument '{key}': " + f"it is a reserved argument of '{type(self).__name__}'" + ) + ) + return + if key in self._launch_args and key != self._launch_args[key]: + logger.warning(f"Overwritting argument '{key}' with value '{value}'") + self._launch_args[key] = value diff --git a/smartsim/settings/arguments/launch_arguments.py b/smartsim/settings/arguments/launch_arguments.py new file mode 100644 index 0000000000..6ec741d914 --- /dev/null +++ b/smartsim/settings/arguments/launch_arguments.py @@ -0,0 +1,75 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import copy +import textwrap +import typing as t +from abc import ABC, abstractmethod + +from smartsim.log import get_logger + +from ..._core.utils.helpers import fmt_dict + +logger = get_logger(__name__) + + +class LaunchArguments(ABC): + """Abstract base class for launcher arguments. It is the responsibility of + child classes for each launcher to add methods to set input parameters and + to maintain valid state between parameters set by a user. + """ + + def __init__(self, launch_args: t.Dict[str, str | None] | None) -> None: + """Initialize a new `LaunchArguments` instance. + + :param launch_args: A mapping of arguments to (optional) values + """ + self._launch_args = copy.deepcopy(launch_args) or {} + """A dictionary of launch arguments""" + + @abstractmethod + def launcher_str(self) -> str: + """Get the string representation of the launcher""" + + @abstractmethod + def set(self, arg: str, val: str | None) -> None: + """Set a launch argument + + :param arg: The argument name to set + :param val: The value to set the argument to as a `str` (if + applicable). Otherwise `None` + """ + + def __str__(self) -> str: # pragma: no-cover + return textwrap.dedent(f"""\ + Launch Arguments: + Launcher: {self.launcher_str()} + Name: {type(self).__name__} + Arguments: + {fmt_dict(self._launch_args)} + """) diff --git a/smartsim/settings/base.py b/smartsim/settings/base.py deleted file mode 100644 index da3edb4917..0000000000 --- a/smartsim/settings/base.py +++ /dev/null @@ -1,689 +0,0 @@ -# BSD 2-Clause License # -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from __future__ import annotations - -import copy -import typing as t - -from smartsim.settings.containers import Container - -from .._core.utils.helpers import expand_exe_path, fmt_dict, is_valid_cmd -from ..entity.dbobject import DBModel, DBScript -from ..log import get_logger - -logger = get_logger(__name__) - -# fmt: off -class SettingsBase: - ... -# fmt: on - - -# pylint: disable=too-many-public-methods -class RunSettings(SettingsBase): - # pylint: disable=unused-argument - - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_command: str = "", - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - container: t.Optional[Container] = None, - **_kwargs: t.Any, - ) -> None: - """Run parameters for a ``Model`` - - The base ``RunSettings`` class should only be used with the `local` - launcher on single node, workstations, or laptops. - - If no ``run_command`` is specified, the executable will be launched - locally. - - ``run_args`` passed as a dict will be interpreted literally for - local ``RunSettings`` and added directly to the ``run_command`` - e.g. run_args = {"-np": 2} will be "-np 2" - - Example initialization - - .. highlight:: python - .. code-block:: python - - rs = RunSettings("echo", "hello", "mpirun", run_args={"-np": "2"}) - - :param exe: executable to run - :param exe_args: executable arguments - :param run_command: launch binary (e.g. "srun") - :param run_args: arguments for run command (e.g. `-np` for `mpiexec`) - :param env_vars: environment vars to launch job with - :param container: container type for workload (e.g. "singularity") - """ - # Do not expand executable if running within a container - self.exe = [exe] if container else [expand_exe_path(exe)] - self.exe_args = exe_args or [] - self.run_args = run_args or {} - self.env_vars = env_vars or {} - self.container = container - self._run_command = run_command - self.in_batch = False - self.colocated_db_settings: t.Optional[ - t.Dict[ - str, - t.Union[ - bool, - int, - str, - None, - t.List[str], - t.Iterable[t.Union[int, t.Iterable[int]]], - t.List[DBModel], - t.List[DBScript], - t.Dict[str, t.Union[int, None]], - t.Dict[str, str], - ], - ] - ] = None - - @property - def exe_args(self) -> t.Union[str, t.List[str]]: - """Return an immutable list of attached executable arguments. - - :returns: attached executable arguments - """ - return self._exe_args - - @exe_args.setter - def exe_args(self, value: t.Union[str, t.List[str], None]) -> None: - """Set the executable arguments. - - :param value: executable arguments - """ - self._exe_args = self._build_exe_args(value) - - @property - def run_args(self) -> t.Dict[str, t.Union[int, str, float, None]]: - """Return an immutable list of attached run arguments. - - :returns: attached run arguments - """ - return self._run_args - - @run_args.setter - def run_args(self, value: t.Dict[str, t.Union[int, str, float, None]]) -> None: - """Set the run arguments. - - :param value: run arguments - """ - self._run_args = copy.deepcopy(value) - - @property - def env_vars(self) -> t.Dict[str, t.Optional[str]]: - """Return an immutable list of attached environment variables. - - :returns: attached environment variables - """ - return self._env_vars - - @env_vars.setter - def env_vars(self, value: t.Dict[str, t.Optional[str]]) -> None: - """Set the environment variables. - - :param value: environment variables - """ - self._env_vars = copy.deepcopy(value) - - # To be overwritten by subclasses. Set of reserved args a user cannot change - reserved_run_args = set() # type: set[str] - - def set_nodes(self, nodes: int) -> None: - """Set the number of nodes - - :param nodes: number of nodes to run with - """ - logger.warning( - ( - "Node specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_tasks(self, tasks: int) -> None: - """Set the number of tasks to launch - - :param tasks: number of tasks to launch - """ - logger.warning( - ( - "Task specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_tasks_per_node(self, tasks_per_node: int) -> None: - """Set the number of tasks per node - - :param tasks_per_node: number of tasks to launch per node - """ - logger.warning( - ( - "Task per node specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_task_map(self, task_mapping: str) -> None: - """Set a task mapping - - :param task_mapping: task mapping - """ - logger.warning( - ( - "Task mapping specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_cpus_per_task(self, cpus_per_task: int) -> None: - """Set the number of cpus per task - - :param cpus_per_task: number of cpus per task - """ - logger.warning( - ( - "CPU per node specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: - """Specify the hostlist for this job - - :param host_list: hosts to launch on - """ - logger.warning( - ( - "Hostlist specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_hostlist_from_file(self, file_path: str) -> None: - """Use the contents of a file to specify the hostlist for this job - - :param file_path: Path to the hostlist file - """ - logger.warning( - ( - "Hostlist from file specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: - """Specify a list of hosts to exclude for launching this job - - :param host_list: hosts to exclude - """ - logger.warning( - ( - "Excluded host list specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_cpu_bindings(self, bindings: t.Union[int, t.List[int]]) -> None: - """Set the cores to which MPI processes are bound - - :param bindings: List specifing the cores to which MPI processes are bound - """ - logger.warning( - ( - "CPU binding specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_memory_per_node(self, memory_per_node: int) -> None: - """Set the amount of memory required per node in megabytes - - :param memory_per_node: Number of megabytes per node - """ - logger.warning( - ( - "Memory per node specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_verbose_launch(self, verbose: bool) -> None: - """Set the job to run in verbose mode - - :param verbose: Whether the job should be run verbosely - """ - logger.warning( - ( - "Verbose specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_quiet_launch(self, quiet: bool) -> None: - """Set the job to run in quiet mode - - :param quiet: Whether the job should be run quietly - """ - logger.warning( - ( - "Quiet specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_broadcast(self, dest_path: t.Optional[str] = None) -> None: - """Copy executable file to allocated compute nodes - - :param dest_path: Path to copy an executable file - """ - logger.warning( - ( - "Broadcast specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_time(self, hours: int = 0, minutes: int = 0, seconds: int = 0) -> None: - """Automatically format and set wall time - - :param hours: number of hours to run job - :param minutes: number of minutes to run job - :param seconds: number of seconds to run job - """ - return self.set_walltime( - self._fmt_walltime(int(hours), int(minutes), int(seconds)) - ) - - def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: - """Specify the node feature for this job - - :param feature_list: node feature to launch on - """ - logger.warning( - ( - "Feature specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - @staticmethod - def _fmt_walltime(hours: int, minutes: int, seconds: int) -> str: - """Convert hours, minutes, and seconds into valid walltime format - - By defualt the formatted wall time is the total number of seconds. - - :param hours: number of hours to run job - :param minutes: number of minutes to run job - :param seconds: number of seconds to run job - :returns: Formatted walltime - """ - time_ = hours * 3600 - time_ += minutes * 60 - time_ += seconds - return str(time_) - - def set_walltime(self, walltime: str) -> None: - """Set the formatted walltime - - :param walltime: Time in format required by launcher`` - """ - logger.warning( - ( - "Walltime specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_binding(self, binding: str) -> None: - """Set binding - - :param binding: Binding - """ - logger.warning( - ( - "binding specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def set_mpmd_preamble(self, preamble_lines: t.List[str]) -> None: - """Set preamble to a file to make a job MPMD - - :param preamble_lines: lines to put at the beginning of a file. - """ - logger.warning( - ( - "MPMD preamble specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - def make_mpmd(self, settings: RunSettings) -> None: - """Make job an MPMD job - - :param settings: ``RunSettings`` instance - """ - logger.warning( - ( - "Make MPMD specification not implemented for this " - f"RunSettings type: {type(self)}" - ) - ) - - @property - def run_command(self) -> t.Optional[str]: - """Return the launch binary used to launch the executable - - Attempt to expand the path to the executable if possible - - :returns: launch binary e.g. mpiexec - """ - cmd = self._run_command - - if cmd: - if is_valid_cmd(cmd): - # command is valid and will be expanded - return expand_exe_path(cmd) - # command is not valid, so return it as is - # it may be on the compute nodes but not local machine - return cmd - # run without run command - return None - - def update_env(self, env_vars: t.Dict[str, t.Union[str, int, float, bool]]) -> None: - """Update the job environment variables - - To fully inherit the current user environment, add the - workload-manager-specific flag to the launch command through the - :meth:`add_exe_args` method. For example, ``--export=ALL`` for - slurm, or ``-V`` for PBS/aprun. - - - :param env_vars: environment variables to update or add - :raises TypeError: if env_vars values cannot be coerced to strings - """ - val_types = (str, int, float, bool) - # Coerce env_vars values to str as a convenience to user - for env, val in env_vars.items(): - if not isinstance(val, val_types): - raise TypeError( - f"env_vars[{env}] was of type {type(val)}, not {val_types}" - ) - - self.env_vars[env] = str(val) - - def add_exe_args(self, args: t.Union[str, t.List[str]]) -> None: - """Add executable arguments to executable - - :param args: executable arguments - """ - args = self._build_exe_args(args) - self._exe_args.extend(args) - - def set( - self, arg: str, value: t.Optional[str] = None, condition: bool = True - ) -> None: - """Allows users to set individual run arguments. - - A method that allows users to set run arguments after object - instantiation. Does basic formatting such as stripping leading dashes. - If the argument has been set previously, this method will log warning - but ultimately comply. - - Conditional expressions may be passed to the conditional parameter. If the - expression evaluates to True, the argument will be set. In not an info - message is logged and no further operation is performed. - - Basic Usage - - .. highlight:: python - .. code-block:: python - - rs = RunSettings("python") - rs.set("an-arg", "a-val") - rs.set("a-flag") - rs.format_run_args() # returns ["an-arg", "a-val", "a-flag", "None"] - - Slurm Example with Conditional Setting - - .. highlight:: python - .. code-block:: python - - import socket - - rs = SrunSettings("echo", "hello") - rs.set_tasks(1) - rs.set("exclusive") - - # Only set this argument if condition param evals True - # Otherwise log and NOP - rs.set("partition", "debug", - condition=socket.gethostname()=="testing-system") - - rs.format_run_args() - # returns ["exclusive", "None", "partition", "debug"] iff - socket.gethostname()=="testing-system" - # otherwise returns ["exclusive", "None"] - - :param arg: name of the argument - :param value: value of the argument - :param conditon: set the argument if condition evaluates to True - """ - if not isinstance(arg, str): - raise TypeError("Argument name should be of type str") - if value is not None and not isinstance(value, str): - raise TypeError("Argument value should be of type str or None") - arg = arg.strip().lstrip("-") - - if not condition: - logger.info(f"Could not set argument '{arg}': condition not met") - return - if arg in self.reserved_run_args: - logger.warning( - ( - f"Could not set argument '{arg}': " - f"it is a reserved arguement of '{type(self).__name__}'" - ) - ) - return - - if arg in self.run_args and value != self.run_args[arg]: - logger.warning(f"Overwritting argument '{arg}' with value '{value}'") - self.run_args[arg] = value - - @staticmethod - def _build_exe_args(exe_args: t.Optional[t.Union[str, t.List[str]]]) -> t.List[str]: - """Check and convert exe_args input to a desired collection format""" - if not exe_args: - return [] - - if isinstance(exe_args, list): - exe_args = copy.deepcopy(exe_args) - - if not ( - isinstance(exe_args, str) - or ( - isinstance(exe_args, list) - and all(isinstance(arg, str) for arg in exe_args) - ) - ): - raise TypeError("Executable arguments were not a list of str or a str.") - - if isinstance(exe_args, str): - return exe_args.split() - - return exe_args - - def format_run_args(self) -> t.List[str]: - """Return formatted run arguments - - For ``RunSettings``, the run arguments are passed - literally with no formatting. - - :return: list run arguments for these settings - """ - formatted = [] - for arg, value in self.run_args.items(): - formatted.append(arg) - formatted.append(str(value)) - return formatted - - def format_env_vars(self) -> t.List[str]: - """Build environment variable string - - :returns: formatted list of strings to export variables - """ - formatted = [] - for key, val in self.env_vars.items(): - if val is None: - formatted.append(f"{key}=") - else: - formatted.append(f"{key}={val}") - return formatted - - def __str__(self) -> str: # pragma: no-cover - string = f"Executable: {self.exe[0]}\n" - string += f"Executable Arguments: {' '.join((self.exe_args))}" - if self.run_command: - string += f"\nRun Command: {self.run_command}" - if self.run_args: - string += f"\nRun Arguments:\n{fmt_dict(self.run_args)}" - if self.colocated_db_settings: - string += "\nCo-located Database: True" - return string - - -class BatchSettings(SettingsBase): - def __init__( - self, - batch_cmd: str, - batch_args: t.Optional[t.Dict[str, t.Optional[str]]] = None, - **kwargs: t.Any, - ) -> None: - self._batch_cmd = batch_cmd - self.batch_args = batch_args or {} - self._preamble: t.List[str] = [] - nodes = kwargs.get("nodes", None) - if nodes: - self.set_nodes(nodes) - queue = kwargs.get("queue", None) - if queue: - self.set_queue(queue) - self.set_walltime(kwargs.get("time", None)) - self.set_account(kwargs.get("account", None)) - - @property - def batch_cmd(self) -> str: - """Return the batch command - - Tests to see if we can expand the batch command - path. If we can, then returns the expanded batch - command. If we cannot, returns the batch command as is. - - :returns: batch command - """ - if is_valid_cmd(self._batch_cmd): - return expand_exe_path(self._batch_cmd) - - return self._batch_cmd - - @property - def batch_args(self) -> t.Dict[str, t.Optional[str]]: - """Retrieve attached batch arguments - - :returns: attached batch arguments - """ - return self._batch_args - - @batch_args.setter - def batch_args(self, value: t.Dict[str, t.Optional[str]]) -> None: - """Attach batch arguments - - :param value: dictionary of batch arguments - """ - self._batch_args = copy.deepcopy(value) if value else {} - - def set_nodes(self, num_nodes: int) -> None: - raise NotImplementedError - - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: - raise NotImplementedError - - def set_queue(self, queue: str) -> None: - raise NotImplementedError - - def set_walltime(self, walltime: str) -> None: - raise NotImplementedError - - def set_account(self, account: str) -> None: - raise NotImplementedError - - def format_batch_args(self) -> t.List[str]: - raise NotImplementedError - - def set_batch_command(self, command: str) -> None: - """Set the command used to launch the batch e.g. ``sbatch`` - - :param command: batch command - """ - self._batch_cmd = command - - def add_preamble(self, lines: t.List[str]) -> None: - """Add lines to the batch file preamble. The lines are just - written (unmodified) at the beginning of the batch file - (after the WLM directives) and can be used to e.g. - start virtual environments before running the executables. - - :param line: lines to add to preamble. - """ - if isinstance(lines, str): - self._preamble += [lines] - elif isinstance(lines, list): - self._preamble += lines - else: - raise TypeError("Expected str or List[str] for lines argument") - - @property - def preamble(self) -> t.Iterable[str]: - """Return an iterable of preamble clauses to be prepended to the batch file - - :return: attached preamble clauses - """ - return (clause for clause in self._preamble) - - def __str__(self) -> str: # pragma: no-cover - string = f"Batch Command: {self._batch_cmd}" - if self.batch_args: - string += f"\nBatch arguments:\n{fmt_dict(self.batch_args)}" - return string diff --git a/smartsim/settings/base_settings.py b/smartsim/settings/base_settings.py new file mode 100644 index 0000000000..2e8a87f57f --- /dev/null +++ b/smartsim/settings/base_settings.py @@ -0,0 +1,30 @@ +# BSD 2-Clause License # +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class BaseSettings: + """ + A base class for LaunchSettings and BatchSettings. + """ diff --git a/smartsim/settings/batch_command.py b/smartsim/settings/batch_command.py new file mode 100644 index 0000000000..a96492d398 --- /dev/null +++ b/smartsim/settings/batch_command.py @@ -0,0 +1,35 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from enum import Enum + + +class BatchSchedulerType(Enum): + """Schedulers supported by SmartSim.""" + + Slurm = "slurm" + Pbs = "pbs" + Lsf = "lsf" diff --git a/smartsim/settings/batch_settings.py b/smartsim/settings/batch_settings.py new file mode 100644 index 0000000000..734e919ce3 --- /dev/null +++ b/smartsim/settings/batch_settings.py @@ -0,0 +1,174 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import copy +import typing as t + +from smartsim.log import get_logger + +from .._core.utils.helpers import fmt_dict +from .arguments import BatchArguments +from .arguments.batch.lsf import BsubBatchArguments +from .arguments.batch.pbs import QsubBatchArguments +from .arguments.batch.slurm import SlurmBatchArguments +from .base_settings import BaseSettings +from .batch_command import BatchSchedulerType +from .common import StringArgument + +logger = get_logger(__name__) + + +class BatchSettings(BaseSettings): + """The BatchSettings class stores scheduler configuration settings and is + used to inject scheduler-specific behavior into a job. + + BatchSettings is designed to be extended by a BatchArguments child class that + corresponds to the scheduler provided during initialization. The supported schedulers + are Slurm, PBS, and LSF. Using the BatchSettings class, users can: + + - Set the scheduler type of a batch job. + - Configure batch arguments and environment variables. + - Access and modify custom batch arguments. + - Update environment variables. + - Retrieve information associated with the ``BatchSettings`` object. + - The scheduler value (BatchSettings.scheduler). + - The derived BatchArguments child class (BatchSettings.batch_args). + - The set environment variables (BatchSettings.env_vars). + - A formatted output of set batch arguments (BatchSettings.format_batch_args). + """ + + def __init__( + self, + batch_scheduler: t.Union[BatchSchedulerType, str], + batch_args: StringArgument | None = None, + env_vars: StringArgument | None = None, + ) -> None: + """Initialize a BatchSettings instance. + + The "batch_scheduler" of SmartSim BatchSettings will determine the + child type assigned to the BatchSettings.batch_args attribute. + For example, to configure a job for SLURM batch jobs, assign BatchSettings.batch_scheduler + to "slurm" or BatchSchedulerType.Slurm: + + .. highlight:: python + .. code-block:: python + + sbatch_settings = BatchSettings(batch_scheduler="slurm") + # OR + sbatch_settings = BatchSettings(batch_scheduler=BatchSchedulerType.Slurm) + + This will assign a SlurmBatchArguments object to ``sbatch_settings.batch_args``. + Using the object, users may access the child class functions to set + batch configurations. For example: + + .. highlight:: python + .. code-block:: python + + sbatch_settings.batch_args.set_nodes(5) + sbatch_settings.batch_args.set_cpus_per_task(2) + + To set customized batch arguments, use the `set()` function provided by + the BatchSettings child class. For example: + + .. highlight:: python + .. code-block:: python + + sbatch_settings.batch_args.set(key="nodes", value="6") + + If the key already exists in the existing batch arguments, the value will + be overwritten. + + :param batch_scheduler: The type of scheduler to initialize (e.g., Slurm, PBS, LSF) + :param batch_args: A dictionary of arguments for the scheduler, where the keys + are strings and the values can be either strings or None. This argument is optional + and defaults to None. + :param env_vars: Environment variables for the batch settings, where the keys + are strings and the values can be either strings or None. This argument is + also optional and defaults to None. + :raises ValueError: Raises if the scheduler provided does not exist. + """ + try: + self._batch_scheduler = BatchSchedulerType(batch_scheduler) + """The scheduler type""" + except ValueError: + raise ValueError(f"Invalid scheduler type: {batch_scheduler}") from None + self._arguments = self._get_arguments(batch_args) + """The BatchSettings child class based on scheduler type""" + self.env_vars = env_vars or {} + """The environment configuration""" + + @property + def batch_scheduler(self) -> str: + """Return the scheduler type.""" + return self._batch_scheduler.value + + @property + def batch_args(self) -> BatchArguments: + """Return the BatchArguments child class.""" + return self._arguments + + @property + def env_vars(self) -> StringArgument: + """Return an immutable list of attached environment variables.""" + return self._env_vars + + @env_vars.setter + def env_vars(self, value: t.Dict[str, str | None]) -> None: + """Set the environment variables.""" + self._env_vars = copy.deepcopy(value) + + def _get_arguments(self, batch_args: StringArgument | None) -> BatchArguments: + """Map the Scheduler to the BatchArguments. This method should only be + called once during construction. + + :param schedule_args: A mapping of arguments names to values to be + used to initialize the arguments + :returns: The appropriate type for the settings instance. + :raises ValueError: An invalid scheduler type was provided. + """ + if self._batch_scheduler == BatchSchedulerType.Slurm: + return SlurmBatchArguments(batch_args) + elif self._batch_scheduler == BatchSchedulerType.Lsf: + return BsubBatchArguments(batch_args) + elif self._batch_scheduler == BatchSchedulerType.Pbs: + return QsubBatchArguments(batch_args) + else: + raise ValueError(f"Invalid scheduler type: {self._batch_scheduler}") + + def format_batch_args(self) -> t.List[str]: + """Get the formatted batch arguments to preview + + :return: formatted batch arguments + """ + return self._arguments.format_batch_args() + + def __str__(self) -> str: # pragma: no-cover + string = f"\nBatch Scheduler: {self.batch_scheduler}{self.batch_args}" + if self.env_vars: + string += f"\nEnvironment variables: \n{fmt_dict(self.env_vars)}" + return string diff --git a/smartsim/settings/common.py b/smartsim/settings/common.py new file mode 100644 index 0000000000..edca5fd52b --- /dev/null +++ b/smartsim/settings/common.py @@ -0,0 +1,49 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t + +from smartsim.log import get_logger + +IntegerArgument = t.Dict[str, t.Optional[int]] +StringArgument = t.Dict[str, t.Optional[str]] + +logger = get_logger(__name__) + + +def set_check_input(key: str, value: t.Optional[str]) -> None: + if not isinstance(key, str): + raise TypeError(f"Key '{key}' should be of type str") + if not isinstance(value, (str, type(None))): + raise TypeError(f"Value '{value}' should be of type str or None") + if key.startswith("-"): + key = key.lstrip("-") + logger.warning( + "One or more leading `-` characters were provided to the run argument.\n" + "Leading dashes were stripped and the arguments were passed to the run_command." + ) diff --git a/smartsim/settings/containers.py b/smartsim/settings/containers.py deleted file mode 100644 index d2fd4fca27..0000000000 --- a/smartsim/settings/containers.py +++ /dev/null @@ -1,173 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import shutil -import typing as t - -from ..log import get_logger - -logger = get_logger(__name__) - - -class Container: - """Base class for container types in SmartSim. - - Container types are used to embed all the information needed to - launch a workload within a container into a single object. - - :param image: local or remote path to container image - :param args: arguments to container command - :param mount: paths to mount (bind) from host machine into image. - :param working_directory: path of the working directory within the container - """ - - def __init__( - self, image: str, args: str = "", mount: str = "", working_directory: str = "" - ) -> None: - # Validate types - if not isinstance(image, str): - raise TypeError("image must be a str") - if not isinstance(args, (str, list)): - raise TypeError("args must be a str | list") - if not isinstance(mount, (str, list, dict)): - raise TypeError("mount must be a str | list | dict") - if not isinstance(working_directory, str): - raise TypeError("working_directory must be a str") - - self.image = image - self.args = args - self.mount = mount - self.working_directory = working_directory - - def _containerized_run_command(self, run_command: str) -> str: - """Return modified run_command with container commands prepended. - - :param run_command: run command from a RunSettings class - """ - raise NotImplementedError( - "Containerized run command specification not implemented for this " - f"Container type: {type(self)}" - ) - - -class Singularity(Container): - # pylint: disable=abstract-method - # todo: determine if _containerized_run_command should be abstract - - """Singularity (apptainer) container type. To be passed into a - ``RunSettings`` class initializer or ``Experiment.create_run_settings``. - - .. note:: - - Singularity integration is currently tested with - `Apptainer 1.0 `_ - with slurm and PBS workload managers only. - - Also, note that user-defined bind paths (``mount`` argument) may be - disabled by a - `system administrator - `_ - - - :param image: local or remote path to container image, - e.g. ``docker://sylabsio/lolcow`` - :param args: arguments to 'singularity exec' command - :param mount: paths to mount (bind) from host machine into image. - """ - - def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: - super().__init__(*args, **kwargs) - - def _container_cmds(self, default_working_directory: str = "") -> t.List[str]: - """Return list of container commands to be inserted before exe. - Container members are validated during this call. - - :raises TypeError: if object members are invalid types - """ - serialized_args = "" - if self.args: - # Serialize args into a str - if isinstance(self.args, str): - serialized_args = self.args - elif isinstance(self.args, list): - serialized_args = " ".join(self.args) - else: - raise TypeError("self.args must be a str | list") - - serialized_mount = "" - if self.mount: - if isinstance(self.mount, str): - serialized_mount = self.mount - elif isinstance(self.mount, list): - serialized_mount = ",".join(self.mount) - elif isinstance(self.mount, dict): - paths = [] - for host_path, img_path in self.mount.items(): - if img_path: - paths.append(f"{host_path}:{img_path}") - else: - paths.append(host_path) - serialized_mount = ",".join(paths) - else: - raise TypeError("self.mount must be str | list | dict") - - working_directory = default_working_directory - if self.working_directory: - working_directory = self.working_directory - - if working_directory not in serialized_mount: - if serialized_mount: - serialized_mount = ",".join([working_directory, serialized_mount]) - else: - serialized_mount = working_directory - logger.warning( - f"Working directory not specified in mount: \n {working_directory}\n" - "Automatically adding it to the list of bind points" - ) - - # Find full path to singularity - singularity = shutil.which("singularity") - - # Some systems have singularity available on compute nodes only, - # so warn instead of error - if not singularity: - logger.warning( - "Unable to find singularity. Continuing in case singularity is " - "available on compute node" - ) - - # Construct containerized launch command - cmd_list = [singularity or "singularity", "exec"] - if working_directory: - cmd_list.extend(["--pwd", working_directory]) - - if serialized_args: - cmd_list.append(serialized_args) - if serialized_mount: - cmd_list.extend(["--bind", serialized_mount]) - cmd_list.append(self.image) - - return cmd_list diff --git a/smartsim/settings/launch_command.py b/smartsim/settings/launch_command.py new file mode 100644 index 0000000000..b848e35e1f --- /dev/null +++ b/smartsim/settings/launch_command.py @@ -0,0 +1,41 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from enum import Enum + + +class LauncherType(Enum): + """Launchers supported by SmartSim.""" + + Dragon = "dragon" + Slurm = "slurm" + Pals = "pals" + Alps = "alps" + Local = "local" + Mpiexec = "mpiexec" + Mpirun = "mpirun" + Orterun = "orterun" + Lsf = "lsf" diff --git a/smartsim/settings/launch_settings.py b/smartsim/settings/launch_settings.py new file mode 100644 index 0000000000..7b60830228 --- /dev/null +++ b/smartsim/settings/launch_settings.py @@ -0,0 +1,226 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import copy +import typing as t + +from smartsim.log import get_logger + +from .._core.utils.helpers import fmt_dict +from .arguments import LaunchArguments +from .arguments.launch.alps import AprunLaunchArguments +from .arguments.launch.dragon import DragonLaunchArguments +from .arguments.launch.local import LocalLaunchArguments +from .arguments.launch.lsf import JsrunLaunchArguments +from .arguments.launch.mpi import ( + MpiexecLaunchArguments, + MpirunLaunchArguments, + OrterunLaunchArguments, +) +from .arguments.launch.pals import PalsMpiexecLaunchArguments +from .arguments.launch.slurm import SlurmLaunchArguments +from .base_settings import BaseSettings +from .common import StringArgument +from .launch_command import LauncherType + +logger = get_logger(__name__) + + +class LaunchSettings(BaseSettings): + """The LaunchSettings class stores launcher configuration settings and is + used to inject launcher-specific behavior into a job. + + LaunchSettings is designed to be extended by a LaunchArguments child class that + corresponds to the launcher provided during initialization. The supported launchers + are Dragon, Slurm, PALS, ALPS, Local, Mpiexec, Mpirun, Orterun, and LSF. Using the + LaunchSettings class, users can: + + - Set the launcher type of a job. + - Configure launch arguments and environment variables. + - Access and modify custom launch arguments. + - Update environment variables. + - Retrieve information associated with the ``LaunchSettings`` object. + - The launcher value (LaunchSettings.launcher). + - The derived LaunchSettings child class (LaunchSettings.launch_args). + - The set environment variables (LaunchSettings.env_vars). + """ + + def __init__( + self, + launcher: t.Union[LauncherType, str], + launch_args: StringArgument | None = None, + env_vars: StringArgument | None = None, + ) -> None: + """Initialize a LaunchSettings instance. + + The "launcher" of SmartSim LaunchSettings will determine the + child type assigned to the LaunchSettings.launch_args attribute. + For example, to configure a job for SLURM, assign LaunchSettings.launcher + to "slurm" or LauncherType.Slurm: + + .. highlight:: python + .. code-block:: python + + srun_settings = LaunchSettings(launcher="slurm") + # OR + srun_settings = LaunchSettings(launcher=LauncherType.Slurm) + + This will assign a SlurmLaunchArguments object to ``srun_settings.launch_args``. + Using the object, users may access the child class functions to set + batch configurations. For example: + + .. highlight:: python + .. code-block:: python + + srun_settings.launch_args.set_nodes(5) + srun_settings.launch_args.set_cpus_per_task(2) + + To set customized launch arguments, use the `set()`function provided by + the LaunchSettings child class. For example: + + .. highlight:: python + .. code-block:: python + + srun_settings.launch_args.set(key="nodes", value="6") + + If the key already exists in the existing launch arguments, the value will + be overwritten. + + :param launcher: The type of launcher to initialize (e.g., Dragon, Slurm, + PALS, ALPS, Local, Mpiexec, Mpirun, Orterun, LSF) + :param launch_args: A dictionary of arguments for the launcher, where the keys + are strings and the values can be either strings or None. This argument is optional + and defaults to None. + :param env_vars: Environment variables for the launch settings, where the keys + are strings and the values can be either strings or None. This argument is + also optional and defaults to None. + :raises ValueError: Raises if the launcher provided does not exist. + """ + try: + self._launcher = LauncherType(launcher) + """The launcher type""" + except ValueError: + raise ValueError(f"Invalid launcher type: {launcher}") + self._arguments = self._get_arguments(launch_args) + """The LaunchSettings child class based on launcher type""" + self.env_vars = env_vars or {} + """The environment configuration""" + + @property + def launcher(self) -> str: + """The launcher type + + :returns: The launcher type's string representation + """ + return self._launcher.value + + @property + def launch_args(self) -> LaunchArguments: + """The launch argument + + :returns: The launch arguments + """ + return self._arguments + + @property + def env_vars(self) -> t.Mapping[str, str | None]: + """A mapping of environment variables to set or remove. This mapping is + a deep copy of the mapping used by the settings and as such altering + will not mutate the settings. + + :returns: An environment mapping + """ + return self._env_vars + + @env_vars.setter + def env_vars(self, value: dict[str, str | None]) -> None: + """Set the environment variables to a new mapping. This setter will + make a copy of the mapping and as such altering the original mapping + will not mutate the settings. + + :param value: The new environment mapping + """ + self._env_vars = copy.deepcopy(value) + + def _get_arguments(self, launch_args: StringArgument | None) -> LaunchArguments: + """Map the Launcher to the LaunchArguments. This method should only be + called once during construction. + + :param launch_args: A mapping of arguments names to values to be used + to initialize the arguments + :returns: The appropriate type for the settings instance. + :raises ValueError: An invalid launcher type was provided. + """ + if self._launcher == LauncherType.Slurm: + return SlurmLaunchArguments(launch_args) + elif self._launcher == LauncherType.Mpiexec: + return MpiexecLaunchArguments(launch_args) + elif self._launcher == LauncherType.Mpirun: + return MpirunLaunchArguments(launch_args) + elif self._launcher == LauncherType.Orterun: + return OrterunLaunchArguments(launch_args) + elif self._launcher == LauncherType.Alps: + return AprunLaunchArguments(launch_args) + elif self._launcher == LauncherType.Lsf: + return JsrunLaunchArguments(launch_args) + elif self._launcher == LauncherType.Pals: + return PalsMpiexecLaunchArguments(launch_args) + elif self._launcher == LauncherType.Dragon: + return DragonLaunchArguments(launch_args) + elif self._launcher == LauncherType.Local: + return LocalLaunchArguments(launch_args) + else: + raise ValueError(f"Invalid launcher type: {self._launcher}") + + def update_env(self, env_vars: t.Dict[str, str | None]) -> None: + """Update the job environment variables + + To fully inherit the current user environment, add the + workload-manager-specific flag to the launch command through the + :meth:`add_exe_args` method. For example, ``--export=ALL`` for + slurm, or ``-V`` for PBS/aprun. + + + :param env_vars: environment variables to update or add + :raises TypeError: if env_vars values cannot be coerced to strings + """ + # Coerce env_vars values to str as a convenience to user + for env, val in env_vars.items(): + if not isinstance(env, str): + raise TypeError(f"The key '{env}' of env_vars should be of type str") + if not isinstance(val, (str, type(None))): + raise TypeError( + f"The value '{val}' of env_vars should be of type str or None" + ) + self._env_vars.update(env_vars) + + def __str__(self) -> str: # pragma: no-cover + string = f"\nLauncher: {self.launcher}{self.launch_args}" + if self.env_vars: + string += f"\nEnvironment variables: \n{fmt_dict(self.env_vars)}" + return string diff --git a/smartsim/settings/lsfSettings.py b/smartsim/settings/lsfSettings.py deleted file mode 100644 index bce0581c5f..0000000000 --- a/smartsim/settings/lsfSettings.py +++ /dev/null @@ -1,560 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -import copy -import typing as t -from pprint import pformat - -from ..error import SSUnsupportedError -from ..log import get_logger -from .base import BatchSettings, RunSettings - -logger = get_logger(__name__) - - -class JsrunSettings(RunSettings): - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - **_kwargs: t.Any, - ) -> None: - """Settings to run job with ``jsrun`` command - - ``JsrunSettings`` should only be used on LSF-based systems. - - :param exe: executable - :param exe_args: executable arguments - :param run_args: arguments for run command - :param env_vars: environment vars to launch job with - """ - super().__init__( - exe, - exe_args, - run_command="jsrun", - run_args=run_args, - env_vars=env_vars, - ) - - # Parameters needed for MPMD run - self.erf_sets = {"host": "*", "cpu": "*", "ranks": "1"} - self.mpmd_preamble_lines: t.List[str] = [] - self.mpmd: t.List[RunSettings] = [] - self.individual_suffix = "" - - reserved_run_args = {"chdir", "h"} - - def set_num_rs(self, num_rs: t.Union[str, int]) -> None: - """Set the number of resource sets to use - - This sets ``--nrs``. - - :param num_rs: Number of resource sets or `ALL_HOSTS` - """ - if isinstance(num_rs, str): - self.run_args["nrs"] = num_rs - else: - self.run_args["nrs"] = int(num_rs) - - def set_cpus_per_rs(self, cpus_per_rs: int) -> None: - """Set the number of cpus to use per resource set - - This sets ``--cpu_per_rs`` - - :param cpus_per_rs: number of cpus to use per resource set or ALL_CPUS - """ - if self.colocated_db_settings: - db_cpus = int(t.cast(int, self.colocated_db_settings.get("db_cpus", 0))) - if not db_cpus: - raise ValueError("db_cpus must be configured on colocated_db_settings") - - if cpus_per_rs < db_cpus: - raise ValueError( - f"Cannot set cpus_per_rs ({cpus_per_rs}) to less than " - + f"db_cpus ({db_cpus})" - ) - if isinstance(cpus_per_rs, str): - self.run_args["cpu_per_rs"] = cpus_per_rs - else: - self.run_args["cpu_per_rs"] = int(cpus_per_rs) - - def set_gpus_per_rs(self, gpus_per_rs: int) -> None: - """Set the number of gpus to use per resource set - - This sets ``--gpu_per_rs`` - - :param gpus_per_rs: number of gpus to use per resource set or ALL_GPUS - """ - if isinstance(gpus_per_rs, str): - self.run_args["gpu_per_rs"] = gpus_per_rs - else: - self.run_args["gpu_per_rs"] = int(gpus_per_rs) - - def set_rs_per_host(self, rs_per_host: int) -> None: - """Set the number of resource sets to use per host - - This sets ``--rs_per_host`` - - :param rs_per_host: number of resource sets to use per host - """ - self.run_args["rs_per_host"] = int(rs_per_host) - - def set_tasks(self, tasks: int) -> None: - """Set the number of tasks for this job - - This sets ``--np`` - - :param tasks: number of tasks - """ - self.run_args["np"] = int(tasks) - - def set_tasks_per_rs(self, tasks_per_rs: int) -> None: - """Set the number of tasks per resource set - - This sets ``--tasks_per_rs`` - - :param tasks_per_rs: number of tasks per resource set - """ - self.run_args["tasks_per_rs"] = int(tasks_per_rs) - - def set_tasks_per_node(self, tasks_per_node: int) -> None: - """Set the number of tasks per resource set. - - This function is an alias for `set_tasks_per_rs`. - - :param tasks_per_node: number of tasks per resource set - """ - self.set_tasks_per_rs(int(tasks_per_node)) - - def set_cpus_per_task(self, cpus_per_task: int) -> None: - """Set the number of cpus per tasks. - - This function is an alias for `set_cpus_per_rs`. - - :param cpus_per_task: number of cpus per resource set - """ - self.set_cpus_per_rs(int(cpus_per_task)) - - def set_memory_per_rs(self, memory_per_rs: int) -> None: - """Specify the number of megabytes of memory to assign to a resource set - - This sets ``--memory_per_rs`` - - :param memory_per_rs: Number of megabytes per rs - """ - self.run_args["memory_per_rs"] = int(memory_per_rs) - - def set_memory_per_node(self, memory_per_node: int) -> None: - """Specify the number of megabytes of memory to assign to a resource set - - Alias for `set_memory_per_rs`. - - :param memory_per_node: Number of megabytes per rs - """ - self.set_memory_per_rs(int(memory_per_node)) - - def set_binding(self, binding: str) -> None: - """Set binding - - This sets ``--bind`` - - :param binding: Binding, e.g. `packed:21` - """ - self.run_args["bind"] = binding - - def make_mpmd(self, settings: RunSettings) -> None: - """Make step an MPMD (or SPMD) job. - - This method will activate job execution through an ERF file. - - Optionally, this method adds an instance of ``JsrunSettings`` to - the list of settings to be launched in the same ERF file. - - :param settings: ``JsrunSettings`` instance - """ - if self.colocated_db_settings: - raise SSUnsupportedError( - "Colocated models cannot be run as a mpmd workload" - ) - - self.mpmd.append(settings) - - def set_mpmd_preamble(self, preamble_lines: t.List[str]) -> None: - """Set preamble used in ERF file. Typical lines include - `oversubscribe-cpu : allow` or `overlapping-rs : allow`. - Can be used to set `launch_distribution`. If it is not present, - it will be inferred from the settings, or set to `packed` by - default. - - :param preamble_lines: lines to put at the beginning of the ERF - file. - """ - self.mpmd_preamble_lines = preamble_lines - - def set_erf_sets(self, erf_sets: t.Dict[str, str]) -> None: - """Set resource sets used for ERF (SPMD or MPMD) steps. - - ``erf_sets`` is a dictionary used to fill the ERF - line representing these settings, e.g. - `{"host": "1", "cpu": "{0:21}, {21:21}", "gpu": "*"}` - can be used to specify rank (or rank_count), hosts, cpus, gpus, - and memory. - The key `rank` is used to give specific ranks, as in - `{"rank": "1, 2, 5"}`, while the key `rank_count` is used to specify - the count only, as in `{"rank_count": "3"}`. If both are specified, - only `rank` is used. - - :param hosts: dictionary of resources - """ - self.erf_sets = copy.deepcopy(erf_sets) - - def format_env_vars(self) -> t.List[str]: - """Format environment variables. Each variable needs - to be passed with ``--env``. If a variable is set to ``None``, - its value is propagated from the current environment. - - :returns: formatted list of strings to export variables - """ - format_str = [] - for k, v in self.env_vars.items(): - if v: - format_str += ["-E", f"{k}={v}"] - else: - format_str += ["-E", f"{k}"] - return format_str - - def set_individual_output(self, suffix: t.Optional[str] = None) -> None: - """Set individual std output. - - This sets ``--stdio_mode individual`` - and inserts the suffix into the output name. The resulting - output name will be ``self.name + suffix + .out``. - - :param suffix: Optional suffix to add to output file names, - it can contain `%j`, `%h`, `%p`, or `%t`, - as specified by `jsrun` options. - """ - self.run_args["stdio_mode"] = "individual" - if suffix: - self.individual_suffix = suffix - - def format_run_args(self) -> t.List[str]: - """Return a list of LSF formatted run arguments - - :return: list of LSF arguments for these settings - """ - # args launcher uses - args = [] - restricted = ["chdir", "h", "stdio_stdout", "o", "stdio_stderr", "k"] - if self.mpmd or "erf_input" in self.run_args.keys(): - restricted.extend( - [ - "tasks_per_rs", - "a", - "np", - "p", - "cpu_per_rs", - "c", - "gpu_per_rs", - "g", - "latency_priority", - "l", - "memory_per_rs", - "m", - "nrs", - "n", - "rs_per_host", - "r", - "rs_per_socket", - "K", - "appfile", - "f", - "allocate_only", - "A", - "launch_node_task", - "H", - "use_reservation", - "J", - "use_resources", - "bind", - "b", - "launch_distribution", - "d", - ] - ) - - for opt, value in self.run_args.items(): - if opt not in restricted: - short_arg = bool(len(str(opt)) == 1) - prefix = "-" if short_arg else "--" - if not value: - args += [prefix + opt] - else: - if short_arg: - args += [prefix + opt, str(value)] - else: - args += ["=".join((prefix + opt, str(value)))] - return args - - def __str__(self) -> str: - string = super().__str__() - if self.mpmd: - string += "\nERF settings: " + pformat(self.erf_sets) - return string - - def _prep_colocated_db(self, db_cpus: int) -> None: - cpus_per_flag_set = False - for cpu_per_rs_flag in ["cpu_per_rs", "c"]: - if run_arg_value := self.run_args.get(cpu_per_rs_flag, 0): - cpus_per_flag_set = True - cpu_per_rs = int(run_arg_value) - if cpu_per_rs < db_cpus: - msg = ( - f"{cpu_per_rs_flag} flag was set to {cpu_per_rs}, but " - f"colocated DB requires {db_cpus} CPUs per RS. Automatically " - f"setting {cpu_per_rs_flag} flag to {db_cpus}" - ) - logger.info(msg) - self.run_args[cpu_per_rs_flag] = db_cpus - if not cpus_per_flag_set: - msg = f"Colocated DB requires {db_cpus} CPUs per RS. Automatically setting " - msg += f"--cpus_per_rs=={db_cpus}" - logger.info(msg) - self.set_cpus_per_rs(db_cpus) - - rs_per_host_set = False - for rs_per_host_flag in ["rs_per_host", "r"]: - if rs_per_host_flag in self.run_args: - rs_per_host_set = True - rs_per_host = self.run_args[rs_per_host_flag] - if rs_per_host != 1: - msg = f"{rs_per_host_flag} flag was set to {rs_per_host}, " - msg += ( - "but colocated DB requires running ONE resource set per host. " - ) - msg += f"Automatically setting {rs_per_host_flag} flag to 1" - logger.info(msg) - self.run_args[rs_per_host_flag] = "1" - if not rs_per_host_set: - msg = "Colocated DB requires one resource set per host. " - msg += " Automatically setting --rs_per_host==1" - logger.info(msg) - self.set_rs_per_host(1) - - -class BsubBatchSettings(BatchSettings): - def __init__( - self, - nodes: t.Optional[int] = None, - time: t.Optional[str] = None, - project: t.Optional[str] = None, - batch_args: t.Optional[t.Dict[str, t.Optional[str]]] = None, - smts: int = 0, - **kwargs: t.Any, - ) -> None: - """Specify ``bsub`` batch parameters for a job - - :param nodes: number of nodes for batch - :param time: walltime for batch job in format hh:mm - :param project: project for batch launch - :param batch_args: overrides for LSF batch arguments - :param smts: SMTs - """ - self.project: t.Optional[str] = None - - if project: - kwargs.pop("account", None) - else: - project = kwargs.pop("account", None) - - super().__init__( - "bsub", - batch_args=batch_args, - nodes=nodes, - account=project, - time=time, - **kwargs, - ) - - self.smts = 0 - if smts: - self.set_smts(smts) - - self.expert_mode = False - self.easy_settings = ["ln_slots", "ln_mem", "cn_cu", "nnodes"] - - def set_walltime(self, walltime: str) -> None: - """Set the walltime - - This sets ``-W``. - - :param walltime: Time in hh:mm format, e.g. "10:00" for 10 hours, - if time is supplied in hh:mm:ss format, seconds - will be ignored and walltime will be set as ``hh:mm`` - """ - # For compatibility with other launchers, as explained in docstring - if walltime: - if len(walltime.split(":")) > 2: - walltime = ":".join(walltime.split(":")[:2]) - self.walltime = walltime - - def set_smts(self, smts: int) -> None: - """Set SMTs - - This sets ``-alloc_flags``. If the user sets - SMT explicitly through ``-alloc_flags``, then that - takes precedence. - - :param smts: SMT (e.g on Summit: 1, 2, or 4) - """ - self.smts = smts - - def set_project(self, project: str) -> None: - """Set the project - - This sets ``-P``. - - :param time: project name - """ - if project: - self.project = project - - def set_account(self, account: str) -> None: - """Set the project - - this function is an alias for `set_project`. - - :param account: project name - """ - self.set_project(account) - - def set_nodes(self, num_nodes: int) -> None: - """Set the number of nodes for this batch job - - This sets ``-nnodes``. - - :param nodes: number of nodes - """ - if num_nodes: - self.batch_args["nnodes"] = str(int(num_nodes)) - - def set_expert_mode_req(self, res_req: str, slots: int) -> None: - """Set allocation for expert mode. This - will activate expert mode (``-csm``) and - disregard all other allocation options. - - This sets ``-csm -n slots -R res_req`` - - :param res_req: specific resource requirements - :param slots: number of resources to allocate - """ - self.expert_mode = True - self.batch_args["csm"] = "y" - self.batch_args["R"] = res_req - self.batch_args["n"] = str(slots) - - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: - """Specify the hostlist for this job - - :param host_list: hosts to launch on - :raises TypeError: if not str or list of str - """ - if isinstance(host_list, str): - host_list = [host_list.strip()] - if not isinstance(host_list, list): - raise TypeError("host_list argument must be a list of strings") - if not all(isinstance(host, str) for host in host_list): - raise TypeError("host_list argument must be list of strings") - self.batch_args["m"] = '"' + " ".join(host_list) + '"' - - def set_tasks(self, tasks: int) -> None: - """Set the number of tasks for this job - - This sets ``-n`` - - :param tasks: number of tasks - """ - self.batch_args["n"] = str(int(tasks)) - - def set_queue(self, queue: str) -> None: - """Set the queue for this job - - :param queue: The queue to submit the job on - """ - if queue: - self.batch_args["q"] = queue - - def _format_alloc_flags(self) -> None: - """Format ``alloc_flags`` checking if user already - set it. Currently only adds SMT flag if missing - and ``self.smts`` is set. - """ - - if self.smts: - if "alloc_flags" not in self.batch_args.keys(): - self.batch_args["alloc_flags"] = f"smt{self.smts}" - else: - # Check if smt is in the flag, otherwise add it - flags: t.List[str] = [] - if flags_arg := self.batch_args.get("alloc_flags", ""): - flags = flags_arg.strip('"').split() - if not any(flag.startswith("smt") for flag in flags): - flags.append(f"smt{self.smts}") - self.batch_args["alloc_flags"] = " ".join(flags) - - # Check if alloc_flags has to be enclosed in quotes - if "alloc_flags" in self.batch_args.keys(): - flags = [] - if flags_arg := self.batch_args.get("alloc_flags", ""): - flags = flags_arg.strip('"').split() - if len(flags) > 1: - self.batch_args["alloc_flags"] = '"' + " ".join(flags) + '"' - - def format_batch_args(self) -> t.List[str]: - """Get the formatted batch arguments for a preview - - :return: list of batch arguments for Qsub - """ - opts = [] - - self._format_alloc_flags() - - for opt, value in self.batch_args.items(): - if self.expert_mode and opt in self.easy_settings: - continue - - prefix = "-" # LSF only uses single dashses - - if not value: - opts += [prefix + opt] - else: - opts += [" ".join((prefix + opt, str(value)))] - - return opts diff --git a/smartsim/settings/mpiSettings.py b/smartsim/settings/mpiSettings.py deleted file mode 100644 index c64c66cbf5..0000000000 --- a/smartsim/settings/mpiSettings.py +++ /dev/null @@ -1,350 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -import shutil -import subprocess -import typing as t - -from ..error import LauncherError, SSUnsupportedError -from ..log import get_logger -from .base import RunSettings - -logger = get_logger(__name__) - - -class _BaseMPISettings(RunSettings): - """Base class for all common arguments of MPI-standard run commands""" - - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_command: str = "mpiexec", - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - fail_if_missing_exec: bool = True, - **kwargs: t.Any, - ) -> None: - """Settings to format run job with an MPI-standard binary - - Note that environment variables can be passed with a None - value to signify that they should be exported from the current - environment - - Any arguments passed in the ``run_args`` dict will be converted - command line arguments and prefixed with ``--``. Values of - None can be provided for arguments that do not have values. - - :param exe: executable - :param exe_args: executable arguments - :param run_args: arguments for run command - :param env_vars: environment vars to launch job with - :param fail_if_missing_exec: Throw an exception of the MPI command - is missing. Otherwise, throw a warning - """ - super().__init__( - exe, - exe_args, - run_command=run_command, - run_args=run_args, - env_vars=env_vars, - **kwargs, - ) - self.mpmd: t.List[RunSettings] = [] - self.affinity_script: t.List[str] = [] - - if not shutil.which(self._run_command): - msg = ( - f"Cannot find {self._run_command}. Try passing the " - "full path via run_command." - ) - if fail_if_missing_exec: - raise LauncherError(msg) - logger.warning(msg) - - reserved_run_args = {"wd", "wdir"} - - def make_mpmd(self, settings: RunSettings) -> None: - """Make a mpmd workload by combining two ``mpirun`` commands - - This connects the two settings to be executed with a single - Model instance - - :param settings: MpirunSettings instance - """ - if self.colocated_db_settings: - raise SSUnsupportedError( - "Colocated models cannot be run as a mpmd workload" - ) - self.mpmd.append(settings) - - def set_task_map(self, task_mapping: str) -> None: - """Set ``mpirun`` task mapping - - this sets ``--map-by `` - - For examples, see the man page for ``mpirun`` - - :param task_mapping: task mapping - """ - self.run_args["map-by"] = task_mapping - - def set_cpus_per_task(self, cpus_per_task: int) -> None: - """Set the number of tasks for this job - - This sets ``--cpus-per-proc`` for MPI compliant implementations - - note: this option has been deprecated in openMPI 4.0+ - and will soon be replaced. - - :param cpus_per_task: number of tasks - """ - self.run_args["cpus-per-proc"] = int(cpus_per_task) - - def set_cpu_binding_type(self, bind_type: str) -> None: - """Specifies the cores to which MPI processes are bound - - This sets ``--bind-to`` for MPI compliant implementations - - :param bind_type: binding type - """ - self.run_args["bind-to"] = bind_type - - def set_tasks_per_node(self, tasks_per_node: int) -> None: - """Set the number of tasks per node - - :param tasks_per_node: number of tasks to launch per node - """ - self.run_args["npernode"] = int(tasks_per_node) - - def set_tasks(self, tasks: int) -> None: - """Set the number of tasks for this job - - This sets ``-n`` for MPI compliant implementations - - :param tasks: number of tasks - """ - self.run_args["n"] = int(tasks) - - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: - """Set the hostlist for the ``mpirun`` command - - This sets ``--host`` - - :param host_list: list of host names - :raises TypeError: if not str or list of str - """ - if isinstance(host_list, str): - host_list = [host_list.strip()] - if not isinstance(host_list, list): - raise TypeError("host_list argument must be a list of strings") - if not all(isinstance(host, str) for host in host_list): - raise TypeError("host_list argument must be list of strings") - self.run_args["host"] = ",".join(host_list) - - def set_hostlist_from_file(self, file_path: str) -> None: - """Use the contents of a file to set the hostlist - - This sets ``--hostfile`` - - :param file_path: Path to the hostlist file - """ - self.run_args["hostfile"] = file_path - - def set_verbose_launch(self, verbose: bool) -> None: - """Set the job to run in verbose mode - - This sets ``--verbose`` - - :param verbose: Whether the job should be run verbosely - """ - if verbose: - self.run_args["verbose"] = None - else: - self.run_args.pop("verbose", None) - - def set_quiet_launch(self, quiet: bool) -> None: - """Set the job to run in quiet mode - - This sets ``--quiet`` - - :param quiet: Whether the job should be run quietly - """ - if quiet: - self.run_args["quiet"] = None - else: - self.run_args.pop("quiet", None) - - def set_broadcast(self, dest_path: t.Optional[str] = None) -> None: - """Copy the specified executable(s) to remote machines - - This sets ``--preload-binary`` - - :param dest_path: Destination path (Ignored) - """ - if dest_path is not None and isinstance(dest_path, str): - logger.warning( - ( - f"{type(self)} cannot set a destination path during broadcast. " - "Using session directory instead" - ) - ) - self.run_args["preload-binary"] = None - - def set_walltime(self, walltime: str) -> None: - """Set the maximum number of seconds that a job will run - - This sets ``--timeout`` - - :param walltime: number like string of seconds that a job will run in secs - """ - self.run_args["timeout"] = walltime - - def format_run_args(self) -> t.List[str]: - """Return a list of MPI-standard formatted run arguments - - :return: list of MPI-standard arguments for these settings - """ - # args launcher uses - args = [] - restricted = ["wdir", "wd"] - - for opt, value in self.run_args.items(): - if opt not in restricted: - prefix = "--" - if not value: - args += [prefix + opt] - else: - args += [prefix + opt, str(value)] - return args - - def format_env_vars(self) -> t.List[str]: - """Format the environment variables for mpirun - - :return: list of env vars - """ - formatted = [] - env_string = "-x" - - if self.env_vars: - for name, value in self.env_vars.items(): - if value: - formatted += [env_string, "=".join((name, str(value)))] - else: - formatted += [env_string, name] - return formatted - - -class MpirunSettings(_BaseMPISettings): - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - **kwargs: t.Any, - ) -> None: - """Settings to run job with ``mpirun`` command (MPI-standard) - - Note that environment variables can be passed with a None - value to signify that they should be exported from the current - environment - - Any arguments passed in the ``run_args`` dict will be converted - into ``mpirun`` arguments and prefixed with ``--``. Values of - None can be provided for arguments that do not have values. - - :param exe: executable - :param exe_args: executable arguments - :param run_args: arguments for run command - :param env_vars: environment vars to launch job with - """ - super().__init__(exe, exe_args, "mpirun", run_args, env_vars, **kwargs) - - -class MpiexecSettings(_BaseMPISettings): - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - **kwargs: t.Any, - ) -> None: - """Settings to run job with ``mpiexec`` command (MPI-standard) - - Note that environment variables can be passed with a None - value to signify that they should be exported from the current - environment - - Any arguments passed in the ``run_args`` dict will be converted - into ``mpiexec`` arguments and prefixed with ``--``. Values of - None can be provided for arguments that do not have values. - - :param exe: executable - :param exe_args: executable arguments - :param run_args: arguments for run command - :param env_vars: environment vars to launch job with - """ - super().__init__(exe, exe_args, "mpiexec", run_args, env_vars, **kwargs) - - completed_process = subprocess.run( - [self._run_command, "--help"], capture_output=True, check=False - ) - help_statement = completed_process.stdout.decode() - if "mpiexec.slurm" in help_statement: - raise SSUnsupportedError( - "Slurm's wrapper for mpiexec is unsupported. Use slurmSettings instead" - ) - - -class OrterunSettings(_BaseMPISettings): - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - **kwargs: t.Any, - ) -> None: - """Settings to run job with ``orterun`` command (MPI-standard) - - Note that environment variables can be passed with a None - value to signify that they should be exported from the current - environment - - Any arguments passed in the ``run_args`` dict will be converted - into ``orterun`` arguments and prefixed with ``--``. Values of - None can be provided for arguments that do not have values. - - :param exe: executable - :param exe_args: executable arguments - :param run_args: arguments for run command - :param env_vars: environment vars to launch job with - """ - super().__init__(exe, exe_args, "orterun", run_args, env_vars, **kwargs) diff --git a/smartsim/settings/palsSettings.py b/smartsim/settings/palsSettings.py deleted file mode 100644 index 4100e8efeb..0000000000 --- a/smartsim/settings/palsSettings.py +++ /dev/null @@ -1,233 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import typing as t - -from ..log import get_logger -from .mpiSettings import _BaseMPISettings - -logger = get_logger(__name__) - - -class PalsMpiexecSettings(_BaseMPISettings): - """Settings to run job with ``mpiexec`` under the HPE Cray - Parallel Application Launch Service (PALS) - - Note that environment variables can be passed with a None - value to signify that they should be exported from the current - environment - - Any arguments passed in the ``run_args`` dict will be converted - into ``mpiexec`` arguments and prefixed with ``--``. Values of - None can be provided for arguments that do not have values. - - :param exe: executable - :param exe_args: executable arguments - :param run_args: arguments for run command - :param env_vars: environment vars to launch job with - """ - - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - fail_if_missing_exec: bool = True, - **kwargs: t.Any, - ) -> None: - """Settings to format run job with an MPI-standard binary - - Note that environment variables can be passed with a None - value to signify that they should be exported from the current - environment - - Any arguments passed in the ``run_args`` dict will be converted - command line arguments and prefixed with ``--``. Values of - None can be provided for arguments that do not have values. - - :param exe: executable - :param exe_args: executable arguments - :param run_args: arguments for run command - :param env_vars: environment vars to launch job with - :param fail_if_missing_exec: Throw an exception of the MPI command - is missing. Otherwise, throw a warning - """ - super().__init__( - exe, - exe_args, - run_command="mpiexec", - run_args=run_args, - env_vars=env_vars, - fail_if_missing_exec=fail_if_missing_exec, - **kwargs, - ) - - def set_task_map(self, task_mapping: str) -> None: - """Set ``mpirun`` task mapping - - this sets ``--map-by `` - - For examples, see the man page for ``mpirun`` - - :param task_mapping: task mapping - """ - logger.warning("set_task_map not supported under PALS") - - def set_cpus_per_task(self, cpus_per_task: int) -> None: - """Set the number of tasks for this job - - This sets ``--cpus-per-proc`` for MPI compliant implementations - - note: this option has been deprecated in openMPI 4.0+ - and will soon be replaced. - - :param cpus_per_task: number of tasks - """ - logger.warning("set_cpus_per_task not supported under PALS") - - def set_cpu_binding_type(self, bind_type: str) -> None: - """Specifies the cores to which MPI processes are bound - - This sets ``--bind-to`` for MPI compliant implementations - - :param bind_type: binding type - """ - self.run_args["cpu-bind"] = bind_type - - def set_tasks(self, tasks: int) -> None: - """Set the number of tasks - - :param tasks: number of total tasks to launch - """ - self.run_args["np"] = int(tasks) - - def set_tasks_per_node(self, tasks_per_node: int) -> None: - """Set the number of tasks per node - - :param tasks_per_node: number of tasks to launch per node - """ - self.run_args["ppn"] = int(tasks_per_node) - - def set_quiet_launch(self, quiet: bool) -> None: - """Set the job to run in quiet mode - - This sets ``--quiet`` - - :param quiet: Whether the job should be run quietly - """ - - logger.warning("set_quiet_launch not supported under PALS") - - def set_broadcast(self, dest_path: t.Optional[str] = None) -> None: - """Copy the specified executable(s) to remote machines - - This sets ``--preload-binary`` - - :param dest_path: Destination path (Ignored) - """ - if dest_path is not None and isinstance(dest_path, str): - logger.warning( - ( - f"{type(self)} cannot set a destination path during broadcast. " - "Using session directory instead" - ) - ) - self.run_args["transfer"] = None - - def set_walltime(self, walltime: str) -> None: - """Set the maximum number of seconds that a job will run - - :param walltime: number like string of seconds that a job will run in secs - """ - logger.warning("set_walltime not supported under PALS") - - def set_gpu_affinity_script(self, affinity: str, *args: t.Any) -> None: - """Set the GPU affinity through a bash script - - :param affinity: path to the affinity script - """ - self.affinity_script.append(str(affinity)) - for arg in args: - self.affinity_script.append(str(arg)) - - def format_run_args(self) -> t.List[str]: - """Return a list of MPI-standard formatted run arguments - - :return: list of MPI-standard arguments for these settings - """ - # args launcher uses - args = [] - restricted = ["wdir", "wd"] - - for opt, value in self.run_args.items(): - if opt not in restricted: - prefix = "--" - if not value: - args += [prefix + opt] - else: - args += [prefix + opt, str(value)] - - if self.affinity_script: - args += self.affinity_script - - return args - - def format_env_vars(self) -> t.List[str]: - """Format the environment variables for mpirun - - :return: list of env vars - """ - formatted = [] - - export_vars = [] - if self.env_vars: - for name, value in self.env_vars.items(): - if value: - formatted += ["--env", "=".join((name, str(value)))] - else: - export_vars.append(name) - - if export_vars: - formatted += ["--envlist", ",".join(export_vars)] - - return formatted - - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: - """Set the hostlist for the PALS ``mpiexec`` command - - This sets ``--hosts`` - - :param host_list: list of host names - :raises TypeError: if not str or list of str - """ - if isinstance(host_list, str): - host_list = [host_list.strip()] - if not isinstance(host_list, list): - raise TypeError("host_list argument must be a list of strings") - if not all(isinstance(host, str) for host in host_list): - raise TypeError("host_list argument must be list of strings") - self.run_args["hosts"] = ",".join(host_list) diff --git a/smartsim/settings/pbsSettings.py b/smartsim/settings/pbsSettings.py deleted file mode 100644 index 09d48181a2..0000000000 --- a/smartsim/settings/pbsSettings.py +++ /dev/null @@ -1,263 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import typing as t - -from ..error import SSConfigError -from ..log import get_logger -from .base import BatchSettings - -logger = get_logger(__name__) - - -class QsubBatchSettings(BatchSettings): - def __init__( - self, - nodes: t.Optional[int] = None, - ncpus: t.Optional[int] = None, - time: t.Optional[str] = None, - queue: t.Optional[str] = None, - account: t.Optional[str] = None, - resources: t.Optional[t.Dict[str, t.Union[str, int]]] = None, - batch_args: t.Optional[t.Dict[str, t.Optional[str]]] = None, - **kwargs: t.Any, - ): - """Specify ``qsub`` batch parameters for a job - - ``nodes``, and ``ncpus`` are used to create the - select statement for PBS if a select statement is not - included in the ``resources``. If both are supplied - the value for select statement supplied in ``resources`` - will override. - - :param nodes: number of nodes for batch - :param ncpus: number of cpus per node - :param time: walltime for batch job - :param queue: queue to run batch in - :param account: account for batch launch - :param resources: overrides for resource arguments - :param batch_args: overrides for PBS batch arguments - """ - - self._ncpus = ncpus - - self.resources = resources or {} - resource_nodes = self.resources.get("nodes", None) - - if nodes and resource_nodes: - raise ValueError( - "nodes was incorrectly specified as a constructor parameter and also " - "as a key in the resource mapping" - ) - - # time, queue, nodes, and account set in parent class init - super().__init__( - "qsub", - batch_args=batch_args, - nodes=nodes, - account=account, - queue=queue, - time=time, - **kwargs, - ) - - self._hosts: t.List[str] = [] - - @property - def resources(self) -> t.Dict[str, t.Union[str, int]]: - return self._resources.copy() - - @resources.setter - def resources(self, resources: t.Dict[str, t.Union[str, int]]) -> None: - self._sanity_check_resources(resources) - self._resources = resources.copy() - - def set_nodes(self, num_nodes: int) -> None: - """Set the number of nodes for this batch job - - In PBS, 'select' is the more primitive way of describing how - many nodes to allocate for the job. 'nodes' is equivalent to - 'select' with a 'place' statement. Assuming that only advanced - users would use 'set_resource' instead, defining the number of - nodes here is sets the 'nodes' resource. - - :param num_nodes: number of nodes - """ - - if num_nodes: - self.set_resource("nodes", num_nodes) - - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: - """Specify the hostlist for this job - - :param host_list: hosts to launch on - :raises TypeError: if not str or list of str - """ - if isinstance(host_list, str): - host_list = [host_list.strip()] - if not isinstance(host_list, list): - raise TypeError("host_list argument must be a list of strings") - if not all(isinstance(host, str) for host in host_list): - raise TypeError("host_list argument must be a list of strings") - self._hosts = host_list - - def set_walltime(self, walltime: str) -> None: - """Set the walltime of the job - - format = "HH:MM:SS" - - If a walltime argument is provided in - ``QsubBatchSettings.resources``, then - this value will be overridden - - :param walltime: wall time - """ - if walltime: - self.set_resource("walltime", walltime) - - def set_queue(self, queue: str) -> None: - """Set the queue for the batch job - - :param queue: queue name - """ - if queue: - self.batch_args["q"] = str(queue) - - def set_ncpus(self, num_cpus: t.Union[int, str]) -> None: - """Set the number of cpus obtained in each node. - - If a select argument is provided in - ``QsubBatchSettings.resources``, then - this value will be overridden - - :param num_cpus: number of cpus per node in select - """ - self._ncpus = int(num_cpus) - - def set_account(self, account: str) -> None: - """Set the account for this batch job - - :param acct: account id - """ - if account: - self.batch_args["A"] = str(account) - - def set_resource(self, resource_name: str, value: t.Union[str, int]) -> None: - """Set a resource value for the Qsub batch - - If a select statement is provided, the nodes and ncpus - arguments will be overridden. Likewise for Walltime - - :param resource_name: name of resource, e.g. walltime - :param value: value - """ - # TODO add error checking here - # TODO include option to overwrite place (warning for orchestrator?) - updated_dict = self.resources - updated_dict.update({resource_name: value}) - self._sanity_check_resources(updated_dict) - self.resources = updated_dict - - def format_batch_args(self) -> t.List[str]: - """Get the formatted batch arguments for a preview - - :return: batch arguments for Qsub - :raises ValueError: if options are supplied without values - """ - opts = self._create_resource_list() - for opt, value in self.batch_args.items(): - prefix = "-" - if not value: - raise ValueError("PBS options without values are not allowed") - opts += [" ".join((prefix + opt, str(value)))] - return opts - - def _sanity_check_resources( - self, resources: t.Optional[t.Dict[str, t.Union[str, int]]] = None - ) -> None: - """Check that only select or nodes was specified in resources - - Note: For PBS Pro, nodes is equivalent to 'select' and 'place' so - they are not quite synonyms. Here we assume that - """ - # Note: isinstance check here to avoid collision with default - checked_resources = resources if isinstance(resources, dict) else self.resources - - has_select = checked_resources.get("select", None) - has_nodes = checked_resources.get("nodes", None) - - if has_select and has_nodes: - raise SSConfigError( - "'select' and 'nodes' cannot both be specified. This can happen " - "if nodes were specified using the 'set_nodes' method and " - "'select' was set using 'set_resource'. Please only specify one." - ) - - if has_select and not isinstance(has_select, int): - raise TypeError("The value for 'select' must be an integer") - if has_nodes and not isinstance(has_nodes, int): - raise TypeError("The value for 'nodes' must be an integer") - - for key, value in checked_resources.items(): - if not isinstance(key, str): - raise TypeError( - f"The type of {key=} is {type(key)}. Only int and str " - "are allowed." - ) - if not isinstance(value, (str, int)): - raise TypeError( - f"The value associated with {key=} is {type(value)}. Only int " - "and str are allowed." - ) - - def _create_resource_list(self) -> t.List[str]: - self._sanity_check_resources() - res = [] - - # Pop off some specific keywords that need to be treated separately - resources = self.resources # Note this is a copy so not modifying original - - # Construct the basic select/nodes statement - if select := resources.pop("select", None): - select_command = f"-l select={select}" - elif nodes := resources.pop("nodes", None): - select_command = f"-l nodes={nodes}" - else: - raise SSConfigError( - "Insufficient resource specification: no nodes or select statement" - ) - if self._ncpus: - select_command += f":ncpus={self._ncpus}" - if self._hosts: - hosts = ["=".join(("host", str(host))) for host in self._hosts] - select_command += f":{'+'.join(hosts)}" - res += [select_command] - - # All other "standard" resource specs - for resource, value in resources.items(): - res += [f"-l {resource}={value}"] - - return res diff --git a/smartsim/settings/settings.py b/smartsim/settings/settings.py deleted file mode 100644 index 5afd0e1929..0000000000 --- a/smartsim/settings/settings.py +++ /dev/null @@ -1,205 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import typing as t - -from .._core.utils.helpers import is_valid_cmd -from ..error import SmartSimError -from ..settings import ( - AprunSettings, - BsubBatchSettings, - Container, - DragonRunSettings, - JsrunSettings, - MpiexecSettings, - MpirunSettings, - OrterunSettings, - PalsMpiexecSettings, - QsubBatchSettings, - RunSettings, - SbatchSettings, - SgeQsubBatchSettings, - SrunSettings, - base, -) -from ..wlm import detect_launcher - -_TRunSettingsSelector = t.Callable[[str], t.Callable[..., RunSettings]] - - -def create_batch_settings( - launcher: str, - nodes: t.Optional[int] = None, - time: str = "", - queue: t.Optional[str] = None, - account: t.Optional[str] = None, - batch_args: t.Optional[t.Dict[str, str]] = None, - **kwargs: t.Any, -) -> base.BatchSettings: - """Create a ``BatchSettings`` instance - - See Experiment.create_batch_settings for details - - :param launcher: launcher for this experiment, if set to 'auto', - an attempt will be made to find an available launcher on the system - :param nodes: number of nodes for batch job - :param time: length of batch job - :param queue: queue or partition (if slurm) - :param account: user account name for batch system - :param batch_args: additional batch arguments - :return: a newly created BatchSettings instance - :raises SmartSimError: if batch creation fails - """ - # all supported batch class implementations - by_launcher: t.Dict[str, t.Callable[..., base.BatchSettings]] = { - "pbs": QsubBatchSettings, - "slurm": SbatchSettings, - "lsf": BsubBatchSettings, - "pals": QsubBatchSettings, - "sge": SgeQsubBatchSettings, - } - - if launcher in ["auto", "dragon"]: - launcher = detect_launcher() - if launcher == "dragon": - by_launcher["dragon"] = by_launcher[launcher] - - if launcher == "local": - raise SmartSimError("Local launcher does not support batch workloads") - - # detect the batch class to use based on the launcher provided by - # the user - try: - batch_class = by_launcher[launcher] - batch_settings = batch_class( - nodes=nodes, - time=time, - batch_args=batch_args, - queue=queue, - account=account, - **kwargs, - ) - return batch_settings - - except KeyError: - raise SmartSimError( - f"User attempted to make batch settings for unsupported launcher {launcher}" - ) from None - - -def create_run_settings( - launcher: str, - exe: str, - exe_args: t.Optional[t.List[str]] = None, - run_command: str = "auto", - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - container: t.Optional[Container] = None, - **kwargs: t.Any, -) -> RunSettings: - """Create a ``RunSettings`` instance. - - See Experiment.create_run_settings docstring for more details - - :param launcher: launcher to create settings for, if set to 'auto', - an attempt will be made to find an available launcher on the system - :param run_command: command to run the executable - :param exe: executable to run - :param exe_args: arguments to pass to the executable - :param run_args: arguments to pass to the ``run_command`` - :param env_vars: environment variables to pass to the executable - :param container: container type for workload (e.g. "singularity") - :return: the created ``RunSettings`` - :raises SmartSimError: if run_command=="auto" and detection fails - """ - # all supported RunSettings child classes - supported: t.Dict[str, _TRunSettingsSelector] = { - "aprun": lambda launcher: AprunSettings, - "srun": lambda launcher: SrunSettings, - "mpirun": lambda launcher: MpirunSettings, - "mpiexec": lambda launcher: ( - MpiexecSettings if launcher != "pals" else PalsMpiexecSettings - ), - "orterun": lambda launcher: OrterunSettings, - "jsrun": lambda launcher: JsrunSettings, - } - - # run commands supported by each launcher - # in order of suspected user preference - by_launcher = { - "dragon": [""], - "slurm": ["srun", "mpirun", "mpiexec"], - "pbs": ["aprun", "mpirun", "mpiexec"], - "pals": ["mpiexec"], - "lsf": ["jsrun", "mpirun", "mpiexec"], - "sge": ["mpirun", "mpiexec"], - "local": [""], - } - - if launcher == "auto": - launcher = detect_launcher() - - def _detect_command(launcher: str) -> str: - if launcher in by_launcher: - if launcher in ["local", "dragon"]: - return "" - - for cmd in by_launcher[launcher]: - if is_valid_cmd(cmd): - return cmd - msg = ( - "Could not automatically detect a run command to use for launcher " - f"{launcher}\nSearched for and could not find the following " - f"commands: {by_launcher[launcher]}" - ) - raise SmartSimError(msg) - - if run_command: - run_command = run_command.lower() - launcher = launcher.lower() - - # detect run_command automatically for all but local launcher - if run_command == "auto": - # no auto detection for local, revert to false - run_command = _detect_command(launcher) - - if launcher == "dragon": - return DragonRunSettings( - exe=exe, exe_args=exe_args, env_vars=env_vars, container=container, **kwargs - ) - - # if user specified and supported or auto detection worked - if run_command and run_command in supported: - return supported[run_command](launcher)( - exe, exe_args, run_args, env_vars, container=container, **kwargs - ) - - # 1) user specified and not implementation in SmartSim - # 2) user supplied run_command=None - # 3) local launcher being used and default of "auto" was passed. - return RunSettings( - exe, exe_args, run_command, run_args, env_vars, container=container - ) diff --git a/smartsim/settings/sgeSettings.py b/smartsim/settings/sge_settings.py similarity index 100% rename from smartsim/settings/sgeSettings.py rename to smartsim/settings/sge_settings.py diff --git a/smartsim/settings/slurmSettings.py b/smartsim/settings/slurmSettings.py deleted file mode 100644 index 64f73fa9c5..0000000000 --- a/smartsim/settings/slurmSettings.py +++ /dev/null @@ -1,513 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -import datetime -import os -import typing as t - -from ..error import SSUnsupportedError -from ..log import get_logger -from .base import BatchSettings, RunSettings - -logger = get_logger(__name__) - - -class SrunSettings(RunSettings): - def __init__( - self, - exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - alloc: t.Optional[str] = None, - **kwargs: t.Any, - ) -> None: - """Initialize run parameters for a slurm job with ``srun`` - - ``SrunSettings`` should only be used on Slurm based systems. - - If an allocation is specified, the instance receiving these run - parameters will launch on that allocation. - - :param exe: executable to run - :param exe_args: executable arguments - :param run_args: srun arguments without dashes - :param env_vars: environment variables for job - :param alloc: allocation ID if running on existing alloc - """ - super().__init__( - exe, - exe_args, - run_command="srun", - run_args=run_args, - env_vars=env_vars, - **kwargs, - ) - self.alloc = alloc - self.mpmd: t.List[RunSettings] = [] - - reserved_run_args = {"chdir", "D"} - - def set_nodes(self, nodes: int) -> None: - """Set the number of nodes - - Effectively this is setting: ``srun --nodes `` - - :param nodes: number of nodes to run with - """ - self.run_args["nodes"] = int(nodes) - - def make_mpmd(self, settings: RunSettings) -> None: - """Make a mpmd workload by combining two ``srun`` commands - - This connects the two settings to be executed with a single - Model instance - - :param settings: SrunSettings instance - """ - if self.colocated_db_settings: - raise SSUnsupportedError( - "Colocated models cannot be run as a mpmd workload" - ) - if self.container: - raise SSUnsupportedError( - "Containerized MPMD workloads are not yet supported." - ) - if os.getenv("SLURM_HET_SIZE") is not None: - raise ValueError( - "Slurm does not support MPMD workloads in heterogeneous jobs." - ) - self.mpmd.append(settings) - - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: - """Specify the hostlist for this job - - This sets ``--nodelist`` - - :param host_list: hosts to launch on - :raises TypeError: if not str or list of str - """ - if isinstance(host_list, str): - host_list = [host_list.strip()] - if not isinstance(host_list, list): - raise TypeError("host_list argument must be a list of strings") - if not all(isinstance(host, str) for host in host_list): - raise TypeError("host_list argument must be list of strings") - self.run_args["nodelist"] = ",".join(host_list) - - def set_hostlist_from_file(self, file_path: str) -> None: - """Use the contents of a file to set the node list - - This sets ``--nodefile`` - - :param file_path: Path to the hostlist file - """ - self.run_args["nodefile"] = file_path - - def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: - """Specify a list of hosts to exclude for launching this job - - :param host_list: hosts to exclude - :raises TypeError: - """ - if isinstance(host_list, str): - host_list = [host_list.strip()] - if not isinstance(host_list, list): - raise TypeError("host_list argument must be a list of strings") - if not all(isinstance(host, str) for host in host_list): - raise TypeError("host_list argument must be list of strings") - self.run_args["exclude"] = ",".join(host_list) - - def set_cpus_per_task(self, cpus_per_task: int) -> None: - """Set the number of cpus to use per task - - This sets ``--cpus-per-task`` - - :param num_cpus: number of cpus to use per task - """ - self.run_args["cpus-per-task"] = int(cpus_per_task) - - def set_tasks(self, tasks: int) -> None: - """Set the number of tasks for this job - - This sets ``--ntasks`` - - :param tasks: number of tasks - """ - self.run_args["ntasks"] = int(tasks) - - def set_tasks_per_node(self, tasks_per_node: int) -> None: - """Set the number of tasks for this job - - This sets ``--ntasks-per-node`` - - :param tasks_per_node: number of tasks per node - """ - self.run_args["ntasks-per-node"] = int(tasks_per_node) - - def set_cpu_bindings(self, bindings: t.Union[int, t.List[int]]) -> None: - """Bind by setting CPU masks on tasks - - This sets ``--cpu-bind`` using the ``map_cpu:`` option - - :param bindings: List specifing the cores to which MPI processes are bound - """ - if isinstance(bindings, int): - bindings = [bindings] - self.run_args["cpu_bind"] = "map_cpu:" + ",".join( - str(int(num)) for num in bindings - ) - - def set_memory_per_node(self, memory_per_node: int) -> None: - """Specify the real memory required per node - - This sets ``--mem`` in megabytes - - :param memory_per_node: Amount of memory per node in megabytes - """ - self.run_args["mem"] = f"{int(memory_per_node)}M" - - def set_verbose_launch(self, verbose: bool) -> None: - """Set the job to run in verbose mode - - This sets ``--verbose`` - - :param verbose: Whether the job should be run verbosely - """ - if verbose: - self.run_args["verbose"] = None - else: - self.run_args.pop("verbose", None) - - def set_quiet_launch(self, quiet: bool) -> None: - """Set the job to run in quiet mode - - This sets ``--quiet`` - - :param quiet: Whether the job should be run quietly - """ - if quiet: - self.run_args["quiet"] = None - else: - self.run_args.pop("quiet", None) - - def set_broadcast(self, dest_path: t.Optional[str] = None) -> None: - """Copy executable file to allocated compute nodes - - This sets ``--bcast`` - - :param dest_path: Path to copy an executable file - """ - self.run_args["bcast"] = dest_path - - def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: - """Specify the node feature for this job - - This sets ``-C`` - - :param feature_list: node feature to launch on - :raises TypeError: if not str or list of str - """ - if isinstance(feature_list, str): - feature_list = [feature_list.strip()] - elif not all(isinstance(feature, str) for feature in feature_list): - raise TypeError("node_feature argument must be string or list of strings") - self.run_args["C"] = ",".join(feature_list) - - @staticmethod - def _fmt_walltime(hours: int, minutes: int, seconds: int) -> str: - """Convert hours, minutes, and seconds into valid walltime format - - Converts time to format HH:MM:SS - - :param hours: number of hours to run job - :param minutes: number of minutes to run job - :param seconds: number of seconds to run job - :returns: Formatted walltime - """ - return fmt_walltime(hours, minutes, seconds) - - def set_walltime(self, walltime: str) -> None: - """Set the walltime of the job - - format = "HH:MM:SS" - - :param walltime: wall time - """ - self.run_args["time"] = str(walltime) - - def set_het_group(self, het_group: t.Iterable[int]) -> None: - """Set the heterogeneous group for this job - - this sets `--het-group` - - :param het_group: list of heterogeneous groups - """ - het_size_env = os.getenv("SLURM_HET_SIZE") - if het_size_env is None: - msg = "Requested to set het group, but the allocation is not a het job" - raise ValueError(msg) - - het_size = int(het_size_env) - if self.mpmd: - msg = "Slurm does not support MPMD workloads in heterogeneous jobs\n" - raise ValueError(msg) - msg = ( - "Support for heterogeneous groups is an experimental feature, " - "please report any unexpected behavior to SmartSim developers " - "by opening an issue on https://github.com/CrayLabs/SmartSim/issues" - ) - if any(group >= het_size for group in het_group): - msg = ( - f"Het group {max(het_group)} requested, " - f"but max het group in allocation is {het_size-1}" - ) - raise ValueError(msg) - logger.warning(msg) - self.run_args["het-group"] = ",".join(str(group) for group in het_group) - - def format_run_args(self) -> t.List[str]: - """Return a list of slurm formatted run arguments - - :return: list of slurm arguments for these settings - """ - # add additional slurm arguments based on key length - opts = [] - for opt, value in self.run_args.items(): - short_arg = bool(len(str(opt)) == 1) - prefix = "-" if short_arg else "--" - if not value: - opts += [prefix + opt] - else: - if short_arg: - opts += [prefix + opt, str(value)] - else: - opts += ["=".join((prefix + opt, str(value)))] - return opts - - def check_env_vars(self) -> None: - """Warn a user trying to set a variable which is set in the environment - - Given Slurm's env var precedence, trying to export a variable which is already - present in the environment will not work. - """ - for k, v in self.env_vars.items(): - if "," not in str(v): - # If a variable is defined, it will take precedence over --export - # we warn the user - preexisting_var = os.environ.get(k, None) - if preexisting_var is not None and preexisting_var != v: - msg = ( - f"Variable {k} is set to {preexisting_var} in current " - "environment. If the job is running in an interactive " - f"allocation, the value {v} will not be set. Please " - "consider removing the variable from the environment " - "and re-running the experiment." - ) - logger.warning(msg) - - def format_env_vars(self) -> t.List[str]: - """Build bash compatible environment variable string for Slurm - - :returns: the formatted string of environment variables - """ - self.check_env_vars() - return [f"{k}={v}" for k, v in self.env_vars.items() if "," not in str(v)] - - def format_comma_sep_env_vars(self) -> t.Tuple[str, t.List[str]]: - """Build environment variable string for Slurm - - Slurm takes exports in comma separated lists - the list starts with all as to not disturb the rest of the environment - for more information on this, see the slurm documentation for srun - - :returns: the formatted string of environment variables - """ - self.check_env_vars() - exportable_env, compound_env, key_only = [], [], [] - - for k, v in self.env_vars.items(): - kvp = f"{k}={v}" - - if "," in str(v): - key_only.append(k) - compound_env.append(kvp) - else: - exportable_env.append(kvp) - - # Append keys to exportable KVPs, e.g. `--export x1=v1,KO1,KO2` - fmt_exported_env = ",".join(v for v in exportable_env + key_only) - - for mpmd in self.mpmd: - compound_mpmd_env = { - k: v for k, v in mpmd.env_vars.items() if "," in str(v) - } - compound_mpmd_fmt = {f"{k}={v}" for k, v in compound_mpmd_env.items()} - compound_env.extend(compound_mpmd_fmt) - - return fmt_exported_env, compound_env - - -def fmt_walltime(hours: int, minutes: int, seconds: int) -> str: - """Helper function walltime format conversion - - Converts time to format HH:MM:SS - - :param hours: number of hours to run job - :param minutes: number of minutes to run job - :param seconds: number of seconds to run job - :returns: Formatted walltime - """ - delta = datetime.timedelta(hours=hours, minutes=minutes, seconds=seconds) - fmt_str = str(delta) - if delta.seconds // 3600 < 10: - fmt_str = "0" + fmt_str - return fmt_str - - -class SbatchSettings(BatchSettings): - def __init__( - self, - nodes: t.Optional[int] = None, - time: str = "", - account: t.Optional[str] = None, - batch_args: t.Optional[t.Dict[str, t.Optional[str]]] = None, - **kwargs: t.Any, - ) -> None: - """Specify run parameters for a Slurm batch job - - Slurm `sbatch` arguments can be written into ``batch_args`` - as a dictionary. e.g. {'ntasks': 1} - - If the argument doesn't have a parameter, put `None` - as the value. e.g. {'exclusive': None} - - Initialization values provided (nodes, time, account) - will overwrite the same arguments in ``batch_args`` if present - - :param nodes: number of nodes - :param time: walltime for job, e.g. "10:00:00" for 10 hours - :param account: account for job - :param batch_args: extra batch arguments - """ - super().__init__( - "sbatch", - batch_args=batch_args, - nodes=nodes, - account=account, - time=time, - **kwargs, - ) - - def set_walltime(self, walltime: str) -> None: - """Set the walltime of the job - - format = "HH:MM:SS" - - :param walltime: wall time - """ - # TODO check for formatting here - if walltime: - self.batch_args["time"] = walltime - - def set_nodes(self, num_nodes: int) -> None: - """Set the number of nodes for this batch job - - :param num_nodes: number of nodes - """ - if num_nodes: - self.batch_args["nodes"] = str(int(num_nodes)) - - def set_account(self, account: str) -> None: - """Set the account for this batch job - - :param account: account id - """ - if account: - self.batch_args["account"] = account - - def set_partition(self, partition: str) -> None: - """Set the partition for the batch job - - :param partition: partition name - """ - self.batch_args["partition"] = str(partition) - - def set_queue(self, queue: str) -> None: - """alias for set_partition - - Sets the partition for the slurm batch job - - :param queue: the partition to run the batch job on - """ - if queue: - self.set_partition(queue) - - def set_cpus_per_task(self, cpus_per_task: int) -> None: - """Set the number of cpus to use per task - - This sets ``--cpus-per-task`` - - :param num_cpus: number of cpus to use per task - """ - self.batch_args["cpus-per-task"] = str(int(cpus_per_task)) - - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: - """Specify the hostlist for this job - - :param host_list: hosts to launch on - :raises TypeError: if not str or list of str - """ - if isinstance(host_list, str): - host_list = [host_list.strip()] - if not isinstance(host_list, list): - raise TypeError("host_list argument must be a list of strings") - if not all(isinstance(host, str) for host in host_list): - raise TypeError("host_list argument must be list of strings") - self.batch_args["nodelist"] = ",".join(host_list) - - def format_batch_args(self) -> t.List[str]: - """Get the formatted batch arguments for a preview - - :return: batch arguments for Sbatch - """ - opts = [] - # TODO add restricted here - for opt, value in self.batch_args.items(): - # attach "-" prefix if argument is 1 character otherwise "--" - short_arg = bool(len(str(opt)) == 1) - prefix = "-" if short_arg else "--" - - if not value: - opts += [prefix + opt] - else: - if short_arg: - opts += [prefix + opt, str(value)] - else: - opts += ["=".join((prefix + opt, str(value)))] - return opts diff --git a/smartsim/status.py b/smartsim/status.py index e0d950619c..e631a454d1 100644 --- a/smartsim/status.py +++ b/smartsim/status.py @@ -27,19 +27,23 @@ from enum import Enum -class SmartSimStatus(Enum): - STATUS_RUNNING = "Running" - STATUS_COMPLETED = "Completed" - STATUS_CANCELLED = "Cancelled" - STATUS_FAILED = "Failed" - STATUS_NEW = "New" - STATUS_PAUSED = "Paused" - STATUS_NEVER_STARTED = "NeverStarted" - STATUS_QUEUED = "Queued" +class JobStatus(Enum): + UNKNOWN = "Unknown" + RUNNING = "Running" + COMPLETED = "Completed" + CANCELLED = "Cancelled" + FAILED = "Failed" + NEW = "New" + PAUSED = "Paused" + QUEUED = "Queued" + + +class InvalidJobStatus(Enum): + NEVER_STARTED = "Never Started" TERMINAL_STATUSES = { - SmartSimStatus.STATUS_CANCELLED, - SmartSimStatus.STATUS_COMPLETED, - SmartSimStatus.STATUS_FAILED, + JobStatus.CANCELLED, + JobStatus.COMPLETED, + JobStatus.FAILED, } diff --git a/smartsim/templates/templates/preview/plain_text/activeinfra.template b/smartsim/templates/templates/preview/plain_text/activeinfra.template index 8f403fbc07..3e9ed6a2eb 100644 --- a/smartsim/templates/templates/preview/plain_text/activeinfra.template +++ b/smartsim/templates/templates/preview/plain_text/activeinfra.template @@ -1,9 +1,9 @@ - = Database Identifier: {{ db.entity.db_identifier }} = - Shards: {{ db.entity.num_shards }} + = Feature Store Identifier: {{ fs.entity.fs_identifier }} = + Shards: {{ fs.entity.num_shards }} TCP/IP Port(s): - {%- for port in db.entity.ports %} + {%- for port in fs.entity.ports %} {{ port }} {%- endfor %} - Network Interface: {{ db.entity.run_settings.exe_args | get_ifname }} - Type: {{ config.database_cli | get_dbtype }} + Network Interface: {{ fs.entity.run_settings.exe_args | get_ifname }} + Type: {{ config.database_cli | get_fstype }} diff --git a/smartsim/templates/templates/preview/plain_text/base.template b/smartsim/templates/templates/preview/plain_text/base.template index 5117125543..5686b86768 100644 --- a/smartsim/templates/templates/preview/plain_text/base.template +++ b/smartsim/templates/templates/preview/plain_text/base.template @@ -1,22 +1,22 @@ {% include "experiment.template" %} -{%- if manifest.has_deployable or active_dbjobs %} +{%- if manifest.has_deployable or active_fsjobs %} === Entity Preview === - {%- if active_dbjobs %} + {%- if active_fsjobs %} == Active Infrastructure == - {%- for name, db in active_dbjobs.items() %} + {%- for name, fs in active_fsjobs.items() %} {% include "activeinfra.template" %} {%- endfor %} {%- endif %} - {%- if manifest.dbs %} + {%- if manifest.fss %} - == Orchestrators == - {%- for db in manifest.dbs %} - {%- if db.is_active() %} - WARNING: Cannot preview {{ db.name }}, because it is already started. + == Feature Stores == + {%- for fs in manifest.fss %} + {%- if fs.is_active() %} + WARNING: Cannot preview {{ fs.name }}, because it is already started. {%- else %} {% include "orchestrator.template" %} {%- endif %} @@ -29,12 +29,12 @@ = Model Name: {{ model.name }} = {%- include "model.template" %} - {%- if model.run_settings.colocated_db_settings or manifest.dbs %} + {%- if model.run_settings.colocated_fs_settings or manifest.fss %} Client Configuration: - {%- if model.run_settings.colocated_db_settings %} + {%- if model.run_settings.colocated_fs_settings %} {%- include "clientconfigcolo.template" %} {%- endif %} - {%- if manifest.dbs %} + {%- if manifest.fss %} {%- include "clientconfig.template" %} {%- endif %} {%- endif %} diff --git a/smartsim/templates/templates/preview/plain_text/clientconfig_debug.template b/smartsim/templates/templates/preview/plain_text/clientconfig_debug.template index 51dafd0d18..12e647cdc4 100644 --- a/smartsim/templates/templates/preview/plain_text/clientconfig_debug.template +++ b/smartsim/templates/templates/preview/plain_text/clientconfig_debug.template @@ -1,12 +1,12 @@ - {%- for db in manifest.dbs %} - {%- if db.name %} - Database Identifier: {{ db.name }} + {%- for fs in manifest.fss %} + {%- if fs.name %} + Feature StoreIdentifier: {{ fs.name }} {%- endif %} {%- if verbosity_level == Verbosity.DEBUG or verbosity_level == Verbosity.DEVELOPER %} - Database Backend: {{ config.database_cli | get_dbtype }} + Feature Store Backend: {{ config.database_cli | get_fstype }} TCP/IP Port(s): - {%- for port in db.ports %} + {%- for port in fs.ports %} {{ port }} {%- endfor %} Type: Standalone diff --git a/smartsim/templates/templates/preview/plain_text/clientconfig_info.template b/smartsim/templates/templates/preview/plain_text/clientconfig_info.template index 164f4bd4a8..998b687073 100644 --- a/smartsim/templates/templates/preview/plain_text/clientconfig_info.template +++ b/smartsim/templates/templates/preview/plain_text/clientconfig_info.template @@ -1,11 +1,11 @@ - {%- for db in manifest.dbs %} - {%- if db.name %} - Database Identifier: {{ db.name }} + {%- for fs in manifest.fss %} + {%- if fs.name %} + Feature Store Identifier: {{ fs.name }} {%- endif %} - Database Backend: {{ config.database_cli | get_dbtype }} + Feature Store Backend: {{ config.database_cli | get_fstype }} TCP/IP Port(s): - {%- for port in db.ports %} + {%- for port in fs.ports %} {{ port }} {%- endfor %} Type: Standalone diff --git a/smartsim/templates/templates/preview/plain_text/clientconfigcolo_debug.template b/smartsim/templates/templates/preview/plain_text/clientconfigcolo_debug.template index 303fd0dcaf..93ad8aa7bc 100644 --- a/smartsim/templates/templates/preview/plain_text/clientconfigcolo_debug.template +++ b/smartsim/templates/templates/preview/plain_text/clientconfigcolo_debug.template @@ -1,25 +1,25 @@ - {%- if model.run_settings.colocated_db_settings.db_identifier %} - Database Identifier: {{ model.run_settings.colocated_db_settings.db_identifier }} + {%- if model.run_settings.colocated_fs_settings.fs_identifier %} + Feature Store Identifier: {{ model.run_settings.colocated_fs_settings.fs_identifier }} {%- else %} - Database Identifier: N/A + Feature Store Identifier: N/A {%- endif %} - Database Backend: {{ config.database_cli | get_dbtype }} - {%- if model.run_settings.colocated_db_settings %} - {%- if model.run_settings.colocated_db_settings.port %} + Feature Store Backend: {{ config.database_cli | get_fstype }} + {%- if model.run_settings.colocated_fs_settings %} + {%- if model.run_settings.colocated_fs_settings.port %} Connection Type: TCP TCP/IP Port(s): - {{ model.run_settings.colocated_db_settings.port }} + {{ model.run_settings.colocated_fs_settings.port }} {%- endif %} - {%- if model.run_settings.colocated_db_settings.unix_socket %} + {%- if model.run_settings.colocated_fs_settings.unix_socket %} Connection Type: UDS - Unix Socket: {{ model.run_settings.colocated_db_settings.unix_socket }} + Unix Socket: {{ model.run_settings.colocated_fs_settings.unix_socket }} {%- endif %} - {%- if model.run_settings.colocated_db_settings.ifname %} - {%- if model.run_settings.colocated_db_settings.ifname | is_list %} - Network Interface Name: {{ model.run_settings.colocated_db_settings.ifname[0] }} + {%- if model.run_settings.colocated_fs_settings.ifname %} + {%- if model.run_settings.colocated_fs_settings.ifname | is_list %} + Network Interface Name: {{ model.run_settings.colocated_fs_settings.ifname[0] }} {%- else %} - Network Interface Name: {{ model.run_settings.colocated_db_settings.ifname }} + Network Interface Name: {{ model.run_settings.colocated_fs_settings.ifname }} {%- endif %} {%- endif %} Type: Colocated diff --git a/smartsim/templates/templates/preview/plain_text/clientconfigcolo_info.template b/smartsim/templates/templates/preview/plain_text/clientconfigcolo_info.template index e03d7ce3bd..3b630f85a9 100644 --- a/smartsim/templates/templates/preview/plain_text/clientconfigcolo_info.template +++ b/smartsim/templates/templates/preview/plain_text/clientconfigcolo_info.template @@ -1,16 +1,16 @@ - {%- if model.run_settings.colocated_db_settings.db_identifier %} - Database Identifier: {{ model.run_settings.colocated_db_settings.db_identifier }} + {%- if model.run_settings.colocated_fs_settings.fs_identifier %} + Feature Store Identifier: {{ model.run_settings.colocated_fs_settings.fs_identifier }} {%- endif %} - Database Backend: {{ config.database_cli | get_dbtype }} - {%- if model.run_settings.colocated_db_settings.port %} + Feature Store Backend: {{ config.database_cli | get_fstype }} + {%- if model.run_settings.colocated_fs_settings.port %} Connection Type: TCP TCP/IP Port(s): - {{ model.run_settings.colocated_db_settings.port }} + {{ model.run_settings.colocated_fs_settings.port }} {%- endif %} - {%- if model.run_settings.colocated_db_settings.unix_socket %} + {%- if model.run_settings.colocated_fs_settings.unix_socket %} Connection Type: UDS - Unix Socket: {{ model.run_settings.colocated_db_settings.unix_socket }} + Unix Socket: {{ model.run_settings.colocated_fs_settings.unix_socket }} {%- endif %} Type: Colocated {%- if model.query_key_prefixing() %} diff --git a/smartsim/templates/templates/preview/plain_text/ensemble_debug.template b/smartsim/templates/templates/preview/plain_text/ensemble_debug.template index 862db60328..c458813cae 100644 --- a/smartsim/templates/templates/preview/plain_text/ensemble_debug.template +++ b/smartsim/templates/templates/preview/plain_text/ensemble_debug.template @@ -32,12 +32,12 @@ - Model Name: {{ model.name }} - {%- include 'model.template' %} - {%- if model.run_settings.colocated_db_settings or manifest.dbs %} + {%- if model.run_settings.colocated_fs_settings or manifest.fss %} Client Configuration: - {%- if model.run_settings.colocated_db_settings %} + {%- if model.run_settings.colocated_fs_settings %} {%- include "clientconfigcolo.template" %} {%- endif %} - {%- if manifest.dbs %} + {%- if manifest.fss %} {%- include "clientconfig.template" %} {%- endif %} {%- endif %} @@ -48,12 +48,12 @@ - Model Name: {{ model.name }} - {%- include 'model_debug.template' %} - {%- if model.run_settings.colocated_db_settings or manifest.dbs %} + {%- if model.run_settings.colocated_fs_settings or manifest.fss %} Client Configuration: - {%- if model.run_settings.colocated_db_settings %} + {%- if model.run_settings.colocated_fs_settings %} {%- include "clientconfigcolo.template" %} {%- endif %} - {%- if manifest.dbs %} + {%- if manifest.fss %} {%- include "clientconfig.template" %} {%- endif %} {%- endif %} diff --git a/smartsim/templates/templates/preview/plain_text/ensemble_info.template b/smartsim/templates/templates/preview/plain_text/ensemble_info.template index 17d1a40547..a7b9c22968 100644 --- a/smartsim/templates/templates/preview/plain_text/ensemble_info.template +++ b/smartsim/templates/templates/preview/plain_text/ensemble_info.template @@ -12,12 +12,12 @@ {% set model = ensemble.models[0] %} - Model Name: {{ model.name }} - {%- include 'model.template' %} - {%- if model.run_settings.colocated_db_settings or manifest.dbs %} + {%- if model.run_settings.colocated_fs_settings or manifest.fss %} Client Configuration: - {%- if model.run_settings.colocated_db_settings %} + {%- if model.run_settings.colocated_fs_settings %} {%- include "clientconfigcolo.template" %} {%- endif %} - {%- if manifest.dbs %} + {%- if manifest.fss %} {%- include "clientconfig.template" %} {%- endif %} {%- endif %} @@ -25,12 +25,12 @@ {% set model = ensemble.models[(ensemble.models | length)-1] %} - Model Name: {{ model.name }} - {%- include 'model.template' %} - {% if model.run_settings.colocated_db_settings or manifest.dbs %} + {% if model.run_settings.colocated_fs_settings or manifest.fss %} Client Configuration: - {%- if model.run_settings.colocated_db_settings %} + {%- if model.run_settings.colocated_fs_settings %} {%- include "clientconfigcolo.template" %} {%- endif %} - {%- if manifest.dbs %} + {%- if manifest.fss %} {%- include "clientconfig.template" %} {%- endif %} {%- endif %} @@ -38,12 +38,12 @@ {% for model in ensemble %} - Model Name: {{ model.name }} - {%- include 'model.template' %} - {% if model.run_settings.colocated_db_settings or manifest.dbs %} + {% if model.run_settings.colocated_fs_settings or manifest.fss %} Client Configuration: - {%- if model.run_settings.colocated_db_settings %} + {%- if model.run_settings.colocated_fs_settings %} {%- include "clientconfigcolo.template" %} {%- endif %} - {%- if manifest.dbs %} + {%- if manifest.fss %} {%- include "clientconfig.template" %} {%- endif %} {%- endif %} diff --git a/smartsim/templates/templates/preview/plain_text/model_debug.template b/smartsim/templates/templates/preview/plain_text/model_debug.template index 186746186a..6605d50ab7 100644 --- a/smartsim/templates/templates/preview/plain_text/model_debug.template +++ b/smartsim/templates/templates/preview/plain_text/model_debug.template @@ -54,42 +54,42 @@ {%- endfor %} {%- endif %} {%- endif %} - {%- if model.run_settings.colocated_db_settings %} + {%- if model.run_settings.colocated_fs_settings %} Colocated: - {%- if model.run_settings.colocated_db_settings.db_identifier %} - Database Identifier: {{ model.run_settings.colocated_db_settings.db_identifier }} + {%- if model.run_settings.colocated_fs_settings.fs_identifier %} + Feature Store Identifier: {{ model.run_settings.colocated_fs_settings.fs_identifier }} {%- endif %} - {%- if model.run_settings.colocated_db_settings.port %} + {%- if model.run_settings.colocated_fs_settings.port %} Connection Type: TCP TCP/IP Port(s): - {{ model.run_settings.colocated_db_settings.port }} + {{ model.run_settings.colocated_fs_settings.port }} {%- endif %} - {%- if model.run_settings.colocated_db_settings.unix_socket %} + {%- if model.run_settings.colocated_fs_settings.unix_socket %} Connection Type: UDS - Unix Socket: {{ model.run_settings.colocated_db_settings.unix_socket }} + Unix Socket: {{ model.run_settings.colocated_fs_settings.unix_socket }} {%- endif %} - {%- if model.run_settings.colocated_db_settings.ifname %} - {%- if model.run_settings.colocated_db_settings.ifname | is_list %} - Network Interface Name: {{ model.run_settings.colocated_db_settings.ifname[0] }} + {%- if model.run_settings.colocated_fs_settings.ifname %} + {%- if model.run_settings.colocated_fs_settings.ifname | is_list %} + Network Interface Name: {{ model.run_settings.colocated_fs_settings.ifname[0] }} {%- else %} - Network Interface Name: {{ model.run_settings.colocated_db_settings.ifname }} + Network Interface Name: {{ model.run_settings.colocated_fs_settings.ifname }} {%- endif %} {%- endif %} - CPUs: {{ model.run_settings.colocated_db_settings.cpus }} - Custom Pinning: {{ model.run_settings.colocated_db_settings.custom_pinning }} + CPUs: {{ model.run_settings.colocated_fs_settings.cpus }} + Custom Pinning: {{ model.run_settings.colocated_fs_settings.custom_pinning }} {%- endif %} - {%- if model._db_scripts %} + {%- if model._fs_scripts %} Torch Scripts: - {%- for script in model._db_scripts%} + {%- for script in model._fs_scripts%} Name: {{ script.name }} Path: {{ script.file }} Backend: {{ script.device }} Devices Per Node: {{ script.devices_per_node }} {%- endfor %} {%- endif %} - {%- if model._db_models %} + {%- if model._fs_models %} ML Models: - {%- for mlmodel in model._db_models %} + {%- for mlmodel in model._fs_models %} Name: {{ mlmodel.name }} Path: {{ mlmodel.file }} Backend: {{ mlmodel.backend }} diff --git a/smartsim/templates/templates/preview/plain_text/model_info.template b/smartsim/templates/templates/preview/plain_text/model_info.template index f746208e53..dc961ae95e 100644 --- a/smartsim/templates/templates/preview/plain_text/model_info.template +++ b/smartsim/templates/templates/preview/plain_text/model_info.template @@ -10,32 +10,32 @@ {%- endfor %} {%- endif %} - {%- if model.run_settings.colocated_db_settings %} + {%- if model.run_settings.colocated_fs_settings %} Colocated: - {%- if model.run_settings.colocated_db_settings.db_identifier %} - Database Identifier: {{ model.run_settings.colocated_db_settings.db_identifier }} + {%- if model.run_settings.colocated_fs_settings.fs_identifier %} + Feature Store Identifier: {{ model.run_settings.colocated_fs_settings.fs_identifier }} {%- endif %} - {%- if model.run_settings.colocated_db_settings.port %} + {%- if model.run_settings.colocated_fs_settings.port %} Connection Type: TCP TCP/IP Port(s): - {{ model.run_settings.colocated_db_settings.port }} + {{ model.run_settings.colocated_fs_settings.port }} {%- endif %} - {%- if model.run_settings.colocated_db_settings.unix_socket %} + {%- if model.run_settings.colocated_fs_settings.unix_socket %} Connection Type: UDS - Unix Socket: {{ model.run_settings.colocated_db_settings.unix_socket }} + Unix Socket: {{ model.run_settings.colocated_fs_settings.unix_socket }} {%- endif %} {%- endif %} - {%- if model.run_settings.colocated_db_settings['db_scripts'] %} + {%- if model.run_settings.colocated_fs_settings['fs_scripts'] %} Torch Scripts: - {%- for script in model.run_settings.colocated_db_settings['db_scripts'] %} + {%- for script in model.run_settings.colocated_fs_settings['fs_scripts'] %} Name: {{ script.name }} Path: {{ script.script_path }} {%- endfor %} {%- endif %} - {%- if model.run_settings.colocated_db_settings['db_models'] %} + {%- if model.run_settings.colocated_fs_settings['fs_models'] %} ML Models: - {%- for mlmodel in model.run_settings.colocated_db_settings['db_models'] %} + {%- for mlmodel in model.run_settings.colocated_fs_settings['fs_models'] %} Name: {{ mlmodel.name }} Path: {{ mlmodel.model_file }} Backend: {{ mlmodel.backend }} diff --git a/smartsim/templates/templates/preview/plain_text/orchestrator_debug.template b/smartsim/templates/templates/preview/plain_text/orchestrator_debug.template index 127a4949e4..8dfa6ae9a8 100644 --- a/smartsim/templates/templates/preview/plain_text/orchestrator_debug.template +++ b/smartsim/templates/templates/preview/plain_text/orchestrator_debug.template @@ -1,33 +1,33 @@ - = Database Identifier: {{ db.name }} = - {%- if db.path %} - Path: {{ db.path }} + = Feature Store Identifier: {{ fs.name }} = + {%- if fs.path %} + Path: {{ fs.path }} {%- endif %} - Shards: {{ db.num_shards }} + Shards: {{ fs.num_shards }} TCP/IP Port(s): - {%- for port in db.ports %} + {%- for port in fs.ports %} {{ port }} {%- endfor %} - Network Interface: {{ db._interfaces[0] }} - Type: {{ config.database_cli | get_dbtype }} + Network Interface: {{ fs._interfaces[0] }} + Type: {{ config.database_cli | get_fstype }} Executable: {{ config.database_exe }} - {%- if db.run_settings %} - Run Command: {{ db.run_settings.run_command }} - {%- if db.run_settings.run_args %} + {%- if fs.run_settings %} + Run Command: {{ fs.run_settings.run_command }} + {%- if fs.run_settings.run_args %} Run Arguments: - {%- for key, value in db.run_settings.run_args.items() %} + {%- for key, value in fs.run_settings.run_args.items() %} {{ key }}: {{ value }} {%- endfor %} {%- endif %} {%- endif %} - {%- if db.run_command %} - Run Command: {{ db.run_command }} + {%- if fs.run_command %} + Run Command: {{ fs.run_command }} {%- endif %} - {%- if db.batch_settings %} + {%- if fs.batch_settings %} Batch Launch: True - Batch Command: {{ db.batch_settings.batch_cmd }} + Batch Command: {{ fs.batch_settings.batch_cmd }} Batch Arguments: - {%- for key, value in db.batch_settings.batch_args.items() %} + {%- for key, value in fs.batch_settings.batch_args.items() %} {{ key }}: {{ value }} {%- endfor %} {%- endif %} diff --git a/smartsim/templates/templates/preview/plain_text/orchestrator_info.template b/smartsim/templates/templates/preview/plain_text/orchestrator_info.template index 11608d6c51..7964d126e3 100644 --- a/smartsim/templates/templates/preview/plain_text/orchestrator_info.template +++ b/smartsim/templates/templates/preview/plain_text/orchestrator_info.template @@ -1,11 +1,11 @@ - = Database Identifier: {{ db.name }} = + = Feature Store Identifier: {{ fs.name }} = TCP/IP Port(s): - {%- for port in db.ports %} + {%- for port in fs.ports %} {{ port }} {%- endfor %} - Network Interface: {{ db._interfaces[0] }} - Type: {{ config.database_cli | get_dbtype }} - {%- if db.batch %} - Batch Launch: {{ db.batch }} + Network Interface: {{ fs._interfaces[0] }} + Type: {{ config.database_cli | get_fstype }} + {%- if fs.batch %} + Batch Launch: {{ fs.batch }} {%- endif %} diff --git a/smartsim/types.py b/smartsim/types.py new file mode 100644 index 0000000000..f756fc6fe2 --- /dev/null +++ b/smartsim/types.py @@ -0,0 +1,32 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import typing as t + +TODO = t.Any # TODO: remove this after refactor +LaunchedJobID = t.NewType("LaunchedJobID", str) diff --git a/smartsim/wlm/pbs.py b/smartsim/wlm/pbs.py index 5b559c1e6b..62f5a69a08 100644 --- a/smartsim/wlm/pbs.py +++ b/smartsim/wlm/pbs.py @@ -31,7 +31,7 @@ from smartsim.error.errors import LauncherError, SmartSimError -from .._core.launcher.pbs.pbsCommands import qstat +from .._core.launcher.pbs.pbs_commands import qstat def get_hosts() -> t.List[str]: diff --git a/smartsim/wlm/slurm.py b/smartsim/wlm/slurm.py index ae7299f28b..e1b24b906d 100644 --- a/smartsim/wlm/slurm.py +++ b/smartsim/wlm/slurm.py @@ -24,13 +24,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import datetime import os import typing as t from shutil import which -from .._core.launcher.slurm.slurmCommands import salloc, scancel, scontrol, sinfo -from .._core.launcher.slurm.slurmParser import parse_salloc, parse_salloc_error -from .._core.launcher.util.launcherUtil import ComputeNode, Partition +from .._core.launcher.slurm.slurm_commands import salloc, scancel, scontrol, sinfo +from .._core.launcher.slurm.slurm_parser import parse_salloc, parse_salloc_error +from .._core.launcher.util.launcher_util import ComputeNode, Partition from ..error import ( AllocationError, LauncherError, @@ -38,7 +39,21 @@ SSReservedKeywordError, ) from ..log import get_logger -from ..settings.slurmSettings import fmt_walltime + + +def fmt_walltime(hours: int, minutes: int, seconds: int) -> str: + """Helper function walltime format conversion + + Converts time to format HH:MM:SS + + :param hours: number of hours to run job + :param minutes: number of minutes to run job + :param seconds: number of seconds to run job + :returns: Formatted walltime + """ + delta = datetime.timedelta(hours=hours, minutes=minutes, seconds=seconds) + return f"0{delta}" if delta.seconds // 3600 < 10 else str(delta) + logger = get_logger(__name__) diff --git a/tests/_legacy/__init__.py b/tests/_legacy/__init__.py new file mode 100644 index 0000000000..efe03908e0 --- /dev/null +++ b/tests/_legacy/__init__.py @@ -0,0 +1,25 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tests/backends/run_sklearn_onnx.py b/tests/_legacy/backends/run_sklearn_onnx.py similarity index 99% rename from tests/backends/run_sklearn_onnx.py rename to tests/_legacy/backends/run_sklearn_onnx.py index f10c8c7fb1..77683ee902 100644 --- a/tests/backends/run_sklearn_onnx.py +++ b/tests/_legacy/backends/run_sklearn_onnx.py @@ -75,7 +75,7 @@ def run_model(client, model_name, device, model, model_input, in_name, out_names def run(device): - # connect a client to the database + # connect a client to the feature store client = Client(cluster=False) # linreg test diff --git a/tests/backends/run_tf.py b/tests/_legacy/backends/run_tf.py similarity index 100% rename from tests/backends/run_tf.py rename to tests/_legacy/backends/run_tf.py diff --git a/tests/backends/run_torch.py b/tests/_legacy/backends/run_torch.py similarity index 97% rename from tests/backends/run_torch.py rename to tests/_legacy/backends/run_torch.py index b3c0fc9649..1071e740ef 100644 --- a/tests/backends/run_torch.py +++ b/tests/_legacy/backends/run_torch.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import io +import typing as t import numpy as np import torch @@ -74,8 +75,8 @@ def calc_svd(input_tensor): return input_tensor.svd() -def run(device, num_devices): - # connect a client to the database +def run(device: str, num_devices: int) -> t.Any: + # connect a client to the feature store client = Client(cluster=False) # test the SVD function diff --git a/tests/backends/test_cli_mini_exp.py b/tests/_legacy/backends/test_cli_mini_exp.py similarity index 87% rename from tests/backends/test_cli_mini_exp.py rename to tests/_legacy/backends/test_cli_mini_exp.py index 3379bf2eec..83ecfc5b07 100644 --- a/tests/backends/test_cli_mini_exp.py +++ b/tests/_legacy/backends/test_cli_mini_exp.py @@ -33,7 +33,6 @@ import smartsim._core._cli.validate import smartsim._core._install.builder as build from smartsim._core._install.platform import Device -from smartsim._core.utils.helpers import installed_redisai_backends sklearn_available = True try: @@ -49,8 +48,8 @@ def test_cli_mini_exp_doesnt_error_out_with_dev_build( - prepare_db, - local_db, + prepare_fs, + local_fs, test_dir, monkeypatch, ): @@ -59,26 +58,26 @@ def test_cli_mini_exp_doesnt_error_out_with_dev_build( to ensure that it does not accidentally report false positive/negatives """ - db = prepare_db(local_db).orchestrator + fs = prepare_fs(local_fs).featurestore @contextmanager - def _mock_make_managed_local_orc(*a, **kw): - (client_addr,) = db.get_address() + def _mock_make_managed_local_feature_store(*a, **kw): + (client_addr,) = fs.get_address() yield smartredis.Client(False, address=client_addr) monkeypatch.setattr( smartsim._core._cli.validate, - "_make_managed_local_orc", - _mock_make_managed_local_orc, + "_make_managed_local_feature_store", + _mock_make_managed_local_feature_store, ) - backends = installed_redisai_backends() - (db_port,) = db.ports + backends = [] # todo: update test to replace installed_redisai_backends() + (fs_port,) = fs.ports smartsim._core._cli.validate.test_install( # Shouldn't matter bc we are stubbing creation of orc # but best to give it "correct" vals for safety location=test_dir, - port=db_port, + port=fs_port, # Always test on CPU, heads don't always have GPU device=Device.CPU, # Test the backends the dev has installed diff --git a/tests/backends/test_dataloader.py b/tests/_legacy/backends/test_dataloader.py similarity index 90% rename from tests/backends/test_dataloader.py rename to tests/_legacy/backends/test_dataloader.py index de4bf6d8e3..4774841eaa 100644 --- a/tests/backends/test_dataloader.py +++ b/tests/_legacy/backends/test_dataloader.py @@ -30,12 +30,12 @@ import numpy as np import pytest -from smartsim.database import Orchestrator +from smartsim.database import FeatureStore from smartsim.error.errors import SSInternalError from smartsim.experiment import Experiment from smartsim.log import get_logger from smartsim.ml.data import DataInfo, TrainingDataUploader -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus logger = get_logger(__name__) @@ -167,11 +167,11 @@ def train_tf(generator): @pytest.mark.skipif(not shouldrun_tf, reason="Test needs TensorFlow to run") -def test_tf_dataloaders(wlm_experiment, prepare_db, single_db, monkeypatch): +def test_tf_dataloaders(wlm_experiment, prepare_fs, single_fs, monkeypatch): - db = prepare_db(single_db).orchestrator - orc = wlm_experiment.reconnect_orchestrator(db.checkpoint_file) - monkeypatch.setenv("SSDB", orc.get_address()[0]) + fs = prepare_fs(single_fs).featurestore + feature_store = wlm_experiment.reconnect_feature_store(fs.checkpoint_file) + monkeypatch.setenv("SSDB", feature_store.get_address()[0]) monkeypatch.setenv("SSKEYIN", "test_uploader_0,test_uploader_1") try: @@ -218,7 +218,7 @@ def create_trainer_torch(experiment: Experiment, filedir, wlmutils): args=["training_service_torch.py"], ) - trainer = experiment.create_model("trainer", run_settings=run_settings) + trainer = experiment.create_application("trainer", run_settings=run_settings) trainer.attach_generator_files( to_copy=[osp.join(filedir, "training_service_torch.py")] @@ -229,12 +229,12 @@ def create_trainer_torch(experiment: Experiment, filedir, wlmutils): @pytest.mark.skipif(not shouldrun_torch, reason="Test needs Torch to run") def test_torch_dataloaders( - wlm_experiment, prepare_db, single_db, fileutils, test_dir, wlmutils, monkeypatch + wlm_experiment, prepare_fs, single_fs, fileutils, test_dir, wlmutils, monkeypatch ): config_dir = fileutils.get_test_dir_path("ml") - db = prepare_db(single_db).orchestrator - orc = wlm_experiment.reconnect_orchestrator(db.checkpoint_file) - monkeypatch.setenv("SSDB", orc.get_address()[0]) + fs = prepare_fs(single_fs).orchestrator + feature_store = wlm_experiment.reconnect_feature_store(fs.checkpoint_file) + monkeypatch.setenv("SSDB", feature_store.get_address()[0]) monkeypatch.setenv("SSKEYIN", "test_uploader_0,test_uploader_1") try: @@ -283,7 +283,7 @@ def test_torch_dataloaders( trainer = create_trainer_torch(wlm_experiment, config_dir, wlmutils) wlm_experiment.start(trainer, block=True) - assert wlm_experiment.get_status(trainer)[0] == SmartSimStatus.STATUS_COMPLETED + assert wlm_experiment.get_status(trainer)[0] == JobStatus.COMPLETED except Exception as e: raise e @@ -320,22 +320,22 @@ def test_data_info_repr(): @pytest.mark.skipif( not (shouldrun_torch or shouldrun_tf), reason="Requires TF or PyTorch" ) -def test_wrong_dataloaders(wlm_experiment, prepare_db, single_db): - db = prepare_db(single_db).orchestrator - orc = wlm_experiment.reconnect_orchestrator(db.checkpoint_file) +def test_wrong_dataloaders(wlm_experiment, prepare_fs, single_fs): + fs = prepare_fs(single_fs).featurestore + feature_store = wlm_experiment.reconnect_feature_store(fs.checkpoint_file) if shouldrun_tf: with pytest.raises(SSInternalError): _ = TFDataGenerator( data_info_or_list_name="test_data_list", - address=orc.get_address()[0], + address=feature_store.get_address()[0], cluster=False, max_fetch_trials=1, ) with pytest.raises(TypeError): _ = TFStaticDataGenerator( test_data_info_repr=1, - address=orc.get_address()[0], + address=feature_store.get_address()[0], cluster=False, max_fetch_trials=1, ) @@ -344,7 +344,7 @@ def test_wrong_dataloaders(wlm_experiment, prepare_db, single_db): with pytest.raises(SSInternalError): torch_data_gen = TorchDataGenerator( data_info_or_list_name="test_data_list", - address=orc.get_address()[0], + address=feature_store.get_address()[0], cluster=False, ) torch_data_gen.init_samples(init_trials=1) diff --git a/tests/backends/test_dbmodel.py b/tests/_legacy/backends/test_dbmodel.py similarity index 81% rename from tests/backends/test_dbmodel.py rename to tests/_legacy/backends/test_dbmodel.py index 6155b6884c..da495004fa 100644 --- a/tests/backends/test_dbmodel.py +++ b/tests/_legacy/backends/test_dbmodel.py @@ -30,12 +30,11 @@ import pytest from smartsim import Experiment -from smartsim._core.utils import installed_redisai_backends from smartsim.entity import Ensemble -from smartsim.entity.dbobject import DBModel +from smartsim.entity.dbobject import FSModel from smartsim.error.errors import SSUnsupportedError from smartsim.log import get_logger -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus logger = get_logger(__name__) @@ -70,7 +69,9 @@ def call(self, x): except: logger.warning("Could not set TF max memory limit for GPU") -should_run_tf &= "tensorflow" in installed_redisai_backends() +should_run_tf &= ( + "tensorflow" in [] +) # todo: update test to replace installed_redisai_backends() # Check if PyTorch is available for tests try: @@ -107,7 +108,9 @@ def forward(self, x): return output -should_run_pt &= "torch" in installed_redisai_backends() +should_run_pt &= ( + "torch" in [] +) # todo: update test to replace installed_redisai_backends() def save_tf_cnn(path, file_name): @@ -146,10 +149,10 @@ def save_torch_cnn(path, file_name): @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -def test_tf_db_model( - wlm_experiment, prepare_db, single_db, fileutils, test_dir, mlutils +def test_tf_fs_model( + wlm_experiment, prepare_fs, single_fs, fileutils, test_dir, mlutils ): - """Test TensorFlow DB Models on remote DB""" + """Test TensorFlow FS Models on remote FS""" # Retrieve parameters from testing environment test_device = mlutils.get_test_device() @@ -165,11 +168,11 @@ def test_tf_db_model( run_settings.set_tasks(1) # Create Model - smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) + smartsim_model = wlm_experiment.create_application("smartsim_model", run_settings) - # Create database - db = prepare_db(single_db).orchestrator - wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + # Create feature store + fs = prepare_fs(single_fs).featurestore + wlm_experiment.reconnect_feature_store(fs.checkpoint_file) # Create and save ML model to filesystem model, inputs, outputs = create_tf_cnn() @@ -200,11 +203,11 @@ def test_tf_db_model( ) logger.debug("The following ML models have been added:") - for db_model in smartsim_model._db_models: - logger.debug(db_model) + for fs_model in smartsim_model._fs_models: + logger.debug(fs_model) # Assert we have added both models - assert len(smartsim_model._db_models) == 2 + assert len(smartsim_model._fs_models) == 2 wlm_experiment.generate(smartsim_model) @@ -212,15 +215,15 @@ def test_tf_db_model( wlm_experiment.start(smartsim_model, block=True) statuses = wlm_experiment.get_status(smartsim_model) assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + stat == JobStatus.COMPLETED for stat in statuses ), f"Statuses: {statuses}" @pytest.mark.skipif(not should_run_pt, reason="Test needs PyTorch to run") -def test_pt_db_model( - wlm_experiment, prepare_db, single_db, fileutils, test_dir, mlutils +def test_pt_fs_model( + wlm_experiment, prepare_fs, single_fs, fileutils, test_dir, mlutils ): - """Test PyTorch DB Models on remote DB""" + """Test PyTorch FS Models on remote FS""" # Retrieve parameters from testing environment test_device = mlutils.get_test_device() @@ -236,11 +239,11 @@ def test_pt_db_model( run_settings.set_tasks(1) # Create Model - smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) + smartsim_model = wlm_experiment.create_applicationl("smartsim_model", run_settings) - # Create database - db = prepare_db(single_db).orchestrator - wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + # Create feature store + fs = prepare_fs(single_fs).featurestore + wlm_experiment.reconnect_feature_store(fs.checkpoint_file) # Create and save ML model to filesystem save_torch_cnn(test_dir, "model1.pt") @@ -258,11 +261,11 @@ def test_pt_db_model( ) logger.debug("The following ML models have been added:") - for db_model in smartsim_model._db_models: - logger.debug(db_model) + for fs_model in smartsim_model._fs_models: + logger.debug(fs_model) # Assert we have added both models - assert len(smartsim_model._db_models) == 1 + assert len(smartsim_model._fs_models) == 1 wlm_experiment.generate(smartsim_model) @@ -270,15 +273,15 @@ def test_pt_db_model( wlm_experiment.start(smartsim_model, block=True) statuses = wlm_experiment.get_status(smartsim_model) assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + stat == JobStatus.COMPLETED for stat in statuses ), f"Statuses: {statuses}" @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -def test_db_model_ensemble( - wlm_experiment, prepare_db, single_db, fileutils, test_dir, wlmutils, mlutils +def test_fs_model_ensemble( + wlm_experiment, prepare_fs, single_fs, fileutils, test_dir, wlmutils, mlutils ): - """Test DBModels on remote DB, with an ensemble""" + """Test FSModels on remote FS, with an ensemble""" # Retrieve parameters from testing environment test_device = mlutils.get_test_device() @@ -299,11 +302,11 @@ def test_db_model_ensemble( ) # Create Model - smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) + smartsim_model = wlm_experiment.create_application("smartsim_model", run_settings) - # Create database - db = prepare_db(single_db).orchestrator - wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + # Create feature store + fs = prepare_fs(single_fs).featurestore + wlm_experiment.reconnect_feature_store(fs.checkpoint_file) # Create and save ML model to filesystem model, inputs, outputs = create_tf_cnn() @@ -336,7 +339,7 @@ def test_db_model_ensemble( ) # Add new ensemble member - smartsim_ensemble.add_model(smartsim_model) + smartsim_ensemble.add_application(smartsim_model) # Add the second ML model to the newly added entity. This is # because the test script runs both ML models for all entities. @@ -352,9 +355,9 @@ def test_db_model_ensemble( ) # Assert we have added one model to the ensemble - assert len(smartsim_ensemble._db_models) == 1 + assert len(smartsim_ensemble._fs_models) == 1 # Assert we have added two models to each entity - assert all([len(entity._db_models) == 2 for entity in smartsim_ensemble]) + assert all([len(entity._fs_models) == 2 for entity in smartsim_ensemble]) wlm_experiment.generate(smartsim_ensemble) @@ -362,16 +365,16 @@ def test_db_model_ensemble( wlm_experiment.start(smartsim_ensemble, block=True) statuses = wlm_experiment.get_status(smartsim_ensemble) assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + stat == JobStatus.COMPLETED for stat in statuses ), f"Statuses: {statuses}" @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -def test_colocated_db_model_tf(fileutils, test_dir, wlmutils, mlutils): - """Test DB Models on colocated DB (TensorFlow backend)""" +def test_colocated_fs_model_tf(fileutils, test_dir, wlmutils, mlutils): + """Test fs Models on colocated fs (TensorFlow backend)""" # Set experiment name - exp_name = "test-colocated-db-model-tf" + exp_name = "test-colocated-fs-model-tf" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -391,9 +394,9 @@ def test_colocated_db_model_tf(fileutils, test_dir, wlmutils, mlutils): colo_settings.set_tasks(1) # Create colocated Model - colo_model = exp.create_model("colocated_model", colo_settings) - colo_model.colocate_db_tcp( - port=test_port, db_cpus=1, debug=True, ifname=test_interface + colo_model = exp.create_application("colocated_model", colo_settings) + colo_model.colocate_fs_tcp( + port=test_port, fs_cpus=1, debug=True, ifname=test_interface ) # Create and save ML model to filesystem @@ -423,7 +426,7 @@ def test_colocated_db_model_tf(fileutils, test_dir, wlmutils, mlutils): ) # Assert we have added both models - assert len(colo_model._db_models) == 2 + assert len(colo_model._fs_models) == 2 exp.generate(colo_model) @@ -432,18 +435,18 @@ def test_colocated_db_model_tf(fileutils, test_dir, wlmutils, mlutils): exp.start(colo_model, block=True) statuses = exp.get_status(colo_model) assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + stat == JobStatus.COMPLETED for stat in statuses ), f"Statuses: {statuses}" finally: exp.stop(colo_model) @pytest.mark.skipif(not should_run_pt, reason="Test needs PyTorch to run") -def test_colocated_db_model_pytorch(fileutils, test_dir, wlmutils, mlutils): - """Test DB Models on colocated DB (PyTorch backend)""" +def test_colocated_fs_model_pytorch(fileutils, test_dir, wlmutils, mlutils): + """Test fs Models on colocated fs (PyTorch backend)""" # Set experiment name - exp_name = "test-colocated-db-model-pytorch" + exp_name = "test-colocated-fs-model-pytorch" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -463,9 +466,9 @@ def test_colocated_db_model_pytorch(fileutils, test_dir, wlmutils, mlutils): colo_settings.set_tasks(1) # Create colocated SmartSim Model - colo_model = exp.create_model("colocated_model", colo_settings) - colo_model.colocate_db_tcp( - port=test_port, db_cpus=1, debug=True, ifname=test_interface + colo_model = exp.create_application("colocated_model", colo_settings) + colo_model.colocate_fs_tcp( + port=test_port, fs_cpus=1, debug=True, ifname=test_interface ) # Create and save ML model to filesystem @@ -483,7 +486,7 @@ def test_colocated_db_model_pytorch(fileutils, test_dir, wlmutils, mlutils): ) # Assert we have added both models - assert len(colo_model._db_models) == 1 + assert len(colo_model._fs_models) == 1 exp.generate(colo_model) @@ -492,20 +495,20 @@ def test_colocated_db_model_pytorch(fileutils, test_dir, wlmutils, mlutils): exp.start(colo_model, block=True) statuses = exp.get_status(colo_model) assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + stat == JobStatus.COMPLETED for stat in statuses ), f"Statuses: {statuses}" finally: exp.stop(colo_model) @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -def test_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlutils): - """Test DBModel on colocated ensembles, first colocating DB, - then adding DBModel. +def test_colocated_fs_model_ensemble(fileutils, test_dir, wlmutils, mlutils): + """Test fsModel on colocated ensembles, first colocating fs, + then adding fsModel. """ # Set experiment name - exp_name = "test-colocated-db-model-ensemble" + exp_name = "test-colocated-fs-model-ensemble" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -529,20 +532,20 @@ def test_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlutils): "colocated_ens", run_settings=colo_settings, replicas=2 ) - # Create a third model with a colocated database - colo_model = exp.create_model("colocated_model", colo_settings) - colo_model.colocate_db_tcp( - port=test_port, db_cpus=1, debug=True, ifname=test_interface + # Create a third model with a colocated feature store + colo_model = exp.create_application("colocated_model", colo_settings) + colo_model.colocate_fs_tcp( + port=test_port, fs_cpus=1, debug=True, ifname=test_interface ) # Create and save the ML models to the filesystem model_file, inputs, outputs = save_tf_cnn(test_dir, "model1.pb") model_file2, inputs2, outputs2 = save_tf_cnn(test_dir, "model2.pb") - # Colocate a database with the ensemble with two ensemble members + # Colocate a feature store with the ensemble with two ensemble members for i, entity in enumerate(colo_ensemble): - entity.colocate_db_tcp( - port=test_port + i + 1, db_cpus=1, debug=True, ifname=test_interface + entity.colocate_fs_tcp( + port=test_port + i + 1, fs_cpus=1, debug=True, ifname=test_interface ) # Add ML model to each ensemble member individual to test that they # do not conflict with models add to the Ensemble object @@ -572,7 +575,7 @@ def test_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlutils): ) # Add a new model to the ensemble - colo_ensemble.add_model(colo_model) + colo_ensemble.add_application(colo_model) # Add the ML model to SmartSim Model just added to the ensemble colo_model.add_ml_model( @@ -593,20 +596,20 @@ def test_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlutils): exp.start(colo_ensemble, block=True) statuses = exp.get_status(colo_ensemble) assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + stat == JobStatus.COMPLETED for stat in statuses ), f"Statuses: {statuses}" finally: exp.stop(colo_ensemble) @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -def test_colocated_db_model_ensemble_reordered(fileutils, test_dir, wlmutils, mlutils): - """Test DBModel on colocated ensembles, first adding the DBModel to the - ensemble, then colocating DB. +def test_colocated_fs_model_ensemble_reordered(fileutils, test_dir, wlmutils, mlutils): + """Test fsModel on colocated ensembles, first adding the fsModel to the + ensemble, then colocating fs. """ # Set experiment name - exp_name = "test-colocated-db-model-ensemble-reordered" + exp_name = "test-colocated-fs-model-ensemble-reordered" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -631,7 +634,7 @@ def test_colocated_db_model_ensemble_reordered(fileutils, test_dir, wlmutils, ml ) # Create colocated SmartSim Model - colo_model = exp.create_model("colocated_model", colo_settings) + colo_model = exp.create_application("colocated_model", colo_settings) # Create and save ML model to filesystem model_file, inputs, outputs = save_tf_cnn(test_dir, "model1.pb") @@ -649,10 +652,10 @@ def test_colocated_db_model_ensemble_reordered(fileutils, test_dir, wlmutils, ml outputs=outputs, ) - # Colocate a database with the first ensemble members + # Colocate a feature store with the first ensemble members for i, entity in enumerate(colo_ensemble): - entity.colocate_db_tcp( - port=test_port + i, db_cpus=1, debug=True, ifname=test_interface + entity.colocate_fs_tcp( + port=test_port + i, fs_cpus=1, debug=True, ifname=test_interface ) # Add ML models to each ensemble member to make sure they # do not conflict with other ML models @@ -669,12 +672,12 @@ def test_colocated_db_model_ensemble_reordered(fileutils, test_dir, wlmutils, ml entity.disable_key_prefixing() # Add another ensemble member - colo_ensemble.add_model(colo_model) + colo_ensemble.add_application(colo_model) - # Colocate a database with the new ensemble member - colo_model.colocate_db_tcp( + # Colocate a feature store with the new ensemble member + colo_model.colocate_fs_tcp( port=test_port + len(colo_ensemble) - 1, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) @@ -697,18 +700,18 @@ def test_colocated_db_model_ensemble_reordered(fileutils, test_dir, wlmutils, ml exp.start(colo_ensemble, block=True) statuses = exp.get_status(colo_ensemble) assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + stat == JobStatus.COMPLETED for stat in statuses ), f"Statuses: {statuses}" finally: exp.stop(colo_ensemble) @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -def test_colocated_db_model_errors(fileutils, test_dir, wlmutils, mlutils): - """Test error when colocated db model has no file.""" +def test_colocated_fs_model_errors(fileutils, test_dir, wlmutils, mlutils): + """Test error when colocated fs model has no file.""" # Set experiment name - exp_name = "test-colocated-db-model-error" + exp_name = "test-colocated-fs-model-error" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -728,9 +731,9 @@ def test_colocated_db_model_errors(fileutils, test_dir, wlmutils, mlutils): colo_settings.set_tasks(1) # Create colocated SmartSim Model - colo_model = exp.create_model("colocated_model", colo_settings) - colo_model.colocate_db_tcp( - port=test_port, db_cpus=1, debug=True, ifname=test_interface + colo_model = exp.create_application("colocated_model", colo_settings) + colo_model.colocate_fs_tcp( + port=test_port, fs_cpus=1, debug=True, ifname=test_interface ) # Get and save TF model @@ -755,10 +758,10 @@ def test_colocated_db_model_errors(fileutils, test_dir, wlmutils, mlutils): "colocated_ens", run_settings=colo_settings, replicas=2 ) - # Colocate a db with each ensemble member + # Colocate a fs with each ensemble member for i, entity in enumerate(colo_ensemble): - entity.colocate_db_tcp( - port=test_port + i, db_cpus=1, debug=True, ifname=test_interface + entity.colocate_fs_tcp( + port=test_port + i, fs_cpus=1, debug=True, ifname=test_interface ) # Check that an error is raised because in-memory models @@ -777,11 +780,11 @@ def test_colocated_db_model_errors(fileutils, test_dir, wlmutils, mlutils): # Check error is still thrown if an in-memory model is used # with a colocated deployment. This test varies by adding - # the SmartSIm model with a colocated database to the ensemble + # the SmartSIm model with a colocated feature store to the ensemble # after the ML model was been added to the ensemble. colo_settings2 = exp.create_run_settings(exe=sys.executable, exe_args=test_script) - # Reverse order of DBModel and model + # Reverse order of fsModel and model colo_ensemble2 = exp.create_ensemble( "colocated_ens", run_settings=colo_settings2, replicas=2 ) @@ -797,25 +800,25 @@ def test_colocated_db_model_errors(fileutils, test_dir, wlmutils, mlutils): ) for i, entity in enumerate(colo_ensemble2): with pytest.raises(SSUnsupportedError): - entity.colocate_db_tcp( + entity.colocate_fs_tcp( port=test_port + i, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) with pytest.raises(SSUnsupportedError): - colo_ensemble.add_model(colo_model) + colo_ensemble.add_application(colo_model) @pytest.mark.skipif(not should_run_tf, reason="Test needs TensorFlow to run") -def test_inconsistent_params_db_model(): - """Test error when devices_per_node parameter>1 when devices is set to CPU in DBModel""" +def test_inconsistent_params_fs_model(): + """Test error when devices_per_node parameter>1 when devices is set to CPU in fsModel""" # Create and save ML model to filesystem model, inputs, outputs = create_tf_cnn() with pytest.raises(SSUnsupportedError) as ex: - DBModel( + FSModel( "cnn", "TF", model=model, @@ -833,11 +836,11 @@ def test_inconsistent_params_db_model(): @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -def test_db_model_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): - """Test DBModels on remote DB, with an ensemble""" +def test_fs_model_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): + """Test fsModels on remote fs, with an ensemble""" # Set experiment name - exp_name = "test-db-model-ensemble-duplicate" + exp_name = "test-fs-model-ensemble-duplicate" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -862,7 +865,7 @@ def test_db_model_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): ) # Create Model - smartsim_model = exp.create_model("smartsim_model", run_settings) + smartsim_model = exp.create_application("smartsim_model", run_settings) # Create and save ML model to filesystem model, inputs, outputs = create_tf_cnn() @@ -906,7 +909,7 @@ def test_db_model_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): outputs=outputs2, ) - # Attempt to add a duplicate ML model to Ensemble via Ensemble.add_model() + # Attempt to add a duplicate ML model to Ensemble via Ensemble.add_application() with pytest.raises(SSUnsupportedError) as ex: - smartsim_ensemble.add_model(smartsim_model) + smartsim_ensemble.add_application(smartsim_model) assert ex.value.args[0] == 'An ML Model with name "cnn" already exists' diff --git a/tests/backends/test_dbscript.py b/tests/_legacy/backends/test_dbscript.py similarity index 71% rename from tests/backends/test_dbscript.py rename to tests/_legacy/backends/test_dbscript.py index 2c04bf5db0..ec6e2f861c 100644 --- a/tests/backends/test_dbscript.py +++ b/tests/_legacy/backends/test_dbscript.py @@ -24,32 +24,29 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import os import sys import pytest from smartredis import * from smartsim import Experiment -from smartsim._core.utils import installed_redisai_backends -from smartsim.entity.dbobject import DBScript +from smartsim.entity.dbobject import FSScript from smartsim.error.errors import SSUnsupportedError from smartsim.log import get_logger -from smartsim.settings import MpiexecSettings, MpirunSettings -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus logger = get_logger(__name__) should_run = True -supported_dbs = ["uds", "tcp"] +supported_fss = ["uds", "tcp"] try: import torch except ImportError: should_run = False -should_run &= "torch" in installed_redisai_backends() +should_run &= "torch" in [] # todo: update test to replace installed_redisai_backends() def timestwo(x): @@ -57,8 +54,8 @@ def timestwo(x): @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") -def test_db_script(wlm_experiment, prepare_db, single_db, fileutils, mlutils): - """Test DB scripts on remote DB""" +def test_fs_script(wlm_experiment, prepare_fs, single_fs, fileutils, mlutils): + """Test FS scripts on remote Fs""" test_device = mlutils.get_test_device() test_num_gpus = mlutils.get_test_num_gpus() if pytest.test_device == "GPU" else 1 @@ -73,19 +70,21 @@ def test_db_script(wlm_experiment, prepare_db, single_db, fileutils, mlutils): run_settings.set_nodes(1) run_settings.set_tasks(1) - # Create the SmartSim Model - smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) + # Create the SmartSim Application + smartsim_application = wlm_experiment.create_application( + "smartsim_application", run_settings + ) - # Create the SmartSim database - db = prepare_db(single_db).orchestrator - wlm_experiment.reconnect_orchestrator(db.checkpoint_file) - wlm_experiment.generate(smartsim_model) + # Create the SmartSim feature store + fs = prepare_fs(single_fs).featurestore + wlm_experiment.reconnect_orchestrator(fs.checkpoint_file) + wlm_experiment.generate(smartsim_application) # Define the torch script string torch_script_str = "def negate(x):\n\treturn torch.neg(x)\n" # Add the script via file - smartsim_model.add_script( + smartsim_application.add_script( "test_script1", script_path=torch_script, device=test_device, @@ -94,7 +93,7 @@ def test_db_script(wlm_experiment, prepare_db, single_db, fileutils, mlutils): ) # Add script via string - smartsim_model.add_script( + smartsim_application.add_script( "test_script2", script=torch_script_str, device=test_device, @@ -103,7 +102,7 @@ def test_db_script(wlm_experiment, prepare_db, single_db, fileutils, mlutils): ) # Add script function - smartsim_model.add_function( + smartsim_application.add_function( "test_func", function=timestwo, device=test_device, @@ -112,20 +111,20 @@ def test_db_script(wlm_experiment, prepare_db, single_db, fileutils, mlutils): ) # Assert we have all three scripts - assert len(smartsim_model._db_scripts) == 3 + assert len(smartsim_application._fs_scripts) == 3 # Launch and check successful completion - wlm_experiment.start(smartsim_model, block=True) - statuses = wlm_experiment.get_status(smartsim_model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + wlm_experiment.start(smartsim_application, block=True) + statuses = wlm_experiment.get_status(smartsim_application) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") -def test_db_script_ensemble(wlm_experiment, prepare_db, single_db, fileutils, mlutils): - """Test DB scripts on remote DB""" +def test_fs_script_ensemble(wlm_experiment, prepare_fs, single_fs, fileutils, mlutils): + """Test FS scripts on remote FS""" # Set wlm_experimenteriment name - wlm_experiment_name = "test-db-script" + wlm_experiment_name = "test-fs-script" # Retrieve parameters from testing environment test_device = mlutils.get_test_device() @@ -141,16 +140,18 @@ def test_db_script_ensemble(wlm_experiment, prepare_db, single_db, fileutils, ml run_settings.set_nodes(1) run_settings.set_tasks(1) - db = prepare_db(single_db).orchestrator - wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + fs = prepare_fs(single_fs).featurestore + wlm_experiment.reconnect_feature_store(fs.checkpoint_file) - # Create Ensemble with two identical models + # Create Ensemble with two identical applications ensemble = wlm_experiment.create_ensemble( - "dbscript_ensemble", run_settings=run_settings, replicas=2 + "fsscript_ensemble", run_settings=run_settings, replicas=2 ) - # Create SmartSim model - smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) + # Create SmartSim application + smartsim_application = wlm_experiment.create_application( + "smartsim_application", run_settings + ) # Create the script string torch_script_str = "def negate(x):\n\treturn torch.neg(x)\n" @@ -185,8 +186,8 @@ def test_db_script_ensemble(wlm_experiment, prepare_db, single_db, fileutils, ml ) # Add an additional ensemble member and attach a script to the new member - ensemble.add_model(smartsim_model) - smartsim_model.add_script( + ensemble.add_application(smartsim_application) + smartsim_application.add_script( "test_script2", script=torch_script_str, device=test_device, @@ -195,24 +196,24 @@ def test_db_script_ensemble(wlm_experiment, prepare_db, single_db, fileutils, ml ) # Assert we have added both models to the ensemble - assert len(ensemble._db_scripts) == 2 + assert len(ensemble._fs_scripts) == 2 # Assert we have added all three models to entities in ensemble - assert all([len(entity._db_scripts) == 3 for entity in ensemble]) + assert all([len(entity._fs_scripts) == 3 for entity in ensemble]) wlm_experiment.generate(ensemble) wlm_experiment.start(ensemble, block=True) statuses = wlm_experiment.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") -def test_colocated_db_script(fileutils, test_dir, wlmutils, mlutils): - """Test DB Scripts on colocated DB""" +def test_colocated_fs_script(fileutils, test_dir, wlmutils, mlutils): + """Test fs Scripts on colocated fs""" # Set the experiment name - exp_name = "test-colocated-db-script" + exp_name = "test-colocated-fs-script" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -232,17 +233,17 @@ def test_colocated_db_script(fileutils, test_dir, wlmutils, mlutils): colo_settings.set_nodes(1) colo_settings.set_tasks(1) - # Create model with colocated database - colo_model = exp.create_model("colocated_model", colo_settings) - colo_model.colocate_db_tcp( - port=test_port, db_cpus=1, debug=True, ifname=test_interface + # Create application with colocated feature store + colo_application = exp.create_application("colocated_application", colo_settings) + colo_application.colocate_fs_tcp( + port=test_port, fs_cpus=1, debug=True, ifname=test_interface ) # Create string for script creation torch_script_str = "def negate(x):\n\treturn torch.neg(x)\n" # Add script via file - colo_model.add_script( + colo_application.add_script( "test_script1", script_path=torch_script, device=test_device, @@ -250,7 +251,7 @@ def test_colocated_db_script(fileutils, test_dir, wlmutils, mlutils): first_device=0, ) # Add script via string - colo_model.add_script( + colo_application.add_script( "test_script2", script=torch_script_str, device=test_device, @@ -259,29 +260,29 @@ def test_colocated_db_script(fileutils, test_dir, wlmutils, mlutils): ) # Assert we have added both models - assert len(colo_model._db_scripts) == 2 + assert len(colo_application._fs_scripts) == 2 - exp.generate(colo_model) + exp.generate(colo_application) - for db_script in colo_model._db_scripts: - logger.debug(db_script) + for fs_script in colo_application._fs_scripts: + logger.debug(fs_script) try: - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) finally: - exp.stop(colo_model) + exp.stop(colo_application) @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") -def test_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlutils): - """Test DB Scripts on colocated DB from ensemble, first colocating DB, +def test_colocated_fs_script_ensemble(fileutils, test_dir, wlmutils, mlutils): + """Test fs Scripts on colocated fs from ensemble, first colocating fs, then adding script. """ # Set experiment name - exp_name = "test-colocated-db-script" + exp_name = "test-colocated-fs-script" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -301,21 +302,21 @@ def test_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlutils): colo_settings.set_nodes(1) colo_settings.set_tasks(1) - # Create SmartSim Ensemble with two identical models + # Create SmartSim Ensemble with two identical applications colo_ensemble = exp.create_ensemble( "colocated_ensemble", run_settings=colo_settings, replicas=2 ) - # Create a SmartSim model - colo_model = exp.create_model("colocated_model", colo_settings) + # Create a SmartSim application + colo_application = exp.create_application("colocated_application", colo_settings) - # Colocate a db with each ensemble entity and add a script + # Colocate a fs with each ensemble entity and add a script # to each entity via file for i, entity in enumerate(colo_ensemble): entity.disable_key_prefixing() - entity.colocate_db_tcp( + entity.colocate_fs_tcp( port=test_port + i, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) @@ -328,15 +329,15 @@ def test_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlutils): first_device=0, ) - # Colocate a db with the non-ensemble Model - colo_model.colocate_db_tcp( + # Colocate a feature store with the non-ensemble Application + colo_application.colocate_fs_tcp( port=test_port + len(colo_ensemble), - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) - # Add a script to the non-ensemble model + # Add a script to the non-ensemble application torch_script_str = "def negate(x):\n\treturn torch.neg(x)\n" colo_ensemble.add_script( "test_script2", @@ -346,11 +347,11 @@ def test_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlutils): first_device=0, ) - # Add the third SmartSim model to the ensemble - colo_ensemble.add_model(colo_model) + # Add the third SmartSim application to the ensemble + colo_ensemble.add_application(colo_application) # Add another script via file to the entire ensemble - colo_model.add_script( + colo_application.add_script( "test_script1", script_path=torch_script, device=test_device, @@ -358,10 +359,10 @@ def test_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlutils): first_device=0, ) - # Assert we have added one model to the ensemble - assert len(colo_ensemble._db_scripts) == 1 - # Assert we have added both models to each entity - assert all([len(entity._db_scripts) == 2 for entity in colo_ensemble]) + # Assert we have added one application to the ensemble + assert len(colo_ensemble._fs_scripts) == 1 + # Assert we have added both applications to each entity + assert all([len(entity._fs_scripts) == 2 for entity in colo_ensemble]) exp.generate(colo_ensemble) @@ -369,18 +370,18 @@ def test_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlutils): try: exp.start(colo_ensemble, block=True) statuses = exp.get_status(colo_ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) finally: exp.stop(colo_ensemble) @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") -def test_colocated_db_script_ensemble_reordered(fileutils, test_dir, wlmutils, mlutils): - """Test DB Scripts on colocated DB from ensemble, first adding the - script to the ensemble, then colocating the DB""" +def test_colocated_fs_script_ensemble_reordered(fileutils, test_dir, wlmutils, mlutils): + """Test fs Scripts on colocated fs from ensemble, first adding the + script to the ensemble, then colocating the fs""" # Set Experiment name - exp_name = "test-colocated-db-script-reord" + exp_name = "test-colocated-fs-script-reord" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -400,13 +401,13 @@ def test_colocated_db_script_ensemble_reordered(fileutils, test_dir, wlmutils, m colo_settings.set_nodes(1) colo_settings.set_tasks(1) - # Create Ensemble with two identical SmartSim Model + # Create Ensemble with two identical SmartSim Application colo_ensemble = exp.create_ensemble( "colocated_ensemble", run_settings=colo_settings, replicas=2 ) - # Create an additional SmartSim Model entity - colo_model = exp.create_model("colocated_model", colo_settings) + # Create an additional SmartSim Application entity + colo_application = exp.create_application("colocated_application", colo_settings) # Add a script via string to the ensemble members torch_script_str = "def negate(x):\n\treturn torch.neg(x)\n" @@ -418,13 +419,13 @@ def test_colocated_db_script_ensemble_reordered(fileutils, test_dir, wlmutils, m first_device=0, ) - # Add a colocated database to the ensemble members + # Add a colocated feature store to the ensemble members # and then add a script via file for i, entity in enumerate(colo_ensemble): entity.disable_key_prefixing() - entity.colocate_db_tcp( + entity.colocate_fs_tcp( port=test_port + i, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) @@ -437,18 +438,18 @@ def test_colocated_db_script_ensemble_reordered(fileutils, test_dir, wlmutils, m first_device=0, ) - # Add a colocated database to the non-ensemble SmartSim Model - colo_model.colocate_db_tcp( + # Add a colocated feature store to the non-ensemble SmartSim Application + colo_application.colocate_fs_tcp( port=test_port + len(colo_ensemble), - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) - # Add the non-ensemble SmartSim Model to the Ensemble + # Add the non-ensemble SmartSim Application to the Ensemble # and then add a script via file - colo_ensemble.add_model(colo_model) - colo_model.add_script( + colo_ensemble.add_application(colo_application) + colo_application.add_script( "test_script1", script_path=torch_script, device=test_device, @@ -456,10 +457,10 @@ def test_colocated_db_script_ensemble_reordered(fileutils, test_dir, wlmutils, m first_device=0, ) - # Assert we have added one model to the ensemble - assert len(colo_ensemble._db_scripts) == 1 - # Assert we have added both models to each entity - assert all([len(entity._db_scripts) == 2 for entity in colo_ensemble]) + # Assert we have added one application to the ensemble + assert len(colo_ensemble._fs_scripts) == 1 + # Assert we have added both applications to each entity + assert all([len(entity._fs_scripts) == 2 for entity in colo_ensemble]) exp.generate(colo_ensemble) @@ -467,17 +468,17 @@ def test_colocated_db_script_ensemble_reordered(fileutils, test_dir, wlmutils, m try: exp.start(colo_ensemble, block=True) statuses = exp.get_status(colo_ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) finally: exp.stop(colo_ensemble) @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") -def test_db_script_errors(fileutils, test_dir, wlmutils, mlutils): - """Test DB Scripts error when setting a serialized function on colocated DB""" +def test_fs_script_errors(fileutils, test_dir, wlmutils, mlutils): + """Test fs Scripts error when setting a serialized function on colocated fs""" # Set Experiment name - exp_name = "test-colocated-db-script" + exp_name = "test-colocated-fs-script" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -496,11 +497,11 @@ def test_db_script_errors(fileutils, test_dir, wlmutils, mlutils): colo_settings.set_nodes(1) colo_settings.set_tasks(1) - # Create a SmartSim model with a colocated database - colo_model = exp.create_model("colocated_model", colo_settings) - colo_model.colocate_db_tcp( + # Create a SmartSim application with a colocated feature store + colo_application = exp.create_application("colocated_application", colo_settings) + colo_application.colocate_fs_tcp( port=test_port, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) @@ -508,7 +509,7 @@ def test_db_script_errors(fileutils, test_dir, wlmutils, mlutils): # Check that an error is raised for adding in-memory # function when using colocated deployment with pytest.raises(SSUnsupportedError): - colo_model.add_function( + colo_application.add_function( "test_func", function=timestwo, device=test_device, @@ -516,23 +517,23 @@ def test_db_script_errors(fileutils, test_dir, wlmutils, mlutils): first_device=0, ) - # Create ensemble with two identical SmartSim Model entities + # Create ensemble with two identical SmartSim Application entities colo_settings = exp.create_run_settings(exe=sys.executable, exe_args=test_script) colo_ensemble = exp.create_ensemble( "colocated_ensemble", run_settings=colo_settings, replicas=2 ) - # Add a colocated database for each ensemble member + # Add a colocated feature store for each ensemble member for i, entity in enumerate(colo_ensemble): - entity.colocate_db_tcp( + entity.colocate_fs_tcp( port=test_port + i, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) # Check that an exception is raised when adding an in-memory - # function to the ensemble with colocated databases + # function to the ensemble with colocated feature stores with pytest.raises(SSUnsupportedError): colo_ensemble.add_function( "test_func", @@ -542,7 +543,7 @@ def test_db_script_errors(fileutils, test_dir, wlmutils, mlutils): first_device=0, ) - # Create an ensemble with two identical SmartSim Model entities + # Create an ensemble with two identical SmartSim Application entities colo_settings = exp.create_run_settings(exe=sys.executable, exe_args=test_script) colo_ensemble = exp.create_ensemble( "colocated_ensemble", run_settings=colo_settings, replicas=2 @@ -558,31 +559,31 @@ def test_db_script_errors(fileutils, test_dir, wlmutils, mlutils): ) # Check that an error is raised when trying to add - # a colocated database to ensemble members that have + # a colocated feature store to ensemble members that have # an in-memory script for i, entity in enumerate(colo_ensemble): with pytest.raises(SSUnsupportedError): - entity.colocate_db_tcp( + entity.colocate_fs_tcp( port=test_port + i, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) # Check that an error is raised when trying to add - # a colocated database to an Ensemble that has + # a colocated feature store to an Ensemble that has # an in-memory script with pytest.raises(SSUnsupportedError): - colo_ensemble.add_model(colo_model) + colo_ensemble.add_application(colo_application) -def test_inconsistent_params_db_script(fileutils): - """Test error when devices_per_node>1 and when devices is set to CPU in DBScript constructor""" +def test_inconsistent_params_fs_script(fileutils): + """Test error when devices_per_node>1 and when devices is set to CPU in FSScript constructor""" torch_script = fileutils.get_test_conf_path("torchscript.py") with pytest.raises(SSUnsupportedError) as ex: - _ = DBScript( - name="test_script_db", + _ = FSScript( + name="test_script_fs", script_path=torch_script, device="CPU", devices_per_node=2, @@ -593,8 +594,8 @@ def test_inconsistent_params_db_script(fileutils): == "Cannot set devices_per_node>1 if CPU is specified under devices" ) with pytest.raises(SSUnsupportedError) as ex: - _ = DBScript( - name="test_script_db", + _ = FSScript( + name="test_script_fs", script_path=torch_script, device="CPU", devices_per_node=1, @@ -607,11 +608,11 @@ def test_inconsistent_params_db_script(fileutils): @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") -def test_db_script_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): - """Test DB scripts on remote DB""" +def test_fs_script_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): + """Test fs scripts on remote fs""" # Set experiment name - exp_name = "test-db-script" + exp_name = "test-fs-script" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -631,15 +632,17 @@ def test_db_script_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): run_settings.set_nodes(1) run_settings.set_tasks(1) - # Create Ensemble with two identical models + # Create Ensemble with two identical applications ensemble = exp.create_ensemble( - "dbscript_ensemble", run_settings=run_settings, replicas=2 + "fsscript_ensemble", run_settings=run_settings, replicas=2 ) - # Create SmartSim model - smartsim_model = exp.create_model("smartsim_model", run_settings) - # Create 2nd SmartSim model - smartsim_model_2 = exp.create_model("smartsim_model_2", run_settings) + # Create SmartSim application + smartsim_application = exp.create_application("smartsim_application", run_settings) + # Create 2nd SmartSim application + smartsim_application_2 = exp.create_application( + "smartsim_application_2", run_settings + ) # Create the script string torch_script_str = "def negate(x):\n\treturn torch.neg(x)\n" @@ -683,8 +686,8 @@ def test_db_script_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): ) assert ex.value.args[0] == 'A Script with name "test_func" already exists' - # Add a script with a non-unique name to a SmartSim Model - smartsim_model.add_script( + # Add a script with a non-unique name to a SmartSim application + smartsim_application.add_script( "test_script1", script_path=torch_script, device=test_device, @@ -693,11 +696,11 @@ def test_db_script_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): ) with pytest.raises(SSUnsupportedError) as ex: - ensemble.add_model(smartsim_model) + ensemble.add_application(smartsim_application) assert ex.value.args[0] == 'A Script with name "test_script1" already exists' - # Add a function with a non-unique name to a SmartSim Model - smartsim_model_2.add_function( + # Add a function with a non-unique name to a SmartSim Application + smartsim_application_2.add_function( "test_func", function=timestwo, device=test_device, @@ -706,5 +709,5 @@ def test_db_script_ensemble_duplicate(fileutils, test_dir, wlmutils, mlutils): ) with pytest.raises(SSUnsupportedError) as ex: - ensemble.add_model(smartsim_model_2) + ensemble.add_application(smartsim_application_2) assert ex.value.args[0] == 'A Script with name "test_func" already exists' diff --git a/tests/backends/test_onnx.py b/tests/_legacy/backends/test_onnx.py similarity index 84% rename from tests/backends/test_onnx.py rename to tests/_legacy/backends/test_onnx.py index 29771bb1ca..67c9775aa3 100644 --- a/tests/backends/test_onnx.py +++ b/tests/_legacy/backends/test_onnx.py @@ -30,9 +30,7 @@ import pytest -from smartsim import Experiment -from smartsim._core.utils import installed_redisai_backends -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus sklearn_available = True try: @@ -47,7 +45,9 @@ sklearn_available = False -onnx_backend_available = "onnxruntime" in installed_redisai_backends() +onnx_backend_available = ( + "onnxruntime" in [] +) # todo: update test to replace installed_redisai_backends() should_run = sklearn_available and onnx_backend_available @@ -57,8 +57,8 @@ ) -def test_sklearn_onnx(wlm_experiment, prepare_db, single_db, mlutils, wlmutils): - """This test needs two free nodes, 1 for the db and 1 some sklearn models +def test_sklearn_onnx(wlm_experiment, prepare_fs, single_fs, mlutils, wlmutils): + """This test needs two free nodes, 1 for the fs and 1 some sklearn models here we test the following sklearn models: - LinearRegression @@ -75,15 +75,15 @@ def test_sklearn_onnx(wlm_experiment, prepare_db, single_db, mlutils, wlmutils): You may need to put CUDNN in your LD_LIBRARY_PATH if running on GPU """ test_device = mlutils.get_test_device() - db = prepare_db(single_db).orchestrator - wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + fs = prepare_fs(single_fs).featurestore + wlm_experiment.reconnect_feature_store(fs.checkpoint_file) run_settings = wlm_experiment.create_run_settings( sys.executable, f"run_sklearn_onnx.py --device={test_device}" ) if wlmutils.get_test_launcher() != "local": run_settings.set_tasks(1) - model = wlm_experiment.create_model("onnx_models", run_settings) + model = wlm_experiment.create_application("onnx_models", run_settings) script_dir = os.path.dirname(os.path.abspath(__file__)) script_path = Path(script_dir, "run_sklearn_onnx.py").resolve() @@ -94,4 +94,4 @@ def test_sklearn_onnx(wlm_experiment, prepare_db, single_db, mlutils, wlmutils): # if model failed, test will fail model_status = wlm_experiment.get_status(model) - assert model_status[0] != SmartSimStatus.STATUS_FAILED + assert model_status[0] != JobStatus.FAILED diff --git a/tests/backends/test_tf.py b/tests/_legacy/backends/test_tf.py similarity index 88% rename from tests/backends/test_tf.py rename to tests/_legacy/backends/test_tf.py index adf0e9daaf..526c08e29e 100644 --- a/tests/backends/test_tf.py +++ b/tests/_legacy/backends/test_tf.py @@ -29,10 +29,8 @@ import pytest -from smartsim import Experiment -from smartsim._core.utils import installed_redisai_backends from smartsim.error import SmartSimError -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus tf_available = True try: @@ -43,14 +41,16 @@ print(e) tf_available = False -tf_backend_available = "tensorflow" in installed_redisai_backends() +tf_backend_available = ( + "tensorflow" in [] +) # todo: update test to replace installed_redisai_backends() @pytest.mark.skipif( (not tf_backend_available) or (not tf_available), reason="Requires RedisAI TF backend", ) -def test_keras_model(wlm_experiment, prepare_db, single_db, mlutils, wlmutils): +def test_keras_model(wlm_experiment, prepare_fs, single_fs, mlutils, wlmutils): """This test needs two free nodes, 1 for the db and 1 for a keras model script this test can run on CPU/GPU by setting SMARTSIM_TEST_DEVICE=GPU @@ -61,8 +61,8 @@ def test_keras_model(wlm_experiment, prepare_db, single_db, mlutils, wlmutils): """ test_device = mlutils.get_test_device() - db = prepare_db(single_db).orchestrator - wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + fs = prepare_fs(single_fs).featurestore + wlm_experiment.reconnect_feature_store(fs.checkpoint_file) run_settings = wlm_experiment.create_run_settings( "python", f"run_tf.py --device={test_device}" @@ -70,7 +70,7 @@ def test_keras_model(wlm_experiment, prepare_db, single_db, mlutils, wlmutils): if wlmutils.get_test_launcher() != "local": run_settings.set_tasks(1) - model = wlm_experiment.create_model("tf_script", run_settings) + model = wlm_experiment.create_application("tf_script", run_settings) script_dir = os.path.dirname(os.path.abspath(__file__)) script_path = Path(script_dir, "run_tf.py").resolve() @@ -81,7 +81,7 @@ def test_keras_model(wlm_experiment, prepare_db, single_db, mlutils, wlmutils): # if model failed, test will fail model_status = wlm_experiment.get_status(model)[0] - assert model_status != SmartSimStatus.STATUS_FAILED + assert model_status != JobStatus.FAILED def create_tf_model(): diff --git a/tests/backends/test_torch.py b/tests/_legacy/backends/test_torch.py similarity index 83% rename from tests/backends/test_torch.py rename to tests/_legacy/backends/test_torch.py index 6aff6b0baf..2606d08837 100644 --- a/tests/backends/test_torch.py +++ b/tests/_legacy/backends/test_torch.py @@ -29,9 +29,7 @@ import pytest -from smartsim import Experiment -from smartsim._core.utils import installed_redisai_backends -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus torch_available = True try: @@ -40,7 +38,9 @@ except ImportError: torch_available = False -torch_backend_available = "torch" in installed_redisai_backends() +torch_backend_available = ( + "torch" in [] +) # todo: update test to replace installed_redisai_backends() should_run = torch_available and torch_backend_available pytestmark = pytest.mark.skipif( @@ -49,9 +49,9 @@ def test_torch_model_and_script( - wlm_experiment, prepare_db, single_db, mlutils, wlmutils + wlm_experiment, prepare_fs, single_fs, mlutils, wlmutils ): - """This test needs two free nodes, 1 for the db and 1 for a torch model script + """This test needs two free nodes, 1 for the fs and 1 for a torch model script Here we test both the torchscipt API and the NN API from torch @@ -62,8 +62,8 @@ def test_torch_model_and_script( You may need to put CUDNN in your LD_LIBRARY_PATH if running on GPU """ - db = prepare_db(single_db).orchestrator - wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + fs = prepare_fs(single_fs).featurestore + wlm_experiment.reconnect_feature_store(fs.checkpoint_file) test_device = mlutils.get_test_device() test_num_gpus = mlutils.get_test_num_gpus() if pytest.test_device == "GPU" else 1 @@ -73,7 +73,7 @@ def test_torch_model_and_script( ) if wlmutils.get_test_launcher() != "local": run_settings.set_tasks(1) - model = wlm_experiment.create_model("torch_script", run_settings) + model = wlm_experiment.create_application("torch_script", run_settings) script_dir = os.path.dirname(os.path.abspath(__file__)) script_path = Path(script_dir, "run_torch.py").resolve() @@ -84,4 +84,4 @@ def test_torch_model_and_script( # if model failed, test will fail model_status = wlm_experiment.get_status(model)[0] - assert model_status != SmartSimStatus.STATUS_FAILED + assert model_status != JobStatus.FAILED diff --git a/tests/full_wlm/test_generic_batch_launch.py b/tests/_legacy/full_wlm/test_generic_batch_launch.py similarity index 72% rename from tests/full_wlm/test_generic_batch_launch.py rename to tests/_legacy/full_wlm/test_generic_batch_launch.py index fd8017c7c8..9e87ce70b3 100644 --- a/tests/full_wlm/test_generic_batch_launch.py +++ b/tests/_legacy/full_wlm/test_generic_batch_launch.py @@ -30,7 +30,7 @@ from smartsim import Experiment from smartsim.settings import QsubBatchSettings -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # retrieved from pytest fixtures if pytest.test_launcher not in pytest.wlm_options: @@ -51,10 +51,10 @@ def add_batch_resources(wlmutils, batch_settings): batch_settings.set_resource(key, value) -def test_batch_model(fileutils, test_dir, wlmutils): - """Test the launch of a manually construced batch model""" +def test_batch_application(fileutils, test_dir, wlmutils): + """Test the launch of a manually construced batch application""" - exp_name = "test-batch-model" + exp_name = "test-batch-application" exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") @@ -63,15 +63,18 @@ def test_batch_model(fileutils, test_dir, wlmutils): batch_settings.set_account(wlmutils.get_test_account()) add_batch_resources(wlmutils, batch_settings) run_settings = wlmutils.get_run_settings("python", f"{script} --time=5") - model = exp.create_model( - "model", path=test_dir, run_settings=run_settings, batch_settings=batch_settings + application = exp.create_application( + "application", + path=test_dir, + run_settings=run_settings, + batch_settings=batch_settings, ) - exp.generate(model) - exp.start(model, block=True) - statuses = exp.get_status(model) + exp.generate(application) + exp.start(application, block=True) + statuses = exp.get_status(application) assert len(statuses) == 1 - assert statuses[0] == SmartSimStatus.STATUS_COMPLETED + assert statuses[0] == JobStatus.COMPLETED def test_batch_ensemble(fileutils, test_dir, wlmutils): @@ -82,21 +85,21 @@ def test_batch_ensemble(fileutils, test_dir, wlmutils): script = fileutils.get_test_conf_path("sleep.py") settings = wlmutils.get_run_settings("python", f"{script} --time=5") - M1 = exp.create_model("m1", path=test_dir, run_settings=settings) - M2 = exp.create_model("m2", path=test_dir, run_settings=settings) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings) + M2 = exp.create_application("m2", path=test_dir, run_settings=settings) batch = exp.create_batch_settings(nodes=1, time="00:01:00") add_batch_resources(wlmutils, batch) batch.set_account(wlmutils.get_test_account()) ensemble = exp.create_ensemble("batch-ens", batch_settings=batch) - ensemble.add_model(M1) - ensemble.add_model(M2) + ensemble.add_application(M1) + ensemble.add_application(M2) exp.generate(ensemble) exp.start(ensemble, block=True) statuses = exp.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) def test_batch_ensemble_replicas(fileutils, test_dir, wlmutils): @@ -116,4 +119,27 @@ def test_batch_ensemble_replicas(fileutils, test_dir, wlmutils): exp.start(ensemble, block=True) statuses = exp.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) + + +def test_batch_run_args_leading_dashes(fileutils, test_dir, wlmutils): + """ + Test that batch args strip leading `-` + """ + exp_name = "test-batch-run-args-leading-dashes" + exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) + + script = fileutils.get_test_conf_path("sleep.py") + batch_args = {"--nodes": 1} + batch_settings = exp.create_batch_settings(time="00:01:00", batch_args=batch_args) + + batch_settings.set_account(wlmutils.get_test_account()) + add_batch_resources(wlmutils, batch_settings) + run_settings = wlmutils.get_run_settings("python", f"{script} --time=5") + model = exp.create_model( + "model", path=test_dir, run_settings=run_settings, batch_settings=batch_settings + ) + + exp.start(model, block=True) + statuses = exp.get_status(model) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) diff --git a/tests/full_wlm/test_generic_orc_launch_batch.py b/tests/_legacy/full_wlm/test_generic_orc_launch_batch.py similarity index 51% rename from tests/full_wlm/test_generic_orc_launch_batch.py rename to tests/_legacy/full_wlm/test_generic_orc_launch_batch.py index 2a5627d6df..eef250e715 100644 --- a/tests/full_wlm/test_generic_orc_launch_batch.py +++ b/tests/_legacy/full_wlm/test_generic_orc_launch_batch.py @@ -32,7 +32,7 @@ from smartsim import Experiment from smartsim.settings.pbsSettings import QsubBatchSettings -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # retrieved from pytest fixtures if pytest.test_launcher not in pytest.wlm_options: @@ -40,7 +40,7 @@ if (pytest.test_launcher == "pbs") and (not pytest.has_aprun): pytestmark = pytest.mark.skip( - reason="Launching orchestrators in a batch job is not supported on PBS without ALPS" + reason="Launching feature stores in a batch job is not supported on PBS without ALPS" ) @@ -53,179 +53,179 @@ def add_batch_resources(wlmutils, batch_settings): batch_settings.set_resource(key, value) -def test_launch_orc_auto_batch(test_dir, wlmutils): - """test single node orchestrator""" +def test_launch_feature_store_auto_batch(test_dir, wlmutils): + """test single node feature store""" launcher = wlmutils.get_test_launcher() - exp_name = "test-launch-auto-orc-batch" + exp_name = "test-launch-auto-feature-store-batch" exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) # batch = False to launch on existing allocation network_interface = wlmutils.get_test_interface() - orc = exp.create_database( + feature_store = exp.create_feature_store( wlmutils.get_test_port(), batch=True, interface=network_interface, single_cmd=False, ) - orc.batch_settings.set_account(wlmutils.get_test_account()) - add_batch_resources(wlmutils, orc.batch_settings) + feature_store.batch_settings.set_account(wlmutils.get_test_account()) + add_batch_resources(wlmutils, feature_store.batch_settings) - orc.batch_settings.set_walltime("00:05:00") - orc.set_path(test_dir) + feature_store.batch_settings.set_walltime("00:05:00") + feature_store.set_path(test_dir) - exp.start(orc, block=True) - statuses = exp.get_status(orc) + exp.start(feature_store, block=True) + statuses = exp.get_status(feature_store) # don't use assert so that we don't leave an orphan process - if SmartSimStatus.STATUS_FAILED in statuses: - exp.stop(orc) + if JobStatus.FAILED in statuses: + exp.stop(feature_store) assert False - exp.stop(orc) - statuses = exp.get_status(orc) - assert all([stat == SmartSimStatus.STATUS_CANCELLED for stat in statuses]) + exp.stop(feature_store) + statuses = exp.get_status(feature_store) + assert all([stat == JobStatus.CANCELLED for stat in statuses]) -def test_launch_cluster_orc_batch_single(test_dir, wlmutils): - """test clustered 3-node orchestrator with single command""" +def test_launch_cluster_feature_store_batch_single(test_dir, wlmutils): + """test clustered 3-node feature store with single command""" # TODO detect number of nodes in allocation and skip if not sufficent launcher = wlmutils.get_test_launcher() - exp_name = "test-launch-auto-cluster-orc-batch-single" + exp_name = "test-launch-auto-cluster-feature-store-batch-single" exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) # batch = False to launch on existing allocation network_interface = wlmutils.get_test_interface() - orc = exp.create_database( + feature_store = exp.create_feature_store( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=True, interface=network_interface, single_cmd=True, ) - orc.batch_settings.set_account(wlmutils.get_test_account()) - add_batch_resources(wlmutils, orc.batch_settings) + feature_store.batch_settings.set_account(wlmutils.get_test_account()) + add_batch_resources(wlmutils, feature_store.batch_settings) - orc.batch_settings.set_walltime("00:05:00") - orc.set_path(test_dir) + feature_store.batch_settings.set_walltime("00:05:00") + feature_store.set_path(test_dir) - exp.start(orc, block=True) - statuses = exp.get_status(orc) + exp.start(feature_store, block=True) + statuses = exp.get_status(feature_store) - # don't use assert so that orc we don't leave an orphan process - if SmartSimStatus.STATUS_FAILED in statuses: - exp.stop(orc) + # don't use assert so that feature_store we don't leave an orphan process + if JobStatus.FAILED in statuses: + exp.stop(feature_store) assert False - exp.stop(orc) - statuses = exp.get_status(orc) - assert all([stat == SmartSimStatus.STATUS_CANCELLED for stat in statuses]) + exp.stop(feature_store) + statuses = exp.get_status(feature_store) + assert all([stat == JobStatus.CANCELLED for stat in statuses]) -def test_launch_cluster_orc_batch_multi(test_dir, wlmutils): - """test clustered 3-node orchestrator""" +def test_launch_cluster_feature_store_batch_multi(test_dir, wlmutils): + """test clustered 3-node feature store""" # TODO detect number of nodes in allocation and skip if not sufficent launcher = wlmutils.get_test_launcher() - exp_name = "test-launch-auto-cluster-orc-batch-multi" + exp_name = "test-launch-auto-cluster-feature-store-batch-multi" exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) # batch = False to launch on existing allocation network_interface = wlmutils.get_test_interface() - orc = exp.create_database( + feature_store = exp.create_feature_store( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=True, interface=network_interface, single_cmd=False, ) - orc.batch_settings.set_account(wlmutils.get_test_account()) - add_batch_resources(wlmutils, orc.batch_settings) + feature_store.batch_settings.set_account(wlmutils.get_test_account()) + add_batch_resources(wlmutils, feature_store.batch_settings) - orc.batch_settings.set_walltime("00:05:00") - orc.set_path(test_dir) + feature_store.batch_settings.set_walltime("00:05:00") + feature_store.set_path(test_dir) - exp.start(orc, block=True) - statuses = exp.get_status(orc) + exp.start(feature_store, block=True) + statuses = exp.get_status(feature_store) - # don't use assert so that orc we don't leave an orphan process - if SmartSimStatus.STATUS_FAILED in statuses: - exp.stop(orc) + # don't use assert so that feature_store we don't leave an orphan process + if JobStatus.FAILED in statuses: + exp.stop(feature_store) assert False - exp.stop(orc) - statuses = exp.get_status(orc) - assert all([stat == SmartSimStatus.STATUS_CANCELLED for stat in statuses]) + exp.stop(feature_store) + statuses = exp.get_status(feature_store) + assert all([stat == JobStatus.CANCELLED for stat in statuses]) -def test_launch_cluster_orc_reconnect(test_dir, wlmutils): - """test reconnecting to clustered 3-node orchestrator""" +def test_launch_cluster_feature_store_reconnect(test_dir, wlmutils): + """test reconnecting to clustered 3-node feature store""" p_test_dir = pathlib.Path(test_dir) launcher = wlmutils.get_test_launcher() - exp_name = "test-launch-cluster-orc-batch-reconect" + exp_name = "test-launch-cluster-feature-store-batch-reconect" exp_1_dir = p_test_dir / exp_name exp_1_dir.mkdir() exp = Experiment(exp_name, launcher=launcher, exp_path=str(exp_1_dir)) # batch = False to launch on existing allocation network_interface = wlmutils.get_test_interface() - orc = exp.create_database( - wlmutils.get_test_port(), db_nodes=3, batch=True, interface=network_interface + feature_store = exp.create_feature_store( + wlmutils.get_test_port(), fs_nodes=3, batch=True, interface=network_interface ) - orc.batch_settings.set_account(wlmutils.get_test_account()) - add_batch_resources(wlmutils, orc.batch_settings) + feature_store.batch_settings.set_account(wlmutils.get_test_account()) + add_batch_resources(wlmutils, feature_store.batch_settings) - orc.batch_settings.set_walltime("00:05:00") + feature_store.batch_settings.set_walltime("00:05:00") - exp.start(orc, block=True) + exp.start(feature_store, block=True) - statuses = exp.get_status(orc) + statuses = exp.get_status(feature_store) try: - assert all(stat == SmartSimStatus.STATUS_RUNNING for stat in statuses) + assert all(stat == JobStatus.RUNNING for stat in statuses) except Exception: - exp.stop(orc) + exp.stop(feature_store) raise - exp_name = "test-orc-cluster-orc-batch-reconnect-2nd" + exp_name = "test-feature_store-cluster-feature-store-batch-reconnect-2nd" exp_2_dir = p_test_dir / exp_name exp_2_dir.mkdir() exp_2 = Experiment(exp_name, launcher=launcher, exp_path=str(exp_2_dir)) try: - checkpoint = osp.join(orc.path, "smartsim_db.dat") - reloaded_orc = exp_2.reconnect_orchestrator(checkpoint) + checkpoint = osp.join(feature_store.path, "smartsim_db.dat") + reloaded_feature_store = exp_2.reconnect_feature_store(checkpoint) # let statuses update once time.sleep(5) - statuses = exp_2.get_status(reloaded_orc) - assert all(stat == SmartSimStatus.STATUS_RUNNING for stat in statuses) + statuses = exp_2.get_status(reloaded_feature_store) + assert all(stat == JobStatus.RUNNING for stat in statuses) except Exception: - # Something went wrong! Let the experiment that started the DB - # clean up the DB - exp.stop(orc) + # Something went wrong! Let the experiment that started the FS + # clean up the FS + exp.stop(feature_store) raise try: - # Test experiment 2 can stop the DB - exp_2.stop(reloaded_orc) + # Test experiment 2 can stop the FS + exp_2.stop(reloaded_feature_store) assert all( - stat == SmartSimStatus.STATUS_CANCELLED - for stat in exp_2.get_status(reloaded_orc) + stat == JobStatus.CANCELLED + for stat in exp_2.get_status(reloaded_feature_store) ) except Exception: - # Something went wrong! Let the experiment that started the DB - # clean up the DB - exp.stop(orc) + # Something went wrong! Let the experiment that started the FS + # clean up the FS + exp.stop(feature_store) raise else: - # Ensure it is the same DB that Experiment 1 was tracking + # Ensure it is the same FS that Experiment 1 was tracking time.sleep(5) assert not any( - stat == SmartSimStatus.STATUS_RUNNING for stat in exp.get_status(orc) + stat == JobStatus.RUNNING for stat in exp.get_status(feature_store) ) diff --git a/tests/full_wlm/test_mpmd.py b/tests/_legacy/full_wlm/test_mpmd.py similarity index 87% rename from tests/full_wlm/test_mpmd.py rename to tests/_legacy/full_wlm/test_mpmd.py index 0167a8f083..e2280308e7 100644 --- a/tests/full_wlm/test_mpmd.py +++ b/tests/_legacy/full_wlm/test_mpmd.py @@ -30,7 +30,7 @@ from smartsim import Experiment from smartsim._core.utils.helpers import is_valid_cmd -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # retrieved from pytest fixtures if pytest.test_launcher not in pytest.wlm_options: @@ -38,7 +38,7 @@ def test_mpmd(fileutils, test_dir, wlmutils): - """Run an MPMD model twice + """Run an MPMD application twice and check that it always gets executed the same way. All MPMD-compatible run commands which do not @@ -87,13 +87,13 @@ def prune_commands(launcher): settings.make_mpmd(deepcopy(settings)) settings.make_mpmd(deepcopy(settings)) - mpmd_model = exp.create_model( + mpmd_application = exp.create_application( f"mpmd-{run_command}", path=test_dir, run_settings=settings ) - exp.start(mpmd_model, block=True) - statuses = exp.get_status(mpmd_model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + exp.start(mpmd_application, block=True) + statuses = exp.get_status(mpmd_application) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) - exp.start(mpmd_model, block=True) - statuses = exp.get_status(mpmd_model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + exp.start(mpmd_application, block=True) + statuses = exp.get_status(mpmd_application) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) diff --git a/tests/full_wlm/test_slurm_allocation.py b/tests/_legacy/full_wlm/test_slurm_allocation.py similarity index 100% rename from tests/full_wlm/test_slurm_allocation.py rename to tests/_legacy/full_wlm/test_slurm_allocation.py diff --git a/tests/full_wlm/test_symlinking.py b/tests/_legacy/full_wlm/test_symlinking.py similarity index 77% rename from tests/full_wlm/test_symlinking.py rename to tests/_legacy/full_wlm/test_symlinking.py index c5b5b90bab..feb5f25f36 100644 --- a/tests/full_wlm/test_symlinking.py +++ b/tests/_legacy/full_wlm/test_symlinking.py @@ -36,23 +36,29 @@ pytestmark = pytest.mark.skip(reason="Not testing WLM integrations") -def test_batch_model_and_ensemble(test_dir, wlmutils): +def test_batch_application_and_ensemble(test_dir, wlmutils): exp_name = "test-batch" launcher = wlmutils.get_test_launcher() exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) rs = exp.create_run_settings("echo", ["spam", "eggs"]) bs = exp.create_batch_settings() - test_model = exp.create_model( - "test_model", path=test_dir, run_settings=rs, batch_settings=bs + test_application = exp.create_application( + "test_application", path=test_dir, run_settings=rs, batch_settings=bs ) - exp.generate(test_model) - exp.start(test_model, block=True) + exp.generate(test_application) + exp.start(test_application, block=True) - assert pathlib.Path(test_model.path).exists() - _should_be_symlinked(pathlib.Path(test_model.path, f"{test_model.name}.out"), True) - _should_be_symlinked(pathlib.Path(test_model.path, f"{test_model.name}.err"), False) - _should_not_be_symlinked(pathlib.Path(test_model.path, f"{test_model.name}.sh")) + assert pathlib.Path(test_application.path).exists() + _should_be_symlinked( + pathlib.Path(test_application.path, f"{test_application.name}.out"), True + ) + _should_be_symlinked( + pathlib.Path(test_application.path, f"{test_application.name}.err"), False + ) + _should_not_be_symlinked( + pathlib.Path(test_application.path, f"{test_application.name}.sh") + ) test_ensemble = exp.create_ensemble( "test_ensemble", params={}, batch_settings=bs, run_settings=rs, replicas=3 @@ -61,7 +67,7 @@ def test_batch_model_and_ensemble(test_dir, wlmutils): exp.start(test_ensemble, block=True) assert pathlib.Path(test_ensemble.path).exists() - for i in range(len(test_ensemble.models)): + for i in range(len(test_ensemble.applications)): _should_be_symlinked( pathlib.Path( test_ensemble.path, @@ -94,7 +100,7 @@ def test_batch_ensemble_symlinks(test_dir, wlmutils): exp.generate(test_ensemble) exp.start(test_ensemble, block=True) - for i in range(len(test_ensemble.models)): + for i in range(len(test_ensemble.applications)): _should_be_symlinked( pathlib.Path( test_ensemble.path, @@ -115,32 +121,38 @@ def test_batch_ensemble_symlinks(test_dir, wlmutils): _should_not_be_symlinked(pathlib.Path(exp.exp_path, "smartsim_params.txt")) -def test_batch_model_symlinks(test_dir, wlmutils): - exp_name = "test-batch-model" +def test_batch_application_symlinks(test_dir, wlmutils): + exp_name = "test-batch-application" launcher = wlmutils.get_test_launcher() exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) rs = exp.create_run_settings("echo", ["spam", "eggs"]) bs = exp.create_batch_settings() - test_model = exp.create_model( - "test_model", path=test_dir, run_settings=rs, batch_settings=bs + test_application = exp.create_application( + "test_application", path=test_dir, run_settings=rs, batch_settings=bs ) - exp.generate(test_model) - exp.start(test_model, block=True) + exp.generate(test_application) + exp.start(test_application, block=True) - assert pathlib.Path(test_model.path).exists() + assert pathlib.Path(test_application.path).exists() - _should_be_symlinked(pathlib.Path(test_model.path, f"{test_model.name}.out"), True) - _should_be_symlinked(pathlib.Path(test_model.path, f"{test_model.name}.err"), False) - _should_not_be_symlinked(pathlib.Path(test_model.path, f"{test_model.name}.sh")) + _should_be_symlinked( + pathlib.Path(test_application.path, f"{test_application.name}.out"), True + ) + _should_be_symlinked( + pathlib.Path(test_application.path, f"{test_application.name}.err"), False + ) + _should_not_be_symlinked( + pathlib.Path(test_application.path, f"{test_application.name}.sh") + ) -def test_batch_orchestrator_symlinks(test_dir, wlmutils): +def test_batch_feature_store_symlinks(test_dir, wlmutils): exp_name = "test-batch-orc" launcher = wlmutils.get_test_launcher() exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) port = 2424 - db = exp.create_database( - db_nodes=3, + db = exp.create_feature_store( + fs_nodes=3, port=port, batch=True, interface=wlmutils.get_test_interface(), @@ -154,7 +166,7 @@ def test_batch_orchestrator_symlinks(test_dir, wlmutils): _should_be_symlinked(pathlib.Path(db.path, f"{db.name}.out"), False) _should_be_symlinked(pathlib.Path(db.path, f"{db.name}.err"), False) - for i in range(db.db_nodes): + for i in range(db.fs_nodes): _should_be_symlinked(pathlib.Path(db.path, f"{db.name}_{i}.out"), False) _should_be_symlinked(pathlib.Path(db.path, f"{db.name}_{i}.err"), False) _should_not_be_symlinked( diff --git a/tests/full_wlm/test_wlm_helper_functions.py b/tests/_legacy/full_wlm/test_wlm_helper_functions.py similarity index 100% rename from tests/full_wlm/test_wlm_helper_functions.py rename to tests/_legacy/full_wlm/test_wlm_helper_functions.py diff --git a/tests/install/test_build.py b/tests/_legacy/install/test_build.py similarity index 100% rename from tests/install/test_build.py rename to tests/_legacy/install/test_build.py diff --git a/tests/install/test_buildenv.py b/tests/_legacy/install/test_buildenv.py similarity index 100% rename from tests/install/test_buildenv.py rename to tests/_legacy/install/test_buildenv.py diff --git a/tests/_legacy/install/test_builder.py b/tests/_legacy/install/test_builder.py new file mode 100644 index 0000000000..e0518a96d8 --- /dev/null +++ b/tests/_legacy/install/test_builder.py @@ -0,0 +1,404 @@ +# # BSD 2-Clause License +# # +# # Copyright (c) 2021-2024, Hewlett Packard Enterprise +# # All rights reserved. +# # +# # Redistribution and use in source and binary forms, with or without +# # modification, are permitted provided that the following conditions are met: +# # +# # 1. Redistributions of source code must retain the above copyright notice, this +# # list of conditions and the following disclaimer. +# # +# # 2. Redistributions in binary form must reproduce the above copyright notice, +# # this list of conditions and the following disclaimer in the documentation +# # and/or other materials provided with the distribution. +# # +# # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +# import functools +# import pathlib +# import textwrap +# import time + +# import pytest + +# import smartsim._core._install.builder as build +# from smartsim._core._install.buildenv import RedisAIVersion + +# # The tests in this file belong to the group_a group +# pytestmark = pytest.mark.group_a + +# RAI_VERSIONS = RedisAIVersion("1.2.7") + +# for_each_device = pytest.mark.parametrize( +# "device", [build.Device.CPU, build.Device.GPU] +# ) + +# _toggle_build_optional_backend = lambda backend: pytest.mark.parametrize( +# f"build_{backend}", +# [ +# pytest.param(switch, id=f"with{'' if switch else 'out'}-{backend}") +# for switch in (True, False) +# ], +# ) +# toggle_build_tf = _toggle_build_optional_backend("tf") +# toggle_build_pt = _toggle_build_optional_backend("pt") +# toggle_build_ort = _toggle_build_optional_backend("ort") + + +# @pytest.mark.parametrize( +# "mock_os", [pytest.param(os_, id=f"os='{os_}'") for os_ in ("Windows", "Java", "")] +# ) +# def test_os_enum_raises_on_unsupported(mock_os): +# with pytest.raises(build.BuildError, match="operating system") as err_info: +# build.OperatingSystem.from_str(mock_os) + + +# @pytest.mark.parametrize( +# "mock_arch", +# [ +# pytest.param(arch_, id=f"arch='{arch_}'") +# for arch_ in ("i386", "i686", "i86pc", "aarch64", "armv7l", "") +# ], +# ) +# def test_arch_enum_raises_on_unsupported(mock_arch): +# with pytest.raises(build.BuildError, match="architecture"): +# build.Architecture.from_str(mock_arch) + + +# @pytest.fixture +# def p_test_dir(test_dir): +# yield pathlib.Path(test_dir).resolve() + + +# @for_each_device +# def test_rai_builder_raises_if_attempting_to_place_deps_when_build_dir_dne( +# monkeypatch, p_test_dir, device +# ): +# monkeypatch.setattr(build.RedisAIBuilder, "_validate_platform", lambda a: None) +# monkeypatch.setattr( +# build.RedisAIBuilder, +# "rai_build_path", +# property(lambda self: p_test_dir / "path/to/dir/that/dne"), +# ) +# rai_builder = build.RedisAIBuilder() +# with pytest.raises(build.BuildError, match=r"build directory not found"): +# rai_builder._fetch_deps_for(device) + + +# @for_each_device +# def test_rai_builder_raises_if_attempting_to_place_deps_in_nonempty_dir( +# monkeypatch, p_test_dir, device +# ): +# (p_test_dir / "some_file.txt").touch() +# monkeypatch.setattr(build.RedisAIBuilder, "_validate_platform", lambda a: None) +# monkeypatch.setattr( +# build.RedisAIBuilder, "rai_build_path", property(lambda self: p_test_dir) +# ) +# monkeypatch.setattr( +# build.RedisAIBuilder, "get_deps_dir_path_for", lambda *a, **kw: p_test_dir +# ) +# rai_builder = build.RedisAIBuilder() + +# with pytest.raises(build.BuildError, match=r"is not empty"): +# rai_builder._fetch_deps_for(device) + + +# invalid_build_arm64 = [ +# dict(build_tf=True, build_onnx=True), +# dict(build_tf=False, build_onnx=True), +# dict(build_tf=True, build_onnx=False), +# ] +# invalid_build_ids = [ +# ",".join([f"{key}={value}" for key, value in d.items()]) +# for d in invalid_build_arm64 +# ] + + +# @pytest.mark.parametrize("build_options", invalid_build_arm64, ids=invalid_build_ids) +# def test_rai_builder_raises_if_unsupported_deps_on_arm64(build_options): +# with pytest.raises(build.BuildError, match=r"are not supported on.*ARM64"): +# build.RedisAIBuilder( +# _os=build.OperatingSystem.DARWIN, +# architecture=build.Architecture.ARM64, +# **build_options, +# ) + + +# def _confirm_inst_presence(type_, should_be_present, seq): +# expected_num_occurrences = 1 if should_be_present else 0 +# occurrences = filter(lambda item: isinstance(item, type_), seq) +# return expected_num_occurrences == len(tuple(occurrences)) + + +# # Helper functions to check for the presence (or absence) of a +# # ``_RAIBuildDependency`` dependency in a list of dependencies that need to be +# # fetched by a ``RedisAIBuilder`` instance +# dlpack_dep_presence = functools.partial( +# _confirm_inst_presence, build._DLPackRepository, True +# ) +# pt_dep_presence = functools.partial(_confirm_inst_presence, build._PTArchive) +# tf_dep_presence = functools.partial(_confirm_inst_presence, build._TFArchive) +# ort_dep_presence = functools.partial(_confirm_inst_presence, build._ORTArchive) + + +# @for_each_device +# @toggle_build_tf +# @toggle_build_pt +# @toggle_build_ort +# def test_rai_builder_will_add_dep_if_backend_requested_wo_duplicates( +# monkeypatch, device, build_tf, build_pt, build_ort +# ): +# monkeypatch.setattr(build.RedisAIBuilder, "_validate_platform", lambda a: None) + +# rai_builder = build.RedisAIBuilder( +# build_tf=build_tf, build_torch=build_pt, build_onnx=build_ort +# ) +# requested_backends = rai_builder._get_deps_to_fetch_for(build.Device(device)) +# assert dlpack_dep_presence(requested_backends) +# assert tf_dep_presence(build_tf, requested_backends) +# assert pt_dep_presence(build_pt, requested_backends) +# assert ort_dep_presence(build_ort, requested_backends) + + +# @for_each_device +# @toggle_build_tf +# @toggle_build_pt +# def test_rai_builder_will_not_add_dep_if_custom_dep_path_provided( +# monkeypatch, device, p_test_dir, build_tf, build_pt +# ): +# monkeypatch.setattr(build.RedisAIBuilder, "_validate_platform", lambda a: None) +# mock_ml_lib = p_test_dir / "some/ml/lib" +# mock_ml_lib.mkdir(parents=True) +# rai_builder = build.RedisAIBuilder( +# build_tf=build_tf, +# build_torch=build_pt, +# build_onnx=False, +# libtf_dir=str(mock_ml_lib if build_tf else ""), +# torch_dir=str(mock_ml_lib if build_pt else ""), +# ) +# requested_backends = rai_builder._get_deps_to_fetch_for(device) +# assert dlpack_dep_presence(requested_backends) +# assert tf_dep_presence(False, requested_backends) +# assert pt_dep_presence(False, requested_backends) +# assert ort_dep_presence(False, requested_backends) +# assert len(requested_backends) == 1 + + +# def test_rai_builder_raises_if_it_fetches_an_unexpected_number_of_ml_deps( +# monkeypatch, p_test_dir +# ): +# monkeypatch.setattr(build.RedisAIBuilder, "_validate_platform", lambda a: None) +# monkeypatch.setattr( +# build.RedisAIBuilder, "rai_build_path", property(lambda self: p_test_dir) +# ) +# monkeypatch.setattr( +# build, +# "_place_rai_dep_at", +# lambda target, verbose: lambda dep: target +# / "whoops_all_ml_deps_extract_to_a_dir_with_this_name", +# ) +# rai_builder = build.RedisAIBuilder(build_tf=True, build_torch=True, build_onnx=True) +# with pytest.raises( +# build.BuildError, +# match=r"Expected to place \d+ dependencies, but only found \d+", +# ): +# rai_builder._fetch_deps_for(build.Device.CPU) + + +# def test_threaded_map(): +# def _some_io_op(x): +# return x * x + +# assert (0, 1, 4, 9, 16) == tuple(build._threaded_map(_some_io_op, range(5))) + + +# def test_threaded_map_returns_early_if_nothing_to_map(): +# sleep_duration = 60 + +# def _some_long_io_op(_): +# time.sleep(sleep_duration) + +# start = time.time() +# build._threaded_map(_some_long_io_op, []) +# end = time.time() +# assert end - start < sleep_duration + + +# def test_correct_pt_variant_os(): +# # Check that all Linux variants return Linux +# for linux_variant in build.OperatingSystem.LINUX.value: +# os_ = build.OperatingSystem.from_str(linux_variant) +# assert build._choose_pt_variant(os_) == build._PTArchiveLinux + +# # Check that ARM64 and X86_64 Mac OSX return the Mac variant +# all_archs = (build.Architecture.ARM64, build.Architecture.X64) +# for arch in all_archs: +# os_ = build.OperatingSystem.DARWIN +# assert build._choose_pt_variant(os_) == build._PTArchiveMacOSX + + +# def test_PTArchiveMacOSX_url(): +# arch = build.Architecture.X64 +# pt_version = RAI_VERSIONS.torch + +# pt_linux_cpu = build._PTArchiveLinux( +# build.Architecture.X64, build.Device.CPU, pt_version, False +# ) +# x64_prefix = "https://download.pytorch.org/libtorch/" +# assert x64_prefix in pt_linux_cpu.url + +# pt_macosx_cpu = build._PTArchiveMacOSX( +# build.Architecture.ARM64, build.Device.CPU, pt_version, False +# ) +# arm64_prefix = "https://github.com/CrayLabs/ml_lib_builder/releases/download/" +# assert arm64_prefix in pt_macosx_cpu.url + + +# def test_PTArchiveMacOSX_gpu_error(): +# with pytest.raises(build.BuildError, match="support GPU on Mac OSX"): +# build._PTArchiveMacOSX( +# build.Architecture.ARM64, build.Device.GPU, RAI_VERSIONS.torch, False +# ).url + + +# def test_valid_platforms(): +# assert build.RedisAIBuilder( +# _os=build.OperatingSystem.LINUX, +# architecture=build.Architecture.X64, +# build_tf=True, +# build_torch=True, +# build_onnx=True, +# ) +# assert build.RedisAIBuilder( +# _os=build.OperatingSystem.DARWIN, +# architecture=build.Architecture.X64, +# build_tf=True, +# build_torch=True, +# build_onnx=False, +# ) +# assert build.RedisAIBuilder( +# _os=build.OperatingSystem.DARWIN, +# architecture=build.Architecture.X64, +# build_tf=False, +# build_torch=True, +# build_onnx=False, +# ) + + +# @pytest.mark.parametrize( +# "plat,cmd,expected_cmd", +# [ +# # Bare Word +# pytest.param( +# build.Platform(build.OperatingSystem.LINUX, build.Architecture.X64), +# ["git", "clone", "my-repo"], +# ["git", "clone", "my-repo"], +# id="git-Linux-X64", +# ), +# pytest.param( +# build.Platform(build.OperatingSystem.LINUX, build.Architecture.ARM64), +# ["git", "clone", "my-repo"], +# ["git", "clone", "my-repo"], +# id="git-Linux-Arm64", +# ), +# pytest.param( +# build.Platform(build.OperatingSystem.DARWIN, build.Architecture.X64), +# ["git", "clone", "my-repo"], +# ["git", "clone", "my-repo"], +# id="git-Darwin-X64", +# ), +# pytest.param( +# build.Platform(build.OperatingSystem.DARWIN, build.Architecture.ARM64), +# ["git", "clone", "my-repo"], +# [ +# "git", +# "clone", +# "--config", +# "core.autocrlf=false", +# "--config", +# "core.eol=lf", +# "my-repo", +# ], +# id="git-Darwin-Arm64", +# ), +# # Abs path +# pytest.param( +# build.Platform(build.OperatingSystem.LINUX, build.Architecture.X64), +# ["/path/to/git", "clone", "my-repo"], +# ["/path/to/git", "clone", "my-repo"], +# id="Abs-Linux-X64", +# ), +# pytest.param( +# build.Platform(build.OperatingSystem.LINUX, build.Architecture.ARM64), +# ["/path/to/git", "clone", "my-repo"], +# ["/path/to/git", "clone", "my-repo"], +# id="Abs-Linux-Arm64", +# ), +# pytest.param( +# build.Platform(build.OperatingSystem.DARWIN, build.Architecture.X64), +# ["/path/to/git", "clone", "my-repo"], +# ["/path/to/git", "clone", "my-repo"], +# id="Abs-Darwin-X64", +# ), +# pytest.param( +# build.Platform(build.OperatingSystem.DARWIN, build.Architecture.ARM64), +# ["/path/to/git", "clone", "my-repo"], +# [ +# "/path/to/git", +# "clone", +# "--config", +# "core.autocrlf=false", +# "--config", +# "core.eol=lf", +# "my-repo", +# ], +# id="Abs-Darwin-Arm64", +# ), +# ], +# ) +# def test_git_commands_are_configered_correctly_for_platforms(plat, cmd, expected_cmd): +# assert build.config_git_command(plat, cmd) == expected_cmd + + +# def test_modify_source_files(p_test_dir): +# def make_text_blurb(food): +# return textwrap.dedent(f"""\ +# My favorite food is {food} +# {food} is an important part of a healthy breakfast +# {food} {food} {food} {food} +# This line should be unchanged! +# --> {food} <-- +# """) + +# original_word = "SPAM" +# mutated_word = "EGGS" + +# source_files = [] +# for i in range(3): +# source_file = p_test_dir / f"test_{i}" +# source_file.touch() +# source_file.write_text(make_text_blurb(original_word)) +# source_files.append(source_file) +# # Modify a single file +# build._modify_source_files(source_files[0], original_word, mutated_word) +# assert source_files[0].read_text() == make_text_blurb(mutated_word) +# assert source_files[1].read_text() == make_text_blurb(original_word) +# assert source_files[2].read_text() == make_text_blurb(original_word) + +# # Modify multiple files +# build._modify_source_files( +# (source_files[1], source_files[2]), original_word, mutated_word +# ) +# assert source_files[1].read_text() == make_text_blurb(mutated_word) +# assert source_files[2].read_text() == make_text_blurb(mutated_word) diff --git a/tests/install/test_mlpackage.py b/tests/_legacy/install/test_mlpackage.py similarity index 100% rename from tests/install/test_mlpackage.py rename to tests/_legacy/install/test_mlpackage.py diff --git a/tests/install/test_package_retriever.py b/tests/_legacy/install/test_package_retriever.py similarity index 100% rename from tests/install/test_package_retriever.py rename to tests/_legacy/install/test_package_retriever.py diff --git a/tests/install/test_platform.py b/tests/_legacy/install/test_platform.py similarity index 100% rename from tests/install/test_platform.py rename to tests/_legacy/install/test_platform.py diff --git a/tests/install/test_redisai_builder.py b/tests/_legacy/install/test_redisai_builder.py similarity index 100% rename from tests/install/test_redisai_builder.py rename to tests/_legacy/install/test_redisai_builder.py diff --git a/tests/on_wlm/test_base_settings_on_wlm.py b/tests/_legacy/on_wlm/test_base_settings_on_wlm.py similarity index 74% rename from tests/on_wlm/test_base_settings_on_wlm.py rename to tests/_legacy/on_wlm/test_base_settings_on_wlm.py index 77bebd524c..1559b6e5f7 100644 --- a/tests/on_wlm/test_base_settings_on_wlm.py +++ b/tests/_legacy/on_wlm/test_base_settings_on_wlm.py @@ -29,10 +29,10 @@ import pytest from smartsim import Experiment -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus """ -Test the launch and stop of models and ensembles using base +Test the launch and stop of applications and ensembles using base RunSettings while on WLM. """ @@ -41,38 +41,38 @@ pytestmark = pytest.mark.skip(reason="Not testing WLM integrations") -def test_model_on_wlm(fileutils, test_dir, wlmutils): - exp_name = "test-base-settings-model-launch" +def test_application_on_wlm(fileutils, test_dir, wlmutils): + exp_name = "test-base-settings-application-launch" exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") settings1 = wlmutils.get_base_run_settings("python", f"{script} --time=5") settings2 = wlmutils.get_base_run_settings("python", f"{script} --time=5") - M1 = exp.create_model("m1", path=test_dir, run_settings=settings1) - M2 = exp.create_model("m2", path=test_dir, run_settings=settings2) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings1) + M2 = exp.create_application("m2", path=test_dir, run_settings=settings2) - # launch models twice to show that they can also be restarted + # launch applications twice to show that they can also be restarted for _ in range(2): exp.start(M1, M2, block=True) statuses = exp.get_status(M1, M2) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) -def test_model_stop_on_wlm(fileutils, test_dir, wlmutils): - exp_name = "test-base-settings-model-stop" +def test_application_stop_on_wlm(fileutils, test_dir, wlmutils): + exp_name = "test-base-settings-application-stop" exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") settings1 = wlmutils.get_base_run_settings("python", f"{script} --time=5") settings2 = wlmutils.get_base_run_settings("python", f"{script} --time=5") - M1 = exp.create_model("m1", path=test_dir, run_settings=settings1) - M2 = exp.create_model("m2", path=test_dir, run_settings=settings2) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings1) + M2 = exp.create_application("m2", path=test_dir, run_settings=settings2) - # stop launched models + # stop launched applications exp.start(M1, M2, block=False) time.sleep(2) exp.stop(M1, M2) assert M1.name in exp._control._jobs.completed assert M2.name in exp._control._jobs.completed statuses = exp.get_status(M1, M2) - assert all([stat == SmartSimStatus.STATUS_CANCELLED for stat in statuses]) + assert all([stat == JobStatus.CANCELLED for stat in statuses]) diff --git a/tests/_legacy/on_wlm/test_colocated_model.py b/tests/_legacy/on_wlm/test_colocated_model.py new file mode 100644 index 0000000000..5df3778017 --- /dev/null +++ b/tests/_legacy/on_wlm/test_colocated_model.py @@ -0,0 +1,207 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +import pytest + +from smartsim import Experiment +from smartsim.entity import Application +from smartsim.status import JobStatus + +if sys.platform == "darwin": + supported_fss = ["tcp", "deprecated"] +else: + supported_fss = ["uds", "tcp", "deprecated"] + +# Set to true if fs logs should be generated for debugging +DEBUG_fs = False + +# retrieved from pytest fixtures +launcher = pytest.test_launcher +if launcher not in pytest.wlm_options: + pytestmark = pytest.mark.skip(reason="Not testing WLM integrations") + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_launch_colocated_application_defaults(fileutils, test_dir, coloutils, fs_type): + """Test the launch of a application with a colocated feature store and local launcher""" + + fs_args = {"debug": DEBUG_fs} + + exp = Experiment( + "colocated_application_defaults", launcher=launcher, exp_path=test_dir + ) + colo_application = coloutils.setup_test_colo( + fileutils, fs_type, exp, "send_data_local_smartredis.py", fs_args, on_wlm=True + ) + exp.generate(colo_application) + assert colo_application.run_settings.colocated_fs_settings["custom_pinning"] == "0" + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all( + stat == JobStatus.COMPLETED for stat in statuses + ), f"Statuses: {statuses}" + + # test restarting the colocated application + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all( + stat == JobStatus.COMPLETED for stat in statuses + ), f"Statuses: {statuses}" + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_colocated_application_disable_pinning(fileutils, test_dir, coloutils, fs_type): + exp = Experiment( + "colocated_application_pinning_auto_1cpu", launcher=launcher, exp_path=test_dir + ) + fs_args = { + "fs_cpus": 1, + "custom_pinning": [], + "debug": DEBUG_fs, + } + + # Check to make sure that the CPU mask was correctly generated + colo_application = coloutils.setup_test_colo( + fileutils, fs_type, exp, "send_data_local_smartredis.py", fs_args, on_wlm=True + ) + assert colo_application.run_settings.colocated_fs_settings["custom_pinning"] is None + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all( + stat == JobStatus.COMPLETED for stat in statuses + ), f"Statuses: {statuses}" + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_colocated_application_pinning_auto_2cpu( + fileutils, test_dir, coloutils, fs_type +): + exp = Experiment( + "colocated_application_pinning_auto_2cpu", + launcher=launcher, + exp_path=test_dir, + ) + + fs_args = {"fs_cpus": 2, "debug": DEBUG_fs} + + # Check to make sure that the CPU mask was correctly generated + colo_application = coloutils.setup_test_colo( + fileutils, fs_type, exp, "send_data_local_smartredis.py", fs_args, on_wlm=True + ) + assert ( + colo_application.run_settings.colocated_fs_settings["custom_pinning"] == "0,1" + ) + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all( + stat == JobStatus.COMPLETED for stat in statuses + ), f"Statuses: {statuses}" + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_colocated_application_pinning_range(fileutils, test_dir, coloutils, fs_type): + # Check to make sure that the CPU mask was correctly generated + # Assume that there are at least 4 cpus on the node + + exp = Experiment( + "colocated_application_pinning_manual", + launcher=launcher, + exp_path=test_dir, + ) + + fs_args = {"fs_cpus": 4, "custom_pinning": range(4), "debug": DEBUG_fs} + + colo_application = coloutils.setup_test_colo( + fileutils, fs_type, exp, "send_data_local_smartredis.py", fs_args, on_wlm=True + ) + assert ( + colo_application.run_settings.colocated_fs_settings["custom_pinning"] + == "0,1,2,3" + ) + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all( + stat == JobStatus.COMPLETED for stat in statuses + ), f"Statuses: {statuses}" + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_colocated_application_pinning_list(fileutils, test_dir, coloutils, fs_type): + # Check to make sure that the CPU mask was correctly generated + # note we presume that this has more than 2 CPUs on the supercomputer node + + exp = Experiment( + "colocated_application_pinning_manual", + launcher=launcher, + exp_path=test_dir, + ) + + fs_args = {"fs_cpus": 2, "custom_pinning": [0, 2]} + + colo_application = coloutils.setup_test_colo( + fileutils, fs_type, exp, "send_data_local_smartredis.py", fs_args, on_wlm=True + ) + assert ( + colo_application.run_settings.colocated_fs_settings["custom_pinning"] == "0,2" + ) + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all( + stat == JobStatus.COMPLETED for stat in statuses + ), f"Statuses: {statuses}" + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_colocated_application_pinning_mixed(fileutils, test_dir, coloutils, fs_type): + # Check to make sure that the CPU mask was correctly generated + # note we presume that this at least 4 CPUs on the supercomputer node + + exp = Experiment( + "colocated_application_pinning_manual", + launcher=launcher, + exp_path=test_dir, + ) + + fs_args = {"fs_cpus": 2, "custom_pinning": [range(2), 3]} + + colo_application = coloutils.setup_test_colo( + fileutils, fs_type, exp, "send_data_local_smartredis.py", fs_args, on_wlm=True + ) + assert ( + colo_application.run_settings.colocated_fs_settings["custom_pinning"] == "0,1,3" + ) + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all( + stat == JobStatus.COMPLETED for stat in statuses + ), f"Statuses: {statuses}" diff --git a/tests/on_wlm/test_containers_wlm.py b/tests/_legacy/on_wlm/test_containers_wlm.py similarity index 88% rename from tests/on_wlm/test_containers_wlm.py rename to tests/_legacy/on_wlm/test_containers_wlm.py index 21f1e1c5e1..473c9fac47 100644 --- a/tests/on_wlm/test_containers_wlm.py +++ b/tests/_legacy/on_wlm/test_containers_wlm.py @@ -31,7 +31,7 @@ from smartsim import Experiment from smartsim.entity import Ensemble from smartsim.settings.containers import Singularity -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus """Test SmartRedis container integration on a supercomputer with a WLM.""" @@ -44,7 +44,7 @@ def test_singularity_wlm_smartredis(fileutils, test_dir, wlmutils): """Run two processes, each process puts a tensor on the DB, then accesses the other process's tensor. - Finally, the tensor is used to run a model. + Finally, the tensor is used to run a application. Note: This is a containerized port of test_smartredis.py for WLM system """ @@ -59,12 +59,12 @@ def test_singularity_wlm_smartredis(fileutils, test_dir, wlmutils): "smartredis_ensemble_exchange", exp_path=test_dir, launcher=launcher ) - # create and start a database - orc = exp.create_database( + # create and start a feature store + feature_store = exp.create_feature_store( port=wlmutils.get_test_port(), interface=wlmutils.get_test_interface() ) - exp.generate(orc) - exp.start(orc, block=False) + exp.generate(feature_store) + exp.start(feature_store, block=False) container = Singularity(containerURI) rs = exp.create_run_settings( @@ -87,16 +87,16 @@ def test_singularity_wlm_smartredis(fileutils, test_dir, wlmutils): exp.generate(ensemble) - # start the models + # start the applications exp.start(ensemble, summary=False) # get and confirm statuses statuses = exp.get_status(ensemble) - if not all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]): - exp.stop(orc) + if not all([stat == JobStatus.COMPLETED for stat in statuses]): + exp.stop(feature_store) assert False # client ensemble failed - # stop the orchestrator - exp.stop(orc) + # stop the feature store + exp.stop(feature_store) print(exp.summary()) diff --git a/tests/on_wlm/test_dragon.py b/tests/_legacy/on_wlm/test_dragon.py similarity index 86% rename from tests/on_wlm/test_dragon.py rename to tests/_legacy/on_wlm/test_dragon.py index 1bef3cac8d..d835d60ce1 100644 --- a/tests/on_wlm/test_dragon.py +++ b/tests/_legacy/on_wlm/test_dragon.py @@ -26,8 +26,8 @@ import pytest from smartsim import Experiment -from smartsim._core.launcher.dragon.dragonLauncher import DragonLauncher -from smartsim.status import SmartSimStatus +from smartsim._core.launcher.dragon.dragon_launcher import DragonLauncher +from smartsim.status import JobStatus # retrieved from pytest fixtures if pytest.test_launcher != "dragon": @@ -42,13 +42,13 @@ def test_dragon_global_path(global_dragon_teardown, wlmutils, test_dir, monkeypa launcher=wlmutils.get_test_launcher(), ) rs = exp.create_run_settings(exe="sleep", exe_args=["1"]) - model = exp.create_model("sleep", run_settings=rs) + model = exp.create_application("sleep", run_settings=rs) exp.generate(model) exp.start(model, block=True) try: - assert exp.get_status(model)[0] == SmartSimStatus.STATUS_COMPLETED + assert exp.get_status(model)[0] == JobStatus.COMPLETED finally: launcher: DragonLauncher = exp._control._launcher launcher.cleanup() @@ -63,12 +63,12 @@ def test_dragon_exp_path(global_dragon_teardown, wlmutils, test_dir, monkeypatch launcher=wlmutils.get_test_launcher(), ) rs = exp.create_run_settings(exe="sleep", exe_args=["1"]) - model = exp.create_model("sleep", run_settings=rs) + model = exp.create_application("sleep", run_settings=rs) exp.generate(model) exp.start(model, block=True) try: - assert exp.get_status(model)[0] == SmartSimStatus.STATUS_COMPLETED + assert exp.get_status(model)[0] == JobStatus.COMPLETED finally: launcher: DragonLauncher = exp._control._launcher launcher.cleanup() @@ -82,13 +82,13 @@ def test_dragon_cannot_honor(wlmutils, test_dir): ) rs = exp.create_run_settings(exe="sleep", exe_args=["1"]) rs.set_nodes(100) - model = exp.create_model("sleep", run_settings=rs) + model = exp.create_application("sleep", run_settings=rs) exp.generate(model) exp.start(model, block=True) try: - assert exp.get_status(model)[0] == SmartSimStatus.STATUS_FAILED + assert exp.get_status(model)[0] == JobStatus.FAILED finally: launcher: DragonLauncher = exp._control._launcher launcher.cleanup() diff --git a/tests/on_wlm/test_dragon_entrypoint.py b/tests/_legacy/on_wlm/test_dragon_entrypoint.py similarity index 100% rename from tests/on_wlm/test_dragon_entrypoint.py rename to tests/_legacy/on_wlm/test_dragon_entrypoint.py diff --git a/tests/on_wlm/test_generic_orc_launch.py b/tests/_legacy/on_wlm/test_generic_orc_launch.py similarity index 63% rename from tests/on_wlm/test_generic_orc_launch.py rename to tests/_legacy/on_wlm/test_generic_orc_launch.py index cacdd5be5b..ee34888de6 100644 --- a/tests/on_wlm/test_generic_orc_launch.py +++ b/tests/_legacy/on_wlm/test_generic_orc_launch.py @@ -27,23 +27,23 @@ import pytest from smartsim import Experiment -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # retrieved from pytest fixtures if pytest.test_launcher not in pytest.wlm_options: pytestmark = pytest.mark.skip(reason="Not testing WLM integrations") -def test_launch_orc_auto(test_dir, wlmutils): - """test single node orchestrator""" +def test_launch_feature_store_auto(test_dir, wlmutils): + """test single node feature store""" launcher = wlmutils.get_test_launcher() - exp_name = "test-launch-auto-orc" + exp_name = "test-launch-auto-feature_store" exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) # batch = False to launch on existing allocation network_interface = wlmutils.get_test_interface() - orc = exp.create_database( + feature_store = exp.create_feature_store( wlmutils.get_test_port(), batch=False, interface=network_interface, @@ -51,78 +51,78 @@ def test_launch_orc_auto(test_dir, wlmutils): hosts=wlmutils.get_test_hostlist(), ) - exp.start(orc, block=True) - statuses = exp.get_status(orc) + exp.start(feature_store, block=True) + statuses = exp.get_status(feature_store) # don't use assert so that we don't leave an orphan process - if SmartSimStatus.STATUS_FAILED in statuses: - exp.stop(orc) + if JobStatus.FAILED in statuses: + exp.stop(feature_store) assert False - exp.stop(orc) - statuses = exp.get_status(orc) - assert all([stat == SmartSimStatus.STATUS_CANCELLED for stat in statuses]) + exp.stop(feature_store) + statuses = exp.get_status(feature_store) + assert all([stat == JobStatus.CANCELLED for stat in statuses]) -def test_launch_cluster_orc_single(test_dir, wlmutils): - """test clustered 3-node orchestrator with single command""" +def test_launch_cluster_feature_store_single(test_dir, wlmutils): + """test clustered 3-node feature store with single command""" # TODO detect number of nodes in allocation and skip if not sufficent launcher = wlmutils.get_test_launcher() - exp_name = "test-launch-auto-cluster-orc-single" + exp_name = "test-launch-auto-cluster-feature_store-single" exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) # batch = False to launch on existing allocation network_interface = wlmutils.get_test_interface() - orc = exp.create_database( + feature_store = exp.create_feature_store( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, interface=network_interface, single_cmd=True, hosts=wlmutils.get_test_hostlist(), ) - exp.start(orc, block=True) - statuses = exp.get_status(orc) + exp.start(feature_store, block=True) + statuses = exp.get_status(feature_store) - # don't use assert so that orc we don't leave an orphan process - if SmartSimStatus.STATUS_FAILED in statuses: - exp.stop(orc) + # don't use assert so that feature_store we don't leave an orphan process + if JobStatus.FAILED in statuses: + exp.stop(feature_store) assert False - exp.stop(orc) - statuses = exp.get_status(orc) - assert all([stat == SmartSimStatus.STATUS_CANCELLED for stat in statuses]) + exp.stop(feature_store) + statuses = exp.get_status(feature_store) + assert all([stat == JobStatus.CANCELLED for stat in statuses]) -def test_launch_cluster_orc_multi(test_dir, wlmutils): - """test clustered 3-node orchestrator with multiple commands""" +def test_launch_cluster_feature_store_multi(test_dir, wlmutils): + """test clustered 3-node feature store with multiple commands""" # TODO detect number of nodes in allocation and skip if not sufficent launcher = wlmutils.get_test_launcher() - exp_name = "test-launch-auto-cluster-orc-multi" + exp_name = "test-launch-auto-cluster-feature-store-multi" exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) # batch = False to launch on existing allocation network_interface = wlmutils.get_test_interface() - orc = exp.create_database( + feature_store = exp.create_feature_store( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, interface=network_interface, single_cmd=False, hosts=wlmutils.get_test_hostlist(), ) - exp.start(orc, block=True) - statuses = exp.get_status(orc) + exp.start(feature_store, block=True) + statuses = exp.get_status(feature_store) - # don't use assert so that orc we don't leave an orphan process - if SmartSimStatus.STATUS_FAILED in statuses: - exp.stop(orc) + # don't use assert so that feature_store we don't leave an orphan process + if JobStatus.FAILED in statuses: + exp.stop(feature_store) assert False - exp.stop(orc) - statuses = exp.get_status(orc) - assert all([stat == SmartSimStatus.STATUS_CANCELLED for stat in statuses]) + exp.stop(feature_store) + statuses = exp.get_status(feature_store) + assert all([stat == JobStatus.CANCELLED for stat in statuses]) diff --git a/tests/on_wlm/test_het_job.py b/tests/_legacy/on_wlm/test_het_job.py similarity index 93% rename from tests/on_wlm/test_het_job.py rename to tests/_legacy/on_wlm/test_het_job.py index aeea7b474e..459f2a9526 100644 --- a/tests/on_wlm/test_het_job.py +++ b/tests/_legacy/on_wlm/test_het_job.py @@ -63,19 +63,19 @@ def test_set_het_groups(monkeypatch, test_dir): rs.set_het_group([4]) -def test_orch_single_cmd(monkeypatch, wlmutils, test_dir): +def test_feature_store_single_cmd(monkeypatch, wlmutils, test_dir): """Test that single cmd is rejected in a heterogeneous job""" monkeypatch.setenv("SLURM_HET_SIZE", "1") - exp_name = "test-orch-single-cmd" + exp_name = "test-feature-store-single-cmd" exp = Experiment(exp_name, launcher="slurm", exp_path=test_dir) - orc = exp.create_database( + feature_store = exp.create_feature_store( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, interface=wlmutils.get_test_interface(), single_cmd=True, hosts=wlmutils.get_test_hostlist(), ) - for node in orc: + for node in feature_store: assert node.is_mpmd == False diff --git a/tests/on_wlm/test_launch_errors.py b/tests/_legacy/on_wlm/test_launch_errors.py similarity index 84% rename from tests/on_wlm/test_launch_errors.py rename to tests/_legacy/on_wlm/test_launch_errors.py index 2498a5a91a..2596cd9eec 100644 --- a/tests/on_wlm/test_launch_errors.py +++ b/tests/_legacy/on_wlm/test_launch_errors.py @@ -30,7 +30,7 @@ from smartsim import Experiment from smartsim.error import SmartSimError -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # retrieved from pytest fixtures if pytest.test_launcher not in pytest.wlm_options: @@ -38,7 +38,7 @@ def test_failed_status(fileutils, test_dir, wlmutils): - """Test when a failure occurs deep into model execution""" + """Test when a failure occurs deep into application execution""" exp_name = "test-report-failure" exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) @@ -48,14 +48,16 @@ def test_failed_status(fileutils, test_dir, wlmutils): "python", f"{script} --time=7", run_comamnd="auto" ) - model = exp.create_model("bad-model", path=test_dir, run_settings=settings) + application = exp.create_application( + "bad-application", path=test_dir, run_settings=settings + ) - exp.start(model, block=False) - while not exp.finished(model): + exp.start(application, block=False) + while not exp.finished(application): time.sleep(2) - stat = exp.get_status(model) + stat = exp.get_status(application) assert len(stat) == 1 - assert stat[0] == SmartSimStatus.STATUS_FAILED + assert stat[0] == JobStatus.FAILED def test_bad_run_command_args(fileutils, test_dir, wlmutils): @@ -79,7 +81,9 @@ def test_bad_run_command_args(fileutils, test_dir, wlmutils): "python", f"{script} --time=5", run_args={"badarg": "badvalue"} ) - model = exp.create_model("bad-model", path=test_dir, run_settings=settings) + application = exp.create_application( + "bad-application", path=test_dir, run_settings=settings + ) with pytest.raises(SmartSimError): - exp.start(model) + exp.start(application) diff --git a/tests/on_wlm/test_launch_ompi_lsf.py b/tests/_legacy/on_wlm/test_launch_ompi_lsf.py similarity index 87% rename from tests/on_wlm/test_launch_ompi_lsf.py rename to tests/_legacy/on_wlm/test_launch_ompi_lsf.py index 51c82e4184..9545c5634f 100644 --- a/tests/on_wlm/test_launch_ompi_lsf.py +++ b/tests/_legacy/on_wlm/test_launch_ompi_lsf.py @@ -27,7 +27,7 @@ import pytest from smartsim import Experiment -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # retrieved from pytest fixtures if pytest.test_launcher not in pytest.wlm_options: @@ -47,7 +47,9 @@ def test_launch_openmpi_lsf(fileutils, test_dir, wlmutils): settings.set_cpus_per_task(1) settings.set_tasks(1) - model = exp.create_model("ompi-model", path=test_dir, run_settings=settings) - exp.start(model, block=True) - statuses = exp.get_status(model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + application = exp.create_application( + "ompi-application", path=test_dir, run_settings=settings + ) + exp.start(application, block=True) + statuses = exp.get_status(application) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) diff --git a/tests/on_wlm/test_local_step.py b/tests/_legacy/on_wlm/test_local_step.py similarity index 94% rename from tests/on_wlm/test_local_step.py rename to tests/_legacy/on_wlm/test_local_step.py index 8f7d823b8b..00c76bb331 100644 --- a/tests/on_wlm/test_local_step.py +++ b/tests/_legacy/on_wlm/test_local_step.py @@ -61,9 +61,9 @@ def test_local_env_pass_implicit(fileutils, test_dir) -> None: # NOTE: not passing env_args into run_settings here, relying on --export=ALL default settings = RunSettings(exe_name, exe_args, run_command="srun", run_args=run_args) app_name = "echo_app" - app = exp.create_model(app_name, settings) + app = exp.create_application(app_name, settings) - # generate the experiment structure and start the model + # generate the experiment structure and start the application exp.generate(app, overwrite=True) exp.start(app, block=True, summary=False) @@ -100,9 +100,9 @@ def test_local_env_pass_explicit(fileutils, test_dir) -> None: exe_name, exe_args, run_command="srun", run_args=run_args, env_vars=env_vars ) app_name = "echo_app" - app = exp.create_model(app_name, settings) + app = exp.create_application(app_name, settings) - # generate the experiment structure and start the model + # generate the experiment structure and start the application exp.generate(app, overwrite=True) exp.start(app, block=True, summary=False) diff --git a/tests/on_wlm/test_preview_wlm.py b/tests/_legacy/on_wlm/test_preview_wlm.py similarity index 79% rename from tests/on_wlm/test_preview_wlm.py rename to tests/_legacy/on_wlm/test_preview_wlm.py index 78da30c9af..66705669e7 100644 --- a/tests/on_wlm/test_preview_wlm.py +++ b/tests/_legacy/on_wlm/test_preview_wlm.py @@ -31,9 +31,9 @@ from jinja2.filters import FILTERS from smartsim import Experiment -from smartsim._core import Manifest, previewrenderer +from smartsim._core import Manifest, preview_renderer from smartsim._core.config import CONFIG -from smartsim.database import Orchestrator +from smartsim.database import FeatureStore from smartsim.settings import QsubBatchSettings, RunSettings pytestmark = pytest.mark.slow_tests @@ -62,47 +62,47 @@ def add_batch_resources(wlmutils, batch_settings): pytest.test_launcher not in pytest.wlm_options, reason="Not testing WLM integrations", ) -def test_preview_wlm_run_commands_cluster_orc_model( +def test_preview_wlm_run_commands_cluster_feature_store_model( test_dir, coloutils, fileutils, wlmutils ): """ Test preview of wlm run command and run aruguments on a - orchestrator and model + feature store and model """ - exp_name = "test-preview-orc-model" + exp_name = "test-preview-feature-store-model" launcher = wlmutils.get_test_launcher() test_port = wlmutils.get_test_port() test_script = fileutils.get_test_conf_path("smartredis/multidbid.py") exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) network_interface = wlmutils.get_test_interface() - orc = exp.create_database( + feature_store = exp.create_feature_store( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, interface=network_interface, single_cmd=True, hosts=wlmutils.get_test_hostlist(), - db_identifier="testdb_reg", + fs_identifier="testfs_reg", ) - db_args = { + fs_args = { "port": test_port, - "db_cpus": 1, + "fs_cpus": 1, "debug": True, - "db_identifier": "testdb_colo", + "fs_identifier": "testfs_colo", } # Create model with colocated database smartsim_model = coloutils.setup_test_colo( - fileutils, "uds", exp, test_script, db_args, on_wlm=on_wlm + fileutils, "uds", exp, test_script, fs_args, on_wlm=on_wlm ) - preview_manifest = Manifest(orc, smartsim_model) + preview_manifest = Manifest(feature_store, smartsim_model) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output if pytest.test_launcher != "dragon": @@ -126,13 +126,13 @@ def test_preview_model_on_wlm(fileutils, test_dir, wlmutils): script = fileutils.get_test_conf_path("sleep.py") settings1 = wlmutils.get_base_run_settings("python", f"{script} --time=5") settings2 = wlmutils.get_base_run_settings("python", f"{script} --time=5") - M1 = exp.create_model("m1", path=test_dir, run_settings=settings1) - M2 = exp.create_model("m2", path=test_dir, run_settings=settings2) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings1) + M2 = exp.create_application("m2", path=test_dir, run_settings=settings2) preview_manifest = Manifest(M1, M2) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") if pytest.test_launcher != "dragon": assert "Run Command" in output @@ -158,7 +158,7 @@ def test_preview_batch_model(fileutils, test_dir, wlmutils): batch_settings.set_account(wlmutils.get_test_account()) add_batch_resources(wlmutils, batch_settings) run_settings = wlmutils.get_run_settings("python", f"{script} --time=5") - model = exp.create_model( + model = exp.create_application( "model", path=test_dir, run_settings=run_settings, batch_settings=batch_settings ) model.set_path(test_dir) @@ -166,7 +166,7 @@ def test_preview_batch_model(fileutils, test_dir, wlmutils): preview_manifest = Manifest(model) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") assert "Batch Launch: True" in output assert "Batch Command" in output @@ -187,8 +187,8 @@ def test_preview_batch_ensemble(fileutils, test_dir, wlmutils): script = fileutils.get_test_conf_path("sleep.py") settings = wlmutils.get_run_settings("python", f"{script} --time=5") - M1 = exp.create_model("m1", path=test_dir, run_settings=settings) - M2 = exp.create_model("m2", path=test_dir, run_settings=settings) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings) + M2 = exp.create_application("m2", path=test_dir, run_settings=settings) batch = exp.create_batch_settings(nodes=1, time="00:01:00") add_batch_resources(wlmutils, batch) @@ -202,7 +202,7 @@ def test_preview_batch_ensemble(fileutils, test_dir, wlmutils): preview_manifest = Manifest(ensemble) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") assert "Batch Launch: True" in output assert "Batch Command" in output @@ -216,7 +216,7 @@ def test_preview_batch_ensemble(fileutils, test_dir, wlmutils): reason="Not testing WLM integrations", ) def test_preview_launch_command(test_dir, wlmutils, choose_host): - """Test preview launch command for orchestrator, models, and + """Test preview launch command for feature store, models, and ensembles""" # Prepare entities test_launcher = wlmutils.get_test_launcher() @@ -225,7 +225,7 @@ def test_preview_launch_command(test_dir, wlmutils, choose_host): exp_name = "test_preview_launch_command" exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) # create regular database - orc = exp.create_database( + feature_store = exp.create_feature_store( port=test_port, interface=test_interface, hosts=choose_host(wlmutils), @@ -235,11 +235,11 @@ def test_preview_launch_command(test_dir, wlmutils, choose_host): rs1 = RunSettings("bash", "multi_tags_template.sh") rs2 = exp.create_run_settings("echo", ["spam", "eggs"]) - hello_world_model = exp.create_model( + hello_world_model = exp.create_application( "echo-hello", run_settings=rs1, params=model_params ) - spam_eggs_model = exp.create_model("echo-spam", run_settings=rs2) + spam_eggs_model = exp.create_application("echo-spam", run_settings=rs2) # setup ensemble parameter space learning_rate = list(np.linspace(0.01, 0.5)) @@ -256,12 +256,14 @@ def test_preview_launch_command(test_dir, wlmutils, choose_host): n_models=4, ) - preview_manifest = Manifest(orc, spam_eggs_model, hello_world_model, ensemble) + preview_manifest = Manifest( + feature_store, spam_eggs_model, hello_world_model, ensemble + ) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") - assert "orchestrator" in output + assert "feature store" in output assert "echo-spam" in output assert "echo-hello" in output @@ -288,24 +290,24 @@ def test_preview_batch_launch_command(fileutils, test_dir, wlmutils): batch_settings.set_account(wlmutils.get_test_account()) add_batch_resources(wlmutils, batch_settings) run_settings = wlmutils.get_run_settings("python", f"{script} --time=5") - model = exp.create_model( + model = exp.create_application( "model", path=test_dir, run_settings=run_settings, batch_settings=batch_settings ) model.set_path(test_dir) - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=True, interface="lo", launcher="slurm", run_command="srun", ) - orc.set_batch_arg("account", "ACCOUNT") + feature_store.set_batch_arg("account", "ACCOUNT") - preview_manifest = Manifest(orc, model) + preview_manifest = Manifest(feature_store, model) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Batch Launch: True" in output @@ -326,9 +328,9 @@ def test_ensemble_batch(test_dir, wlmutils): exp = Experiment( "test-preview-ensemble-clientconfig", exp_path=test_dir, launcher=test_launcher ) - # Create Orchestrator - db = exp.create_database(port=6780, interface="lo") - exp.generate(db, overwrite=True) + # Create feature store + fs = exp.create_feature_store(port=6780, interface="lo") + exp.generate(fs, overwrite=True) rs1 = exp.create_run_settings("echo", ["hello", "world"]) # Create ensemble batch_settings = exp.create_batch_settings(nodes=1, time="00:01:00") @@ -342,22 +344,22 @@ def test_ensemble_batch(test_dir, wlmutils): exp.generate(ensemble, overwrite=True) rs2 = exp.create_run_settings("echo", ["spam", "eggs"]) # Create model - ml_model = exp.create_model("tf_training", rs2) + ml_model = exp.create_application("tf_training", rs2) for sim in ensemble.entities: ml_model.register_incoming_entity(sim) exp.generate(ml_model, overwrite=True) - preview_manifest = Manifest(db, ml_model, ensemble) + preview_manifest = Manifest(fs, ml_model, ensemble) # Call preview renderer for testing output - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Client Configuration" in output - assert "Database Identifier" in output - assert "Database Backend" in output + assert "Feature Store Identifier" in output + assert "Feature Store Backend" in output assert "Type" in output @@ -365,7 +367,7 @@ def test_ensemble_batch(test_dir, wlmutils): pytest.test_launcher not in pytest.wlm_options, reason="Not testing WLM integrations", ) -def test_preview_ensemble_db_script(wlmutils, test_dir): +def test_preview_ensemble_fs_script(wlmutils, test_dir): """ Test preview of a torch script on a model in an ensemble. """ @@ -373,15 +375,15 @@ def test_preview_ensemble_db_script(wlmutils, test_dir): test_launcher = wlmutils.get_test_launcher() exp = Experiment("getting-started", launcher=test_launcher) - orch = exp.create_database(db_identifier="test_db1") - orch_2 = exp.create_database(db_identifier="test_db2", db_nodes=3) + feature_store = exp.create_feature_store(fs_identifier="test_fs1") + feature_store_2 = exp.create_feature_store(fs_identifier="test_fs2", fs_nodes=3) # Initialize a RunSettings object model_settings = exp.create_run_settings(exe="python", exe_args="params.py") model_settings_2 = exp.create_run_settings(exe="python", exe_args="params.py") model_settings_3 = exp.create_run_settings(exe="python", exe_args="params.py") # Initialize a Model object - model_instance = exp.create_model("model_name", model_settings) - model_instance_2 = exp.create_model("model_name_2", model_settings_2) + model_instance = exp.create_application("model_name", model_settings) + model_instance_2 = exp.create_application("model_name_2", model_settings_2) batch = exp.create_batch_settings(time="24:00:00", account="test") ensemble = exp.create_ensemble( "ensemble", batch_settings=batch, run_settings=model_settings_3, replicas=2 @@ -400,10 +402,10 @@ def test_preview_ensemble_db_script(wlmutils, test_dir): devices_per_node=2, first_device=0, ) - preview_manifest = Manifest(ensemble, orch, orch_2) + preview_manifest = Manifest(ensemble, feature_store, feature_store_2) # Call preview renderer for testing output - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Torch Script" in output diff --git a/tests/on_wlm/test_restart.py b/tests/_legacy/on_wlm/test_restart.py similarity index 85% rename from tests/on_wlm/test_restart.py rename to tests/_legacy/on_wlm/test_restart.py index 0116c10d39..8a8c383f2a 100644 --- a/tests/on_wlm/test_restart.py +++ b/tests/_legacy/on_wlm/test_restart.py @@ -29,7 +29,7 @@ import pytest from smartsim import Experiment -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # retrieved from pytest fixtures if pytest.test_launcher not in pytest.wlm_options: @@ -44,15 +44,15 @@ def test_restart(fileutils, test_dir, wlmutils): settings = exp.create_run_settings("python", f"{script} --time=5") settings.set_tasks(1) - M1 = exp.create_model("m1", path=test_dir, run_settings=settings) - M2 = exp.create_model("m2", path=test_dir, run_settings=deepcopy(settings)) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings) + M2 = exp.create_application("m2", path=test_dir, run_settings=deepcopy(settings)) exp.start(M1, M2, block=True) statuses = exp.get_status(M1, M2) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) exp.start(M1, M2, block=True) statuses = exp.get_status(M1, M2) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) # TODO add job history check here. diff --git a/tests/on_wlm/test_simple_base_settings_on_wlm.py b/tests/_legacy/on_wlm/test_simple_base_settings_on_wlm.py similarity index 80% rename from tests/on_wlm/test_simple_base_settings_on_wlm.py rename to tests/_legacy/on_wlm/test_simple_base_settings_on_wlm.py index caa55da3ed..80f6dc704e 100644 --- a/tests/on_wlm/test_simple_base_settings_on_wlm.py +++ b/tests/_legacy/on_wlm/test_simple_base_settings_on_wlm.py @@ -30,10 +30,10 @@ from smartsim import Experiment from smartsim.settings.settings import RunSettings -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus """ -Test the launch and stop of simple models and ensembles that use base +Test the launch and stop of simple applications and ensembles that use base RunSettings while on WLM that do not include a run command These tests will execute code (very light scripts) on the head node @@ -49,39 +49,39 @@ pytestmark = pytest.mark.skip(reason="Not testing WLM integrations") -def test_simple_model_on_wlm(fileutils, test_dir, wlmutils): +def test_simple_application_on_wlm(fileutils, test_dir, wlmutils): launcher = wlmutils.get_test_launcher() if launcher not in ["pbs", "slurm", "lsf"]: pytest.skip("Test only runs on systems with LSF, PBSPro, or Slurm as WLM") - exp_name = "test-simplebase-settings-model-launch" + exp_name = "test-simplebase-settings-application-launch" exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") settings = RunSettings("python", exe_args=f"{script} --time=5") - M = exp.create_model("m", path=test_dir, run_settings=settings) + M = exp.create_application("m", path=test_dir, run_settings=settings) - # launch model twice to show that it can also be restarted + # launch application twice to show that it can also be restarted for _ in range(2): exp.start(M, block=True) - assert exp.get_status(M)[0] == SmartSimStatus.STATUS_COMPLETED + assert exp.get_status(M)[0] == JobStatus.COMPLETED -def test_simple_model_stop_on_wlm(fileutils, test_dir, wlmutils): +def test_simple_application_stop_on_wlm(fileutils, test_dir, wlmutils): launcher = wlmutils.get_test_launcher() if launcher not in ["pbs", "slurm", "lsf"]: pytest.skip("Test only runs on systems with LSF, PBSPro, or Slurm as WLM") - exp_name = "test-simplebase-settings-model-stop" + exp_name = "test-simplebase-settings-application-stop" exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") settings = RunSettings("python", exe_args=f"{script} --time=5") - M = exp.create_model("m", path=test_dir, run_settings=settings) + M = exp.create_application("m", path=test_dir, run_settings=settings) - # stop launched model + # stop launched application exp.start(M, block=False) time.sleep(2) exp.stop(M) assert M.name in exp._control._jobs.completed - assert exp.get_status(M)[0] == SmartSimStatus.STATUS_CANCELLED + assert exp.get_status(M)[0] == JobStatus.CANCELLED diff --git a/tests/on_wlm/test_simple_entity_launch.py b/tests/_legacy/on_wlm/test_simple_entity_launch.py similarity index 85% rename from tests/on_wlm/test_simple_entity_launch.py rename to tests/_legacy/on_wlm/test_simple_entity_launch.py index 28ddf92f74..141aa781a6 100644 --- a/tests/on_wlm/test_simple_entity_launch.py +++ b/tests/_legacy/on_wlm/test_simple_entity_launch.py @@ -31,7 +31,7 @@ import pytest from smartsim import Experiment -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus """ Test the launch of simple entity types on pre-existing allocations. @@ -49,20 +49,20 @@ pytestmark = pytest.mark.skip(reason="Not testing WLM integrations") -def test_models(fileutils, test_dir, wlmutils): - exp_name = "test-models-launch" +def test_applications(fileutils, test_dir, wlmutils): + exp_name = "test-applications-launch" exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") settings = exp.create_run_settings("python", f"{script} --time=5") settings.set_tasks(1) - M1 = exp.create_model("m1", path=test_dir, run_settings=settings) - M2 = exp.create_model("m2", path=test_dir, run_settings=deepcopy(settings)) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings) + M2 = exp.create_application("m2", path=test_dir, run_settings=deepcopy(settings)) exp.start(M1, M2, block=True) statuses = exp.get_status(M1, M2) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) def test_multinode_app(mpi_app_path, test_dir, wlmutils): @@ -76,7 +76,7 @@ def test_multinode_app(mpi_app_path, test_dir, wlmutils): settings = exp.create_run_settings(str(mpi_app_path), []) settings.set_nodes(3) - model = exp.create_model("mpi_app", run_settings=settings) + model = exp.create_application("mpi_app", run_settings=settings) exp.generate(model) exp.start(model, block=True) @@ -108,7 +108,7 @@ def test_ensemble(fileutils, test_dir, wlmutils): exp.start(ensemble, block=True) statuses = exp.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) def test_summary(fileutils, test_dir, wlmutils): @@ -125,13 +125,15 @@ def test_summary(fileutils, test_dir, wlmutils): bad_settings = exp.create_run_settings("python", f"{bad} --time=6") bad_settings.set_tasks(1) - sleep_exp = exp.create_model("sleep", path=test_dir, run_settings=sleep_settings) - bad = exp.create_model("bad", path=test_dir, run_settings=bad_settings) + sleep_exp = exp.create_application( + "sleep", path=test_dir, run_settings=sleep_settings + ) + bad = exp.create_application("bad", path=test_dir, run_settings=bad_settings) # start and poll exp.start(sleep_exp, bad) - assert exp.get_status(bad)[0] == SmartSimStatus.STATUS_FAILED - assert exp.get_status(sleep_exp)[0] == SmartSimStatus.STATUS_COMPLETED + assert exp.get_status(bad)[0] == JobStatus.FAILED + assert exp.get_status(sleep_exp)[0] == JobStatus.COMPLETED summary_str = exp.summary(style="plain") print(summary_str) diff --git a/tests/on_wlm/test_slurm_commands.py b/tests/_legacy/on_wlm/test_slurm_commands.py similarity index 97% rename from tests/on_wlm/test_slurm_commands.py rename to tests/_legacy/on_wlm/test_slurm_commands.py index 8411be6e0a..b44d309650 100644 --- a/tests/on_wlm/test_slurm_commands.py +++ b/tests/_legacy/on_wlm/test_slurm_commands.py @@ -25,7 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pytest -from smartsim._core.launcher.slurm.slurmCommands import * +from smartsim._core.launcher.slurm.slurm_commands import * from smartsim.error.errors import LauncherError # retrieved from pytest fixtures diff --git a/tests/on_wlm/test_stop.py b/tests/_legacy/on_wlm/test_stop.py similarity index 90% rename from tests/on_wlm/test_stop.py rename to tests/_legacy/on_wlm/test_stop.py index abc7441bb2..77d781ccd0 100644 --- a/tests/on_wlm/test_stop.py +++ b/tests/_legacy/on_wlm/test_stop.py @@ -29,7 +29,7 @@ import pytest from smartsim import Experiment -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus """ Test Stopping launched entities. @@ -44,19 +44,19 @@ def test_stop_entity(fileutils, test_dir, wlmutils): - exp_name = "test-launch-stop-model" + exp_name = "test-launch-stop-application" exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") settings = exp.create_run_settings("python", f"{script} --time=10") settings.set_tasks(1) - M1 = exp.create_model("m1", path=test_dir, run_settings=settings) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings) exp.start(M1, block=False) time.sleep(5) exp.stop(M1) assert M1.name in exp._control._jobs.completed - assert exp.get_status(M1)[0] == SmartSimStatus.STATUS_CANCELLED + assert exp.get_status(M1)[0] == JobStatus.CANCELLED def test_stop_entity_list(fileutils, test_dir, wlmutils): @@ -73,5 +73,5 @@ def test_stop_entity_list(fileutils, test_dir, wlmutils): time.sleep(5) exp.stop(ensemble) statuses = exp.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_CANCELLED for stat in statuses]) + assert all([stat == JobStatus.CANCELLED for stat in statuses]) assert all([m.name in exp._control._jobs.completed for m in ensemble]) diff --git a/tests/on_wlm/test_wlm_orc_config_settings.py b/tests/_legacy/on_wlm/test_wlm_orc_config_settings.py similarity index 69% rename from tests/on_wlm/test_wlm_orc_config_settings.py rename to tests/_legacy/on_wlm/test_wlm_orc_config_settings.py index c74f2a497d..f4f14fbb7b 100644 --- a/tests/on_wlm/test_wlm_orc_config_settings.py +++ b/tests/_legacy/on_wlm/test_wlm_orc_config_settings.py @@ -43,61 +43,61 @@ pytestmark = pytest.mark.skip(reason="SmartRedis version is < 0.3.1") -def test_config_methods_on_wlm_single(dbutils, prepare_db, single_db): +def test_config_methods_on_wlm_single(fsutils, prepare_db, single_db): """Test all configuration file edit methods on single node WLM db""" - db = prepare_db(single_db).orchestrator + fs = prepare_fs(single_fs).featurestore # test the happy path and ensure all configuration file edit methods # successfully execute when given correct key-value pairs - configs = dbutils.get_db_configs() + configs = fsutils.get_fs_configs() for setting, value in configs.items(): logger.debug(f"Setting {setting}={value}") - config_set_method = dbutils.get_config_edit_method(db, setting) + config_set_method = fsutils.get_config_edit_method(fs, setting) config_set_method(value) - # ensure SmartSimError is raised when a clustered database's - # Orchestrator.set_db_conf is given invalid CONFIG key-value pairs - ss_error_configs = dbutils.get_smartsim_error_db_configs() + # ensure SmartSimError is raised when a clustered feature store's + # FeatureStore.set_fs_conf is given invalid CONFIG key-value pairs + ss_error_configs = fsutils.get_smartsim_error_fs_configs() for key, value_list in ss_error_configs.items(): for value in value_list: with pytest.raises(SmartSimError): - db.set_db_conf(key, value) + fs.set_fs_conf(key, value) - # ensure TypeError is raised when a clustered database's - # Orchestrator.set_db_conf is given invalid CONFIG key-value pairs - type_error_configs = dbutils.get_type_error_db_configs() + # ensure TypeError is raised when a clustered feature store's + # FeatureStore.set_fs_conf is given invalid CONFIG key-value pairs + type_error_configs = fsutils.get_type_error_fs_configs() for key, value_list in type_error_configs.items(): for value in value_list: with pytest.raises(TypeError): - db.set_db_conf(key, value) + fs.set_fs_conf(key, value) -def test_config_methods_on_wlm_cluster(dbutils, prepare_db, clustered_db): +def test_config_methods_on_wlm_cluster(fsutils, prepare_fs, clustered_fs): """Test all configuration file edit methods on an active clustered db""" - db = prepare_db(clustered_db).orchestrator + fs = prepare_fs(clustered_fs).featurestore # test the happy path and ensure all configuration file edit methods # successfully execute when given correct key-value pairs - configs = dbutils.get_db_configs() + configs = fsutils.get_fs_configs() for setting, value in configs.items(): logger.debug(f"Setting {setting}={value}") - config_set_method = dbutils.get_config_edit_method(db, setting) + config_set_method = fsutils.get_config_edit_method(fs, setting) config_set_method(value) - # ensure SmartSimError is raised when a clustered database's - # Orchestrator.set_db_conf is given invalid CONFIG key-value pairs - ss_error_configs = dbutils.get_smartsim_error_db_configs() + # ensure SmartSimError is raised when a clustered feature store's + # FeatureStore.set_fs_conf is given invalid CONFIG key-value pairs + ss_error_configs = fsutils.get_smartsim_error_fs_configs() for key, value_list in ss_error_configs.items(): for value in value_list: with pytest.raises(SmartSimError): logger.debug(f"Setting {key}={value}") - db.set_db_conf(key, value) + fs.set_fs_conf(key, value) - # ensure TypeError is raised when a clustered database's - # Orchestrator.set_db_conf is given invalid CONFIG key-value pairs - type_error_configs = dbutils.get_type_error_db_configs() + # ensure TypeError is raised when a clustered feature store's + # FeatureStore.set_fs_conf is given invalid CONFIG key-value pairs + type_error_configs = fsutils.get_type_error_fs_configs() for key, value_list in type_error_configs.items(): for value in value_list: with pytest.raises(TypeError): logger.debug(f"Setting {key}={value}") - db.set_db_conf(key, value) + fs.set_fs_conf(key, value) diff --git a/tests/test_alps_settings.py b/tests/_legacy/test_alps_settings.py similarity index 98% rename from tests/test_alps_settings.py rename to tests/_legacy/test_alps_settings.py index b3c4c3bdb4..f96d0e60db 100644 --- a/tests/test_alps_settings.py +++ b/tests/_legacy/test_alps_settings.py @@ -67,7 +67,7 @@ def test_aprun_add_mpmd(): def test_catch_colo_mpmd(): settings = AprunSettings("python") - settings.colocated_db_settings = {"port": 6379, "cpus": 1} + settings.colocated_fs_settings = {"port": 6379, "cpus": 1} settings_2 = AprunSettings("python") with pytest.raises(SSUnsupportedError): settings.make_mpmd(settings_2) diff --git a/tests/test_batch_settings.py b/tests/_legacy/test_batch_settings.py similarity index 100% rename from tests/test_batch_settings.py rename to tests/_legacy/test_batch_settings.py diff --git a/tests/test_cli.py b/tests/_legacy/test_cli.py similarity index 97% rename from tests/test_cli.py rename to tests/_legacy/test_cli.py index 1cead76251..c47ea046b7 100644 --- a/tests/test_cli.py +++ b/tests/_legacy/test_cli.py @@ -232,7 +232,7 @@ def test_cli_command_execution(capsys): exp_b_help = "this is my mock help text for build" exp_b_cmd = "build" - dbcli_exec = lambda x, y: mock_execute_custom(msg="Database", good=True) + dbcli_exec = lambda x, y: mock_execute_custom(msg="FeatureStore", good=True) build_exec = lambda x, y: mock_execute_custom(msg="Builder", good=True) menu = [ @@ -249,7 +249,7 @@ def test_cli_command_execution(capsys): captured = capsys.readouterr() # capture new output # show that `smart dbcli` calls the build parser and build execute function - assert "Database" in captured.out + assert "FeatureStore" in captured.out assert ret_val == 0 build_args = ["smart", exp_b_cmd] @@ -446,7 +446,6 @@ def mock_execute(ns: argparse.Namespace, _unparsed: t.Optional[t.List[str]] = No pytest.param( "build", "build_execute", "onnx mocked-build", "--skip-onnx", True, "", "onnx", True, id="Skip Onnx"), pytest.param( "build", "build_execute", "config-dir mocked-build", "--config-dir /foo/bar", True, "", "config-dir", "/foo/bar", id="set torch dir"), pytest.param( "build", "build_execute", "bad-config-dir mocked-build", "--config-dir", False, "error: argument --config-dir", "", "", id="set config dir w/o path"), - pytest.param( "build", "build_execute", "keydb mocked-build", "--keydb", True, "", "keydb", True, id="keydb on"), pytest.param( "clean", "clean_execute", "clobbering mocked-clean", "--clobber", True, "", "clobber", True, id="clean w/clobber"), pytest.param("validate", "validate_execute", "port mocked-validate", "--port=12345", True, "", "port", 12345, id="validate w/ manual port"), pytest.param("validate", "validate_execute", "abbrv port mocked-validate", "-p 12345", True, "", "port", 12345, id="validate w/ manual abbreviated port"), @@ -669,13 +668,13 @@ def mock_operation(*args, **kwargs) -> int: def test_cli_full_dbcli_execute(capsys, monkeypatch): """Ensure that the execute method of dbcli is called""" exp_retval = 0 - exp_output = "mocked-get_db_path utility" + exp_output = "mocked-get_fs_path utility" def mock_operation(*args, **kwargs) -> int: return exp_output - # mock out the internal get_db_path method so we don't actually do file system ops - monkeypatch.setattr(smartsim._core._cli.dbcli, "get_db_path", mock_operation) + # mock out the internal get_fs_path method so we don't actually do file system ops + monkeypatch.setattr(smartsim._core._cli.dbcli, "get_fs_path", mock_operation) command = "dbcli" cfg = MenuItemConfig(command, f"test {command} help text", dbcli_execute) @@ -702,7 +701,7 @@ def mock_operation(*args, **kwargs) -> int: print(exp_output) return exp_retval - # mock out the internal get_db_path method so we don't actually do file system ops + # mock out the internal get_fs_path method so we don't actually do file system ops monkeypatch.setattr(smartsim._core._cli.site, "get_install_path", mock_operation) command = "site" @@ -730,9 +729,8 @@ def mock_operation(*args, **kwargs) -> int: print(exp_output) return exp_retval - # mock out the internal get_db_path method so we don't actually do file system ops + # mock out the internal get_fs_path method so we don't actually do file system ops monkeypatch.setattr(smartsim._core._cli.build, "tabulate", mock_operation) - monkeypatch.setattr(smartsim._core._cli.build, "build_database", mock_operation) monkeypatch.setattr(smartsim._core._cli.build, "build_redis_ai", mock_operation) command = "build" diff --git a/tests/test_collector_manager.py b/tests/_legacy/test_collector_manager.py similarity index 93% rename from tests/test_collector_manager.py rename to tests/_legacy/test_collector_manager.py index 56add1ef7d..98e87c2ad6 100644 --- a/tests/test_collector_manager.py +++ b/tests/_legacy/test_collector_manager.py @@ -246,13 +246,13 @@ async def test_collector_manager_collect_filesink( @pytest.mark.asyncio async def test_collector_manager_collect_integration( - test_dir: str, mock_entity: MockCollectorEntityFunc, prepare_db, local_db, mock_sink + test_dir: str, mock_entity: MockCollectorEntityFunc, prepare_fs, local_fs, mock_sink ) -> None: """Ensure that all collectors are executed and some metric is retrieved""" - db = prepare_db(local_db).orchestrator - entity1 = mock_entity(port=db.ports[0], name="e1", telemetry_on=True) - entity2 = mock_entity(port=db.ports[0], name="e2", telemetry_on=True) + fs = prepare_fs(local_fs).featurestore + entity1 = mock_entity(port=fs.ports[0], name="e1", telemetry_on=True) + entity2 = mock_entity(port=fs.ports[0], name="e2", telemetry_on=True) # todo: consider a MockSink so i don't have to save the last value in the collector sinks = [mock_sink(), mock_sink(), mock_sink()] @@ -337,24 +337,24 @@ async def snooze() -> None: @pytest.mark.parametrize( "e_type,telemetry_on", [ - pytest.param("model", False, id="models"), - pytest.param("model", True, id="models, telemetry enabled"), + pytest.param("application", False, id="applications"), + pytest.param("application", True, id="applications, telemetry enabled"), pytest.param("ensemble", False, id="ensemble"), pytest.param("ensemble", True, id="ensemble, telemetry enabled"), - pytest.param("orchestrator", False, id="orchestrator"), - pytest.param("orchestrator", True, id="orchestrator, telemetry enabled"), - pytest.param("dbnode", False, id="dbnode"), - pytest.param("dbnode", True, id="dbnode, telemetry enabled"), + pytest.param("featurestore", False, id="featurestore"), + pytest.param("featurestore", True, id="featurestore, telemetry enabled"), + pytest.param("fsnode", False, id="fsnode"), + pytest.param("fsnode", True, id="fsnode, telemetry enabled"), ], ) @pytest.mark.asyncio -async def test_collector_manager_find_nondb( +async def test_collector_manager_find_nonfs( mock_entity: MockCollectorEntityFunc, e_type: str, telemetry_on: bool, ) -> None: """Ensure that the number of collectors returned for entity types match expectations - NOTE: even orchestrator returns 0 mapped collectors because no collector output + NOTE: even featurestore returns 0 mapped collectors because no collector output paths are set on the entity""" entity = mock_entity(port=1234, name="e1", type=e_type, telemetry_on=telemetry_on) manager = CollectorManager(timeout_ms=10000) @@ -371,7 +371,7 @@ async def test_collector_manager_find_nondb( async def test_collector_manager_find_db(mock_entity: MockCollectorEntityFunc) -> None: """Ensure that the manifest allows individually enabling a given collector""" entity: JobEntity = mock_entity( - port=1234, name="entity1", type="model", telemetry_on=True + port=1234, name="entity1", type="application", telemetry_on=True ) manager = CollectorManager() @@ -383,7 +383,7 @@ async def test_collector_manager_find_db(mock_entity: MockCollectorEntityFunc) - # 1. ensure DBConnectionCountCollector is mapped entity = mock_entity( - port=1234, name="entity1", type="orchestrator", telemetry_on=True + port=1234, name="entity1", type="featurestore", telemetry_on=True ) entity.collectors["client"] = "mock/path.csv" manager = CollectorManager() @@ -397,7 +397,7 @@ async def test_collector_manager_find_db(mock_entity: MockCollectorEntityFunc) - # 3. ensure DBConnectionCountCollector is mapped entity = mock_entity( - port=1234, name="entity1", type="orchestrator", telemetry_on=True + port=1234, name="entity1", type="featurestore", telemetry_on=True ) entity.collectors["client_count"] = "mock/path.csv" manager = CollectorManager() @@ -411,7 +411,7 @@ async def test_collector_manager_find_db(mock_entity: MockCollectorEntityFunc) - # ensure DbMemoryCollector is mapped entity = mock_entity( - port=1234, name="entity1", type="orchestrator", telemetry_on=True + port=1234, name="entity1", type="featurestore", telemetry_on=True ) entity.collectors["memory"] = "mock/path.csv" manager = CollectorManager() @@ -429,7 +429,7 @@ async def test_collector_manager_find_entity_disabled( mock_entity: MockCollectorEntityFunc, ) -> None: """Ensure that disabling telemetry on the entity results in no collectors""" - entity: JobEntity = mock_entity(port=1234, name="entity1", type="orchestrator") + entity: JobEntity = mock_entity(port=1234, name="entity1", type="featurestore") # set paths for all known collectors entity.collectors["client"] = "mock/path.csv" @@ -457,7 +457,7 @@ async def test_collector_manager_find_entity_unmapped( ) -> None: """Ensure that an entity type that is not mapped results in no collectors""" entity: JobEntity = mock_entity( - port=1234, name="entity1", type="model", telemetry_on=True + port=1234, name="entity1", type="application", telemetry_on=True ) manager = CollectorManager() diff --git a/tests/test_collector_sink.py b/tests/_legacy/test_collector_sink.py similarity index 100% rename from tests/test_collector_sink.py rename to tests/_legacy/test_collector_sink.py diff --git a/tests/test_collectors.py b/tests/_legacy/test_collectors.py similarity index 94% rename from tests/test_collectors.py rename to tests/_legacy/test_collectors.py index 2eb61d62da..a474632c2b 100644 --- a/tests/test_collectors.py +++ b/tests/_legacy/test_collectors.py @@ -29,7 +29,7 @@ import pytest -import smartsim._core.entrypoints.telemetrymonitor +import smartsim._core.entrypoints.telemetry_monitor import smartsim._core.utils.telemetry.collector from conftest import MockCollectorEntityFunc, MockSink from smartsim._core.utils.telemetry.collector import ( @@ -42,7 +42,7 @@ # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a -PrepareDB = t.Callable[[dict], smartsim.experiment.Orchestrator] +PrepareFS = t.Callable[[dict], smartsim.experiment.FeatureStore] @pytest.mark.asyncio @@ -173,15 +173,15 @@ async def test_dbmemcollector_collect( async def test_dbmemcollector_integration( mock_entity: MockCollectorEntityFunc, mock_sink: MockSink, - prepare_db: PrepareDB, - local_db: dict, + prepare_fs: PrepareFS, + local_fs: dict, monkeypatch: pytest.MonkeyPatch, ) -> None: - """Integration test with a real orchestrator instance to ensure + """Integration test with a real feature store instance to ensure output data matches expectations and proper db client API uage""" - db = prepare_db(local_db).orchestrator - entity = mock_entity(port=db.ports[0], telemetry_on=True) + fs = prepare_fs(local_fs).featurestore + entity = mock_entity(port=fs.ports[0], telemetry_on=True) sink = mock_sink() collector = DBMemoryCollector(entity, sink) @@ -273,15 +273,15 @@ async def test_dbconn_count_collector_collect( async def test_dbconncollector_integration( mock_entity: MockCollectorEntityFunc, mock_sink: MockSink, - prepare_db: PrepareDB, - local_db: dict, + prepare_fs: PrepareFS, + local_fs: dict, monkeypatch: pytest.MonkeyPatch, ) -> None: - """Integration test with a real orchestrator instance to ensure + """Integration test with a real feature store instance to ensure output data matches expectations and proper db client API uage""" - db = prepare_db(local_db).orchestrator - entity = mock_entity(port=db.ports[0], telemetry_on=True) + fs = prepare_fs(local_fs).featurestore + entity = mock_entity(port=fs.ports[0], telemetry_on=True) sink = mock_sink() collector = DBConnectionCollector(entity, sink) diff --git a/tests/_legacy/test_colo_model_local.py b/tests/_legacy/test_colo_model_local.py new file mode 100644 index 0000000000..1ab97c4cc3 --- /dev/null +++ b/tests/_legacy/test_colo_model_local.py @@ -0,0 +1,324 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +import pytest + +from smartsim import Experiment +from smartsim.entity import Application +from smartsim.error import SSUnsupportedError +from smartsim.status import JobStatus + +# The tests in this file belong to the slow_tests group +pytestmark = pytest.mark.slow_tests + + +if sys.platform == "darwin": + supported_fss = ["tcp", "deprecated"] +else: + supported_fss = ["uds", "tcp", "deprecated"] + +is_mac = sys.platform == "darwin" + + +@pytest.mark.skipif(not is_mac, reason="MacOS-only test") +def test_macosx_warning(fileutils, test_dir, coloutils): + fs_args = {"custom_pinning": [1]} + fs_type = "uds" # Test is insensitive to choice of fs + + exp = Experiment( + "colocated_application_defaults", launcher="local", exp_path=test_dir + ) + with pytest.warns( + RuntimeWarning, + match="CPU pinning is not supported on MacOSX. Ignoring pinning specification.", + ): + _ = coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + ) + + +def test_unsupported_limit_app(fileutils, test_dir, coloutils): + fs_args = {"limit_app_cpus": True} + fs_type = "uds" # Test is insensitive to choice of fs + + exp = Experiment( + "colocated_application_defaults", launcher="local", exp_path=test_dir + ) + with pytest.raises(SSUnsupportedError): + coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + ) + + +@pytest.mark.skipif(is_mac, reason="Unsupported on MacOSX") +@pytest.mark.parametrize("custom_pinning", [1, "10", "#", 1.0, ["a"], [1.0]]) +def test_unsupported_custom_pinning(fileutils, test_dir, coloutils, custom_pinning): + fs_type = "uds" # Test is insensitive to choice of fs + fs_args = {"custom_pinning": custom_pinning} + + exp = Experiment( + "colocated_application_defaults", launcher="local", exp_path=test_dir + ) + with pytest.raises(TypeError): + coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + ) + + +@pytest.mark.skipif(is_mac, reason="Unsupported on MacOSX") +@pytest.mark.parametrize( + "pin_list, num_cpus, expected", + [ + pytest.param(None, 2, "0,1", id="Automatic creation of pinned cpu list"), + pytest.param([1, 2], 2, "1,2", id="Individual ids only"), + pytest.param([range(2), 3], 3, "0,1,3", id="Mixed ranges and individual ids"), + pytest.param(range(3), 3, "0,1,2", id="Range only"), + pytest.param( + [range(8, 10), range(6, 1, -2)], 4, "2,4,6,8,9", id="Multiple ranges" + ), + ], +) +def test_create_pinning_string(pin_list, num_cpus, expected): + assert Application._create_pinning_string(pin_list, num_cpus) == expected + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_launch_colocated_application_defaults( + fileutils, test_dir, coloutils, fs_type, launcher="local" +): + """Test the launch of a application with a colocated feature store and local launcher""" + + fs_args = {} + + exp = Experiment( + "colocated_application_defaults", launcher=launcher, exp_path=test_dir + ) + colo_application = coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + ) + + if is_mac: + true_pinning = None + else: + true_pinning = "0" + assert ( + colo_application.run_settings.colocated_fs_settings["custom_pinning"] + == true_pinning + ) + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all(stat == JobStatus.COMPLETED for stat in statuses) + + # test restarting the colocated application + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all(stat == JobStatus.COMPLETED for stat in statuses), f"Statuses {statuses}" + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_launch_multiple_colocated_applications( + fileutils, test_dir, coloutils, wlmutils, fs_type, launcher="local" +): + """Test the concurrent launch of two applications with a colocated feature store and local launcher""" + + fs_args = {} + + exp = Experiment("multi_colo_applications", launcher=launcher, exp_path=test_dir) + colo_applications = [ + coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + colo_application_name="colo0", + port=wlmutils.get_test_port(), + ), + coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + colo_application_name="colo1", + port=wlmutils.get_test_port() + 1, + ), + ] + exp.generate(*colo_applications) + exp.start(*colo_applications, block=True) + statuses = exp.get_status(*colo_applications) + assert all(stat == JobStatus.COMPLETED for stat in statuses) + + # test restarting the colocated application + exp.start(*colo_applications, block=True) + statuses = exp.get_status(*colo_applications) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_colocated_application_disable_pinning( + fileutils, test_dir, coloutils, fs_type, launcher="local" +): + exp = Experiment( + "colocated_application_pinning_auto_1cpu", launcher=launcher, exp_path=test_dir + ) + fs_args = { + "fs_cpus": 1, + "custom_pinning": [], + } + # Check to make sure that the CPU mask was correctly generated + colo_application = coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + ) + assert colo_application.run_settings.colocated_fs_settings["custom_pinning"] is None + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) + + +@pytest.mark.parametrize("fs_type", supported_fss) +def test_colocated_application_pinning_auto_2cpu( + fileutils, test_dir, coloutils, fs_type, launcher="local" +): + exp = Experiment( + "colocated_application_pinning_auto_2cpu", launcher=launcher, exp_path=test_dir + ) + + fs_args = { + "fs_cpus": 2, + } + + # Check to make sure that the CPU mask was correctly generated + colo_application = coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + ) + if is_mac: + true_pinning = None + else: + true_pinning = "0,1" + assert ( + colo_application.run_settings.colocated_fs_settings["custom_pinning"] + == true_pinning + ) + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) + + +@pytest.mark.skipif(is_mac, reason="unsupported on MacOSX") +@pytest.mark.parametrize("fs_type", supported_fss) +def test_colocated_application_pinning_range( + fileutils, test_dir, coloutils, fs_type, launcher="local" +): + # Check to make sure that the CPU mask was correctly generated + + exp = Experiment( + "colocated_application_pinning_manual", launcher=launcher, exp_path=test_dir + ) + + fs_args = {"fs_cpus": 2, "custom_pinning": range(2)} + + colo_application = coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + ) + assert ( + colo_application.run_settings.colocated_fs_settings["custom_pinning"] == "0,1" + ) + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) + + +@pytest.mark.skipif(is_mac, reason="unsupported on MacOSX") +@pytest.mark.parametrize("fs_type", supported_fss) +def test_colocated_application_pinning_list( + fileutils, test_dir, coloutils, fs_type, launcher="local" +): + # Check to make sure that the CPU mask was correctly generated + + exp = Experiment( + "colocated_application_pinning_manual", launcher=launcher, exp_path=test_dir + ) + + fs_args = {"fs_cpus": 1, "custom_pinning": [1]} + + colo_application = coloutils.setup_test_colo( + fileutils, + fs_type, + exp, + "send_data_local_smartredis.py", + fs_args, + ) + assert colo_application.run_settings.colocated_fs_settings["custom_pinning"] == "1" + exp.generate(colo_application) + exp.start(colo_application, block=True) + statuses = exp.get_status(colo_application) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) + + +def test_colo_uds_verifies_socket_file_name(test_dir, launcher="local"): + exp = Experiment(f"colo_uds_wrong_name", launcher=launcher, exp_path=test_dir) + + colo_settings = exp.create_run_settings(exe=sys.executable, exe_args=["--version"]) + + colo_application = exp.create_application("wrong_uds_socket_name", colo_settings) + + with pytest.raises(ValueError): + colo_application.colocate_fs_uds(unix_socket="this is not a valid name!") diff --git a/tests/test_colo_model_lsf.py b/tests/_legacy/test_colo_model_lsf.py similarity index 70% rename from tests/test_colo_model_lsf.py rename to tests/_legacy/test_colo_model_lsf.py index 5e1c449cca..17e75caee6 100644 --- a/tests/test_colo_model_lsf.py +++ b/tests/_legacy/test_colo_model_lsf.py @@ -30,7 +30,7 @@ import smartsim.settings.base from smartsim import Experiment -from smartsim.entity import Model +from smartsim.entity import Application from smartsim.settings.lsfSettings import JsrunSettings # The tests in this file belong to the group_a group @@ -47,29 +47,29 @@ class ExpectationMet(Exception): def show_expectation_met(*args, **kwargs): - raise ExpectationMet("mock._prep_colocated_db") + raise ExpectationMet("mock._prep_colocated_fs") def test_jsrun_prep(fileutils, coloutils, monkeypatch): """Ensure that JsrunSettings prep method is executed as expected""" monkeypatch.setattr(smartsim.settings.base, "expand_exe_path", lambda x: "/bin/{x}") # mock the prep method to raise an exception that short circuits test when goal is met - monkeypatch.setattr(JsrunSettings, "_prep_colocated_db", show_expectation_met) + monkeypatch.setattr(JsrunSettings, "_prep_colocated_fs", show_expectation_met) - db_args = {"custom_pinning": [1]} - db_type = "uds" # Test is insensitive to choice of db + fs_args = {"custom_pinning": [1]} + fs_type = "uds" # Test is insensitive to choice of fs - exp = Experiment("colocated_model_lsf", launcher="lsf") + exp = Experiment("colocated_application_lsf", launcher="lsf") - with pytest.raises(ExpectationMet, match="mock._prep_colocated_db") as ex: + with pytest.raises(ExpectationMet, match="mock._prep_colocated_fs") as ex: run_settings = JsrunSettings("foo") coloutils.setup_test_colo( fileutils, - db_type, + fs_type, exp, "send_data_local_smartredis.py", - db_args, + fs_args, colo_settings=run_settings, ) @@ -78,25 +78,25 @@ def test_non_js_run_prep(fileutils, coloutils, monkeypatch): """Ensure that RunSettings does not attempt to call a prep method""" monkeypatch.setattr(smartsim.settings.base, "expand_exe_path", lambda x: "/bin/{x}") # mock prep method to ensure that the exception isn't thrown w/non-JsrunSettings arg - monkeypatch.setattr(JsrunSettings, "_prep_colocated_db", show_expectation_met) + monkeypatch.setattr(JsrunSettings, "_prep_colocated_fs", show_expectation_met) - db_args = {"custom_pinning": [1]} - db_type = "tcp" # Test is insensitive to choice of db + fs_args = {"custom_pinning": [1]} + fs_type = "tcp" # Test is insensitive to choice of fs - exp = Experiment("colocated_model_lsf", launcher="lsf") + exp = Experiment("colocated_application_lsf", launcher="lsf") run_settings = smartsim.settings.base.RunSettings("foo") - colo_model: Model = coloutils.setup_test_colo( + colo_application: Application = coloutils.setup_test_colo( fileutils, - db_type, + fs_type, exp, "send_data_local_smartredis.py", - db_args, + fs_args, colo_settings=run_settings, ) - assert colo_model + assert colo_application @pytest.mark.parametrize( @@ -119,30 +119,30 @@ def test_jsrun_prep_cpu_per_flag_set_check( exp_value, test_value, ): - """Ensure that _prep_colocated_db honors basic cpu_per_rs config and allows a + """Ensure that _prep_colocated_fs honors basic cpu_per_rs config and allows a valid input parameter to result in the correct output. If no expected input (or incorrect key) is given, the default should be returned using default config key""" monkeypatch.setattr(smartsim.settings.base, "expand_exe_path", lambda x: "/bin/{x}") - # excluding "db_cpus" should result in default value in comparison & output - db_args = {"custom_pinning": [1]} - db_type = "uds" # Test is insensitive to choice of db + # excluding "fs_cpus" should result in default value in comparison & output + fs_args = {"custom_pinning": [1]} + fs_type = "uds" # Test is insensitive to choice of fs - exp = Experiment("colocated_model_lsf", launcher="lsf") + exp = Experiment("colocated_application_lsf", launcher="lsf") run_args = {run_arg_key: test_value} run_settings = JsrunSettings("foo", run_args=run_args) - colo_model: Model = coloutils.setup_test_colo( + colo_application: Application = coloutils.setup_test_colo( fileutils, - db_type, + fs_type, exp, "send_data_local_smartredis.py", - db_args, + fs_args, colo_settings=run_settings, ) - assert colo_model.run_settings.run_args[exp_run_arg_key] == exp_value + assert colo_application.run_settings.run_args[exp_run_arg_key] == exp_value @pytest.mark.parametrize( @@ -151,14 +151,14 @@ def test_jsrun_prep_cpu_per_flag_set_check( pytest.param("cpu_per_rs", "cpu_per_rs", 11, 11, id="cpu_per_rs matches input"), pytest.param("c", "c", 22, 22, id="c matches input"), pytest.param( - "cpu_per_rs", "cpu_per_rsx", 3, 33, id="key typo: db_cpus out (not default)" + "cpu_per_rs", "cpu_per_rsx", 3, 33, id="key typo: fs_cpus out (not default)" ), pytest.param( - "cpu_per_rs", "cx", 3, 44, id="key typo: get db_cpus out (not default)" + "cpu_per_rs", "cx", 3, 44, id="key typo: get fs_cpus out (not default)" ), ], ) -def test_jsrun_prep_db_cpu_override( +def test_jsrun_prep_fs_cpu_override( fileutils, coloutils, monkeypatch, @@ -167,42 +167,42 @@ def test_jsrun_prep_db_cpu_override( exp_value, test_value, ): - """Ensure that both cpu_per_rs and c input config override db_cpus""" + """Ensure that both cpu_per_rs and c input config override fs_cpus""" monkeypatch.setattr(smartsim.settings.base, "expand_exe_path", lambda x: "/bin/{x}") - # setting "db_cpus" should result in non-default value in comparison & output - db_args = {"custom_pinning": [1], "db_cpus": 3} - db_type = "tcp" # Test is insensitive to choice of db + # setting "fs_cpus" should result in non-default value in comparison & output + fs_args = {"custom_pinning": [1], "fs_cpus": 3} + fs_type = "tcp" # Test is insensitive to choice of fs - exp = Experiment("colocated_model_lsf", launcher="lsf") + exp = Experiment("colocated_application_lsf", launcher="lsf") run_args = {run_arg_key: test_value} run_settings = JsrunSettings("foo", run_args=run_args) - colo_model: Model = coloutils.setup_test_colo( + colo_application: Application = coloutils.setup_test_colo( fileutils, - db_type, + fs_type, exp, "send_data_local_smartredis.py", - db_args, + fs_args, colo_settings=run_settings, ) - assert colo_model.run_settings.run_args[exp_run_arg_key] == exp_value + assert colo_application.run_settings.run_args[exp_run_arg_key] == exp_value @pytest.mark.parametrize( "exp_run_arg_key,run_arg_key,exp_value,test_value", [ pytest.param( - "cpu_per_rs", "cpu_per_rs", 8, 3, id="cpu_per_rs swaps to db_cpus" + "cpu_per_rs", "cpu_per_rs", 8, 3, id="cpu_per_rs swaps to fs_cpus" ), - pytest.param("c", "c", 8, 4, id="c swaps to db_cpus"), - pytest.param("cpu_per_rs", "cpu_per_rsx", 8, 5, id="key typo: db_cpus out"), - pytest.param("cpu_per_rs", "cx", 8, 6, id="key typo: get db_cpus out"), + pytest.param("c", "c", 8, 4, id="c swaps to fs_cpus"), + pytest.param("cpu_per_rs", "cpu_per_rsx", 8, 5, id="key typo: fs_cpus out"), + pytest.param("cpu_per_rs", "cx", 8, 6, id="key typo: get fs_cpus out"), ], ) -def test_jsrun_prep_db_cpu_replacement( +def test_jsrun_prep_fs_cpu_replacement( fileutils, coloutils, monkeypatch, @@ -211,28 +211,28 @@ def test_jsrun_prep_db_cpu_replacement( exp_value, test_value, ): - """Ensure that db_cpus default is used if user config suggests underutilizing resources""" + """Ensure that fs_cpus default is used if user config suggests underutilizing resources""" monkeypatch.setattr(smartsim.settings.base, "expand_exe_path", lambda x: "/bin/{x}") - # setting "db_cpus" should result in non-default value in comparison & output - db_args = {"custom_pinning": [1], "db_cpus": 8} - db_type = "uds" # Test is insensitive to choice of db + # setting "fs_cpus" should result in non-default value in comparison & output + fs_args = {"custom_pinning": [1], "fs_cpus": 8} + fs_type = "uds" # Test is insensitive to choice of fs - exp = Experiment("colocated_model_lsf", launcher="lsf") + exp = Experiment("colocated_application_lsf", launcher="lsf") run_args = {run_arg_key: test_value} run_settings = JsrunSettings("foo", run_args=run_args) - colo_model: Model = coloutils.setup_test_colo( + colo_application: Application = coloutils.setup_test_colo( fileutils, - db_type, + fs_type, exp, "send_data_local_smartredis.py", - db_args, + fs_args, colo_settings=run_settings, ) - assert colo_model.run_settings.run_args[exp_run_arg_key] == exp_value + assert colo_application.run_settings.run_args[exp_run_arg_key] == exp_value @pytest.mark.parametrize( @@ -265,22 +265,24 @@ def test_jsrun_prep_rs_per_host( required to meet limitations (e.g. rs_per_host MUST equal 1)""" monkeypatch.setattr(smartsim.settings.base, "expand_exe_path", lambda x: "/bin/{x}") - db_args = {"custom_pinning": [1]} - db_type = "tcp" # Test is insensitive to choice of db + fs_args = {"custom_pinning": [1]} + fs_type = "tcp" # Test is insensitive to choice of fs - exp = Experiment("colocated_model_lsf", launcher="lsf") + exp = Experiment("colocated_application_lsf", launcher="lsf") run_args = {run_arg_key: test_value} run_settings = JsrunSettings("foo", run_args=run_args) - colo_model: Model = coloutils.setup_test_colo( + colo_application: Application = coloutils.setup_test_colo( fileutils, - db_type, + fs_type, exp, "send_data_local_smartredis.py", - db_args, + fs_args, colo_settings=run_settings, ) - # NOTE: _prep_colocated_db sets this to a string & not an integer - assert str(colo_model.run_settings.run_args[exp_run_arg_key]) == str(exp_value) + # NOTE: _prep_colocated_fs sets this to a string & not an integer + assert str(colo_application.run_settings.run_args[exp_run_arg_key]) == str( + exp_value + ) diff --git a/tests/test_config.py b/tests/_legacy/test_config.py similarity index 100% rename from tests/test_config.py rename to tests/_legacy/test_config.py diff --git a/tests/test_containers.py b/tests/_legacy/test_containers.py similarity index 85% rename from tests/test_containers.py rename to tests/_legacy/test_containers.py index 5d0f933fff..cc16d9f0d7 100644 --- a/tests/test_containers.py +++ b/tests/_legacy/test_containers.py @@ -34,7 +34,7 @@ from smartsim import Experiment, status from smartsim.entity import Ensemble from smartsim.settings.containers import Singularity -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a @@ -99,17 +99,17 @@ def test_singularity_basic(fileutils, test_dir): run_settings = exp.create_run_settings( "python3", "sleep.py --time=3", container=container ) - model = exp.create_model("singularity_basic", run_settings) + application = exp.create_application("singularity_basic", run_settings) script = fileutils.get_test_conf_path("sleep.py") - model.attach_generator_files(to_copy=[script]) - exp.generate(model) + application.attach_generator_files(to_copy=[script]) + exp.generate(application) - exp.start(model, summary=False) + exp.start(application, summary=False) # get and confirm status - stat = exp.get_status(model)[0] - assert stat == SmartSimStatus.STATUS_COMPLETED + stat = exp.get_status(application)[0] + assert stat == JobStatus.COMPLETED print(exp.summary()) @@ -127,32 +127,32 @@ def test_singularity_args(fileutils, test_dir): run_settings = exp.create_run_settings( "python3", "test/check_dirs.py", container=container ) - model = exp.create_model("singularity_args", run_settings) + application = exp.create_application("singularity_args", run_settings) script = fileutils.get_test_conf_path("check_dirs.py") - model.attach_generator_files(to_copy=[script]) - exp.generate(model) + application.attach_generator_files(to_copy=[script]) + exp.generate(application) - exp.start(model, summary=False) + exp.start(application, summary=False) # get and confirm status - stat = exp.get_status(model)[0] - assert stat == SmartSimStatus.STATUS_COMPLETED + stat = exp.get_status(application)[0] + assert stat == JobStatus.COMPLETED print(exp.summary()) @pytest.mark.skipif(not singularity_exists, reason="Test needs singularity to run") -def test_singularity_smartredis(local_experiment, prepare_db, local_db, fileutils): +def test_singularity_smartredis(local_experiment, prepare_fs, local_fs, fileutils): """Run two processes, each process puts a tensor on the DB, then accesses the other process's tensor. - Finally, the tensor is used to run a model. + Finally, the tensor is used to run a application. Note: This is a containerized port of test_smartredis.py """ # create and start a database - db = prepare_db(local_db).orchestrator - local_experiment.reconnect_orchestrator(db.checkpoint_file) + fs = prepare_fs(local_fs).featurestore + local_experiment.reconnect_feature_store(fs.checkpoint_file) container = Singularity(containerURI) @@ -175,10 +175,10 @@ def test_singularity_smartredis(local_experiment, prepare_db, local_db, fileutil local_experiment.generate(ensemble) - # start the models + # start the applications local_experiment.start(ensemble, summary=False) # get and confirm statuses statuses = local_experiment.get_status(ensemble) - if not all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]): + if not all([stat == JobStatus.COMPLETED for stat in statuses]): assert False # client ensemble failed diff --git a/tests/test_controller.py b/tests/_legacy/test_controller.py similarity index 90% rename from tests/test_controller.py rename to tests/_legacy/test_controller.py index 1498727085..ad0c98fe88 100644 --- a/tests/test_controller.py +++ b/tests/_legacy/test_controller.py @@ -30,8 +30,8 @@ from smartsim._core.control.controller import Controller from smartsim._core.launcher.step import Step -from smartsim.database.orchestrator import Orchestrator -from smartsim.entity.ensemble import Ensemble +from smartsim.builders.ensemble import Ensemble +from smartsim.database.orchestrator import FeatureStore from smartsim.settings.slurmSettings import SbatchSettings, SrunSettings controller = Controller() @@ -40,7 +40,9 @@ bs = SbatchSettings() ens = Ensemble("ens", params={}, run_settings=rs, batch_settings=bs, replicas=3) -orc = Orchestrator(db_nodes=3, batch=True, launcher="slurm", run_command="srun") +feature_store = FeatureStore( + fs_nodes=3, batch=True, launcher="slurm", run_command="srun" +) class MockStep(Step): @@ -58,7 +60,7 @@ def get_launch_cmd(self): "collection", [ pytest.param(ens, id="Ensemble"), - pytest.param(orc, id="Database"), + pytest.param(feature_store, id="FeatureStore"), ], ) def test_controller_batch_step_creation_preserves_entity_order(collection, monkeypatch): diff --git a/tests/test_controller_errors.py b/tests/_legacy/test_controller_errors.py similarity index 74% rename from tests/test_controller_errors.py rename to tests/_legacy/test_controller_errors.py index 2d623cdd1a..5ae05d70ad 100644 --- a/tests/test_controller_errors.py +++ b/tests/_legacy/test_controller_errors.py @@ -29,10 +29,10 @@ from smartsim._core.control import Controller, Manifest from smartsim._core.launcher.step import Step -from smartsim._core.launcher.step.dragonStep import DragonStep -from smartsim.database import Orchestrator -from smartsim.entity import Model -from smartsim.entity.ensemble import Ensemble +from smartsim._core.launcher.step.dragon_step import DragonStep +from smartsim.builders.ensemble import Ensemble +from smartsim.database import FeatureStore +from smartsim.entity import Application from smartsim.error import SmartSimError, SSUnsupportedError from smartsim.error.errors import SSUnsupportedError from smartsim.settings import RunSettings, SrunSettings @@ -41,22 +41,28 @@ pytestmark = pytest.mark.group_a entity_settings = SrunSettings("echo", ["spam", "eggs"]) -model_dup_setting = RunSettings("echo", ["spam_1", "eggs_2"]) -model = Model("model_name", run_settings=entity_settings, params={}, path="") -# Model entity slightly different but with same name -model_2 = Model("model_name", run_settings=model_dup_setting, params={}, path="") +application_dup_setting = RunSettings("echo", ["spam_1", "eggs_2"]) +application = Application( + "application_name", run_settings=entity_settings, params={}, path="" +) +# Application entity slightly different but with same name +application_2 = Application( + "application_name", run_settings=application_dup_setting, params={}, path="" +) ens = Ensemble("ensemble_name", params={}, run_settings=entity_settings, replicas=2) # Ensemble entity slightly different but with same name ens_2 = Ensemble("ensemble_name", params={}, run_settings=entity_settings, replicas=3) -orc = Orchestrator(db_nodes=3, batch=True, launcher="slurm", run_command="srun") +feature_store = FeatureStore( + fs_nodes=3, batch=True, launcher="slurm", run_command="srun" +) -def test_finished_entity_orc_error(): - """Orchestrators are never 'finished', either run forever or stopped by user""" - orc = Orchestrator() +def test_finished_entity_feature_store_error(): + """FeatureStores are never 'finished', either run forever or stopped by user""" + feature_store = FeatureStore() cont = Controller(launcher="local") with pytest.raises(TypeError): - cont.finished(orc) + cont.finished(feature_store) def test_finished_entity_wrong_type(): @@ -67,12 +73,12 @@ def test_finished_entity_wrong_type(): def test_finished_not_found(): - """Ask if model is finished that hasnt been launched by this experiment""" + """Ask if application is finished that hasnt been launched by this experiment""" rs = RunSettings("python") - model = Model("hello", {}, "./", rs) + application = Application("hello", {}, "./", rs) cont = Controller(launcher="local") with pytest.raises(ValueError): - cont.finished(model) + cont.finished(application) def test_entity_status_wrong_type(): @@ -101,26 +107,26 @@ def test_no_launcher(): cont.init_launcher(None) -def test_wrong_orchestrator(wlmutils): +def test_wrong_feature_store(wlmutils): # lo interface to avoid warning from SmartSim - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, interface="lo", run_command="aprun", launcher="pbs", ) cont = Controller(launcher="local") - manifest = Manifest(orc) + manifest = Manifest(feature_store) with pytest.raises(SmartSimError): cont._launch("exp_name", "exp_path", manifest) -def test_bad_orc_checkpoint(): +def test_bad_feature_store_checkpoint(): checkpoint = "./bad-checkpoint" cont = Controller(launcher="local") with pytest.raises(FileNotFoundError): - cont.reload_saved_db(checkpoint) + cont.reload_saved_fs(checkpoint) class MockStep(Step): @@ -136,13 +142,13 @@ def get_launch_cmd(self): "entity", [ pytest.param(ens, id="Ensemble_running"), - pytest.param(model, id="Model_running"), - pytest.param(orc, id="Orch_running"), + pytest.param(application, id="Application_running"), + pytest.param(orc, id="Feature_store_running"), ], ) def test_duplicate_running_entity(test_dir, wlmutils, entity): """This test validates that users cannot reuse entity names - that are running in JobManager.jobs or JobManager.db_jobs + that are running in JobManager.jobs or JobManager.fs_jobs """ step_settings = RunSettings("echo") step = MockStep("mock-step", test_dir, step_settings) @@ -156,10 +162,13 @@ def test_duplicate_running_entity(test_dir, wlmutils, entity): @pytest.mark.parametrize( "entity", - [pytest.param(ens, id="Ensemble_running"), pytest.param(model, id="Model_running")], + [ + pytest.param(ens, id="Ensemble_running"), + pytest.param(application, id="Application_running"), + ], ) def test_restarting_entity(test_dir, wlmutils, entity): - """Validate restarting a completed Model/Ensemble job""" + """Validate restarting a completed Application/Ensemble job""" step_settings = RunSettings("echo") test_launcher = wlmutils.get_test_launcher() step = MockStep("mock-step", test_dir, step_settings) @@ -171,28 +180,28 @@ def test_restarting_entity(test_dir, wlmutils, entity): controller._launch_step(step, entity=entity) -def test_restarting_orch(test_dir, wlmutils): - """Validate restarting a completed Orchestrator job""" +def test_restarting_feature_store(test_dir, wlmutils): + """Validate restarting a completed FeatureStore job""" step_settings = RunSettings("echo") test_launcher = wlmutils.get_test_launcher() step = MockStep("mock-step", test_dir, step_settings) step.meta["status_dir"] = test_dir - orc.path = test_dir + feature_store.path = test_dir controller = Controller(test_launcher) - controller._jobs.add_job(orc.name, job_id="1234", entity=orc) - controller._jobs.move_to_completed(controller._jobs.db_jobs.get(orc.name)) - controller._launch_step(step, entity=orc) + controller._jobs.add_job(feature_store.name, job_id="1234", entity=feature_store) + controller._jobs.move_to_completed(controller._jobs.fs_jobs.get(feature_store.name)) + controller._launch_step(step, entity=feature_store) @pytest.mark.parametrize( "entity,entity_2", [ pytest.param(ens, ens_2, id="Ensemble_running"), - pytest.param(model, model_2, id="Model_running"), + pytest.param(application, application_2, id="Application_running"), ], ) def test_starting_entity(test_dir, wlmutils, entity, entity_2): - """Test launching a job of Model/Ensemble with same name in completed""" + """Test launching a job of Application/Ensemble with same name in completed""" step_settings = RunSettings("echo") step = MockStep("mock-step", test_dir, step_settings) test_launcher = wlmutils.get_test_launcher() diff --git a/tests/test_dbnode.py b/tests/_legacy/test_dbnode.py similarity index 72% rename from tests/test_dbnode.py rename to tests/_legacy/test_dbnode.py index 04845344cb..7111f5ce5f 100644 --- a/tests/test_dbnode.py +++ b/tests/_legacy/test_dbnode.py @@ -33,28 +33,28 @@ import pytest from smartsim import Experiment -from smartsim.database import Orchestrator -from smartsim.entity.dbnode import DBNode, LaunchedShardData +from smartsim.database import FeatureStore +from smartsim.entity.dbnode import FSNode, LaunchedShardData from smartsim.error.errors import SmartSimError # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a -def test_parse_db_host_error(): - orc = Orchestrator() - orc.entities[0].path = "not/a/path" - # Fail to obtain database hostname +def test_parse_fs_host_error(): + feature_store = FeatureStore() + feature_store.entities[0].path = "not/a/path" + # Fail to obtain feature store hostname with pytest.raises(SmartSimError): - orc.entities[0].host + feature_store.entities[0].host -def test_hosts(local_experiment, prepare_db, local_db): - db = prepare_db(local_db).orchestrator - orc = local_experiment.reconnect_orchestrator(db.checkpoint_file) +def test_hosts(local_experiment, prepare_fs, local_fs): + fs = prepare_fs(local_fs).featurestore + feature_store = local_experiment.reconnect_feature_store(fs.checkpoint_file) - hosts = orc.hosts - assert len(hosts) == orc.db_nodes == 1 + hosts = feature_store.hosts + assert len(hosts) == feature_store.fs_nodes == 1 def _random_shard_info(): @@ -81,7 +81,7 @@ def test_launched_shard_info_can_be_serialized(): @pytest.mark.parametrize("limit", [None, 1]) -def test_db_node_can_parse_launched_shard_info(limit): +def test_fs_node_can_parse_launched_shard_info(limit): rand_shards = [_random_shard_info() for _ in range(3)] with io.StringIO(textwrap.dedent("""\ This is some file like str @@ -90,7 +90,7 @@ def test_db_node_can_parse_launched_shard_info(limit): SMARTSIM_ORC_SHARD_INFO: {} ^^^^^^^^^^^^^^^^^^^^^^^ We should be able to parse the serialized - launched db info from this file if the line is + launched fs info from this file if the line is prefixed with this tag. Here are two more for good measure: @@ -99,28 +99,28 @@ def test_db_node_can_parse_launched_shard_info(limit): All other lines should be ignored. """).format(*(json.dumps(s.to_dict()) for s in rand_shards))) as stream: - parsed_shards = DBNode._parse_launched_shard_info_from_iterable(stream, limit) + parsed_shards = FSNode._parse_launched_shard_info_from_iterable(stream, limit) if limit is not None: rand_shards = rand_shards[:limit] assert rand_shards == parsed_shards def test_set_host(): - orc = Orchestrator() - orc.entities[0].set_hosts(["host"]) - assert orc.entities[0].host == "host" + feature_store = FeatureStore() + feature_store.entities[0].set_hosts(["host"]) + assert feature_store.entities[0].host == "host" @pytest.mark.parametrize("nodes, mpmd", [[3, False], [3, True], [1, False]]) -def test_db_id_and_name(mpmd, nodes, wlmutils): +def test_fs_id_and_name(mpmd, nodes, wlmutils): if nodes > 1 and wlmutils.get_test_launcher() not in pytest.wlm_options: - pytest.skip(reason="Clustered DB can only be checked on WLMs") - orc = Orchestrator( - db_identifier="test_db", - db_nodes=nodes, + pytest.skip(reason="Clustered fs can only be checked on WLMs") + feature_store = FeatureStore( + fs_identifier="test_fs", + fs_nodes=nodes, single_cmd=mpmd, launcher=wlmutils.get_test_launcher(), ) - for i, node in enumerate(orc.entities): - assert node.name == f"{orc.name}_{i}" - assert node.db_identifier == orc.db_identifier + for i, node in enumerate(feature_store.entities): + assert node.name == f"{feature_store.name}_{i}" + assert node.fs_identifier == feature_store.fs_identifier diff --git a/tests/test_dragon_client.py b/tests/_legacy/test_dragon_client.py similarity index 97% rename from tests/test_dragon_client.py rename to tests/_legacy/test_dragon_client.py index 80257b6107..054f6f0d12 100644 --- a/tests/test_dragon_client.py +++ b/tests/_legacy/test_dragon_client.py @@ -30,7 +30,7 @@ import pytest -from smartsim._core.launcher.step.dragonStep import DragonBatchStep, DragonStep +from smartsim._core.launcher.step.dragon_step import DragonBatchStep, DragonStep from smartsim.settings import DragonRunSettings from smartsim.settings.slurmSettings import SbatchSettings @@ -39,8 +39,8 @@ import smartsim._core.entrypoints.dragon_client as dragon_client -from smartsim._core.schemas.dragonRequests import * -from smartsim._core.schemas.dragonResponses import * +from smartsim._core.schemas.dragon_requests import * +from smartsim._core.schemas.dragon_responses import * @pytest.fixture diff --git a/tests/test_dragon_installer.py b/tests/_legacy/test_dragon_installer.py similarity index 94% rename from tests/test_dragon_installer.py rename to tests/_legacy/test_dragon_installer.py index a58d711721..8ce7404c5f 100644 --- a/tests/test_dragon_installer.py +++ b/tests/_legacy/test_dragon_installer.py @@ -434,7 +434,7 @@ def test_install_package_no_wheel(test_dir: str, extraction_dir: pathlib.Path): def test_install_macos(monkeypatch: pytest.MonkeyPatch, extraction_dir: pathlib.Path): - """Verify that installation exits cleanly if installing on unsupported platform""" + """Verify that installation exits cleanly if installing on unsupported platform.""" with monkeypatch.context() as ctx: ctx.setattr(sys, "platform", "darwin") @@ -444,6 +444,36 @@ def test_install_macos(monkeypatch: pytest.MonkeyPatch, extraction_dir: pathlib. assert result == 1 +@pytest.mark.parametrize( + "version, exp_result", + [ + pytest.param("0.9", 2, id="0.9 DNE In Public Repo"), + pytest.param("0.91", 2, id="0.91 DNE In Public Repo"), + pytest.param("0.10", 0, id="0.10 Exists In Public Repo"), + pytest.param("0.19", 2, id="0.19 DNE In Public Repo"), + ], +) +def test_install_specify_asset_version( + monkeypatch: pytest.MonkeyPatch, + extraction_dir: pathlib.Path, + version: str, + exp_result: int, +): + """Verify that installation completes as expected when fed a variety of + version numbers that can or cannot be found on release assets of the + public dragon repository. + + :param extraction_dir: file system path where the dragon package should + be downloaded and extracted + :param version: Dragon version number to attempt to install + :param exp_result: Expected return code from the call to `install_dragon` + """ + request = DragonInstallRequest(extraction_dir, version=version) + + result = install_dragon(request) + assert result == exp_result + + def test_create_dotenv(monkeypatch: pytest.MonkeyPatch, test_dir: str): """Verify that attempting to create a .env file without any existing file or container directory works""" diff --git a/tests/test_dragon_launcher.py b/tests/_legacy/test_dragon_launcher.py similarity index 98% rename from tests/test_dragon_launcher.py rename to tests/_legacy/test_dragon_launcher.py index a894757918..c4f241b24b 100644 --- a/tests/test_dragon_launcher.py +++ b/tests/_legacy/test_dragon_launcher.py @@ -42,17 +42,17 @@ create_dotenv, ) from smartsim._core.config.config import get_config -from smartsim._core.launcher.dragon.dragonLauncher import ( +from smartsim._core.launcher.dragon.dragon_launcher import ( DragonConnector, DragonLauncher, ) -from smartsim._core.launcher.dragon.dragonSockets import ( +from smartsim._core.launcher.dragon.dragon_sockets import ( get_authenticator, get_secure_socket, ) -from smartsim._core.launcher.step.dragonStep import DragonBatchStep, DragonStep -from smartsim._core.schemas.dragonRequests import DragonBootstrapRequest -from smartsim._core.schemas.dragonResponses import ( +from smartsim._core.launcher.step.dragon_step import DragonBatchStep, DragonStep +from smartsim._core.schemas.dragon_requests import DragonBootstrapRequest +from smartsim._core.schemas.dragon_responses import ( DragonHandshakeResponse, DragonRunResponse, ) diff --git a/tests/test_dragon_run_policy.py b/tests/_legacy/test_dragon_run_policy.py similarity index 97% rename from tests/test_dragon_run_policy.py rename to tests/_legacy/test_dragon_run_policy.py index 5e8642c052..14219f9a32 100644 --- a/tests/test_dragon_run_policy.py +++ b/tests/_legacy/test_dragon_run_policy.py @@ -28,7 +28,7 @@ import pytest -from smartsim._core.launcher.step.dragonStep import DragonBatchStep, DragonStep +from smartsim._core.launcher.step.dragon_step import DragonBatchStep, DragonStep from smartsim.settings.dragonRunSettings import DragonRunSettings from smartsim.settings.slurmSettings import SbatchSettings @@ -36,7 +36,7 @@ from dragon.infrastructure.policy import Policy import smartsim._core.entrypoints.dragon as drg - from smartsim._core.launcher.dragon.dragonBackend import DragonBackend + from smartsim._core.launcher.dragon.dragon_backend import DragonBackend dragon_loaded = True except: @@ -45,8 +45,8 @@ # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b -from smartsim._core.schemas.dragonRequests import * -from smartsim._core.schemas.dragonResponses import * +from smartsim._core.schemas.dragon_requests import * +from smartsim._core.schemas.dragon_responses import * @pytest.fixture diff --git a/tests/test_dragon_run_request.py b/tests/_legacy/test_dragon_run_request.py similarity index 95% rename from tests/test_dragon_run_request.py rename to tests/_legacy/test_dragon_run_request.py index 62ac572eb2..94e7a5dd97 100644 --- a/tests/test_dragon_run_request.py +++ b/tests/_legacy/test_dragon_run_request.py @@ -33,23 +33,19 @@ import pydantic.error_wrappers import pytest -from smartsim._core.launcher.dragon.pqueue import NodePrioritizer - # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b dragon = pytest.importorskip("dragon") from smartsim._core.config import CONFIG -from smartsim._core.schemas.dragonRequests import * -from smartsim._core.schemas.dragonResponses import * -from smartsim._core.utils.helpers import create_short_id_str -from smartsim.status import TERMINAL_STATUSES, SmartSimStatus - -if t.TYPE_CHECKING: - from smartsim._core.launcher.dragon.dragonBackend import ( - DragonBackend, - ProcessGroupInfo, - ) +from smartsim._core.launcher.dragon.dragon_backend import ( + DragonBackend, + ProcessGroupInfo, +) +from smartsim._core.launcher.dragon.pqueue import NodePrioritizer +from smartsim._core.schemas.dragon_requests import * +from smartsim._core.schemas.dragon_responses import * +from smartsim.status import TERMINAL_STATUSES, InvalidJobStatus, JobStatus class GroupStateMock(MagicMock): @@ -105,8 +101,6 @@ def get_mock_backend( ), ) - from smartsim._core.launcher.dragon.dragonBackend import DragonBackend - dragon_backend = DragonBackend(pid=99999) # NOTE: we're manually updating these values due to issue w/mocking namespaces @@ -131,7 +125,7 @@ def set_mock_group_infos( process_mock.configure_mock(**{"returncode": 0}) dragon_mock.configure_mock(**{"native.process.Process.return_value": process_mock}) monkeypatch.setitem(sys.modules, "dragon", dragon_mock) - from smartsim._core.launcher.dragon.dragonBackend import ProcessGroupInfo + from smartsim._core.launcher.dragon.dragon_backend import ProcessGroupInfo running_group = MagicMock(status="Running") error_group = MagicMock(status="Error") @@ -139,7 +133,7 @@ def set_mock_group_infos( group_infos = { "abc123-1": ProcessGroupInfo( - SmartSimStatus.STATUS_RUNNING, + JobStatus.RUNNING, running_group, [123], [], @@ -147,7 +141,7 @@ def set_mock_group_infos( MagicMock(), ), "del999-2": ProcessGroupInfo( - SmartSimStatus.STATUS_CANCELLED, + JobStatus.CANCELLED, error_group, [124], [-9], @@ -155,7 +149,7 @@ def set_mock_group_infos( MagicMock(), ), "c101vz-3": ProcessGroupInfo( - SmartSimStatus.STATUS_COMPLETED, + JobStatus.COMPLETED, MagicMock(), [125, 126], [0], @@ -163,7 +157,7 @@ def set_mock_group_infos( MagicMock(), ), "0ghjk1-4": ProcessGroupInfo( - SmartSimStatus.STATUS_FAILED, + JobStatus.FAILED, error_group, [127], [-1], @@ -171,7 +165,7 @@ def set_mock_group_infos( MagicMock(), ), "ljace0-5": ProcessGroupInfo( - SmartSimStatus.STATUS_NEVER_STARTED, None, [], [], [], None + InvalidJobStatus.NEVER_STARTED, None, [], [], [], None ), } @@ -236,7 +230,7 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None: assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] - dragon_backend._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED + dragon_backend._group_infos[step_id].status = JobStatus.CANCELLED dragon_backend._update() assert not dragon_backend._running_steps @@ -264,7 +258,7 @@ def test_deny_run_request(monkeypatch: pytest.MonkeyPatch) -> None: assert run_resp.error_message == "Cannot satisfy request, server is shutting down." step_id = run_resp.step_id - assert dragon_backend.group_infos[step_id].status == SmartSimStatus.STATUS_FAILED + assert dragon_backend.group_infos[step_id].status == JobStatus.FAILED def test_run_request_with_empty_policy(monkeypatch: pytest.MonkeyPatch) -> None: @@ -331,7 +325,7 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None: assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]] assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]] - dragon_backend._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED + dragon_backend._group_infos[step_id].status = JobStatus.CANCELLED dragon_backend._update() assert not dragon_backend._running_steps @@ -360,7 +354,7 @@ def test_stop_request(monkeypatch: pytest.MonkeyPatch) -> None: running_steps = [ step_id for step_id, group in group_infos.items() - if group.status == SmartSimStatus.STATUS_RUNNING + if group.status == JobStatus.RUNNING ] step_id_to_stop = running_steps[0] @@ -375,10 +369,7 @@ def test_stop_request(monkeypatch: pytest.MonkeyPatch) -> None: dragon_backend._update() assert len(dragon_backend._stop_requests) == 0 - assert ( - dragon_backend._group_infos[step_id_to_stop].status - == SmartSimStatus.STATUS_CANCELLED - ) + assert dragon_backend._group_infos[step_id_to_stop].status == JobStatus.CANCELLED assert len(dragon_backend._allocated_hosts) == 0 assert len(dragon_backend._prioritizer.unassigned()) == 3 @@ -409,7 +400,7 @@ def test_shutdown_request( if kill_jobs: for group_info in dragon_backend.group_infos.values(): if not group_info.status in TERMINAL_STATUSES: - group_info.status = SmartSimStatus.STATUS_FAILED + group_info.status = JobStatus.FAILED group_info.return_codes = [-9] group_info.process_group = None group_info.redir_workers = None diff --git a/tests/test_dragon_run_request_nowlm.py b/tests/_legacy/test_dragon_run_request_nowlm.py similarity index 97% rename from tests/test_dragon_run_request_nowlm.py rename to tests/_legacy/test_dragon_run_request_nowlm.py index 3dd7099c89..98f5b706da 100644 --- a/tests/test_dragon_run_request_nowlm.py +++ b/tests/_legacy/test_dragon_run_request_nowlm.py @@ -30,8 +30,8 @@ # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a -from smartsim._core.schemas.dragonRequests import * -from smartsim._core.schemas.dragonResponses import * +from smartsim._core.schemas.dragon_requests import * +from smartsim._core.schemas.dragon_responses import * def test_run_request_with_null_policy(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/tests/test_dragon_runsettings.py b/tests/_legacy/test_dragon_runsettings.py similarity index 100% rename from tests/test_dragon_runsettings.py rename to tests/_legacy/test_dragon_runsettings.py diff --git a/tests/test_dragon_step.py b/tests/_legacy/test_dragon_step.py similarity index 98% rename from tests/test_dragon_step.py rename to tests/_legacy/test_dragon_step.py index f933fb7bc2..3dbdf114ea 100644 --- a/tests/test_dragon_step.py +++ b/tests/_legacy/test_dragon_step.py @@ -32,7 +32,7 @@ import pytest -from smartsim._core.launcher.step.dragonStep import DragonBatchStep, DragonStep +from smartsim._core.launcher.step.dragon_step import DragonBatchStep, DragonStep from smartsim.settings import DragonRunSettings from smartsim.settings.pbsSettings import QsubBatchSettings from smartsim.settings.slurmSettings import SbatchSettings @@ -41,8 +41,8 @@ pytestmark = pytest.mark.group_a -from smartsim._core.schemas.dragonRequests import * -from smartsim._core.schemas.dragonResponses import * +from smartsim._core.schemas.dragon_requests import * +from smartsim._core.schemas.dragon_responses import * @pytest.fixture diff --git a/tests/_legacy/test_ensemble.py b/tests/_legacy/test_ensemble.py new file mode 100644 index 0000000000..62c7d8d4f7 --- /dev/null +++ b/tests/_legacy/test_ensemble.py @@ -0,0 +1,306 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from copy import deepcopy + +import pytest + +from smartsim import Experiment +from smartsim.builders import Ensemble +from smartsim.entity import Application +from smartsim.error import EntityExistsError, SSUnsupportedError, UserStrategyError +from smartsim.settings import RunSettings + +# The tests in this file belong to the slow_tests group +pytestmark = pytest.mark.slow_tests + + +""" +Test ensemble creation + +TODO: test to add +- test batch settings/run_setting combinations and errors +- test replica creation +""" + +# ---- helpers ------------------------------------------------------ + + +def step_values(param_names, param_values, n_applications=0): + permutations = [] + for p in zip(*param_values): + permutations.append(dict(zip(param_names, p))) + return permutations + + +# bad permutation strategy that doesn't return +# a list of dictionaries +def bad_strategy(names, values, n_applications=0): + return -1 + + +# test bad perm strategy that returns a list but of lists +# not dictionaries +def bad_strategy_2(names, values, n_applications=0): + return [values] + + +rs = RunSettings("python", exe_args="sleep.py") + +# ----- Test param generation ---------------------------------------- + + +def test_all_perm(): + """Test all permutation strategy""" + params = {"h": [5, 6]} + ensemble = Ensemble("all_perm", params, run_settings=rs, perm_strat="all_perm") + assert len(ensemble) == 2 + assert ensemble.entities[0].params["h"] == "5" + assert ensemble.entities[1].params["h"] == "6" + + +def test_step(): + """Test step strategy""" + params = {"h": [5, 6], "g": [7, 8]} + ensemble = Ensemble("step", params, run_settings=rs, perm_strat="step") + assert len(ensemble) == 2 + + application_1_params = {"h": "5", "g": "7"} + assert ensemble.entities[0].params == application_1_params + + application_2_params = {"h": "6", "g": "8"} + assert ensemble.entities[1].params == application_2_params + + +def test_random(): + """Test random strategy""" + random_ints = [4, 5, 6, 7, 8] + params = {"h": random_ints} + ensemble = Ensemble( + "random_test", + params, + run_settings=rs, + perm_strat="random", + n_applications=len(random_ints), + ) + assert len(ensemble) == len(random_ints) + assigned_params = [m.params["h"] for m in ensemble.entities] + assert all([int(x) in random_ints for x in assigned_params]) + + ensemble = Ensemble( + "random_test", + params, + run_settings=rs, + perm_strat="random", + n_applications=len(random_ints) - 1, + ) + assert len(ensemble) == len(random_ints) - 1 + assigned_params = [m.params["h"] for m in ensemble.entities] + assert all([int(x) in random_ints for x in assigned_params]) + + +def test_user_strategy(): + """Test a user provided strategy""" + params = {"h": [5, 6], "g": [7, 8]} + ensemble = Ensemble("step", params, run_settings=rs, perm_strat=step_values) + assert len(ensemble) == 2 + + application_1_params = {"h": "5", "g": "7"} + assert ensemble.entities[0].params == application_1_params + + application_2_params = {"h": "6", "g": "8"} + assert ensemble.entities[1].params == application_2_params + + +# ----- Application arguments ------------------------------------- + + +def test_arg_params(): + """Test parameterized exe arguments""" + params = {"H": [5, 6], "g_param": ["a", "b"]} + + # Copy rs to avoid modifying referenced object + rs_copy = deepcopy(rs) + rs_orig_args = rs_copy.exe_args + ensemble = Ensemble( + "step", + params=params, + params_as_args=list(params.keys()), + run_settings=rs_copy, + perm_strat="step", + ) + assert len(ensemble) == 2 + + exe_args_0 = rs_orig_args + ["-H", "5", "--g_param=a"] + assert ensemble.entities[0].run_settings.exe_args == exe_args_0 + + exe_args_1 = rs_orig_args + ["-H", "6", "--g_param=b"] + assert ensemble.entities[1].run_settings.exe_args == exe_args_1 + + +def test_arg_and_application_params_step(): + """Test parameterized exe arguments combined with + application parameters and step strategy + """ + params = {"H": [5, 6], "g_param": ["a", "b"], "h": [5, 6], "g": [7, 8]} + + # Copy rs to avoid modifying referenced object + rs_copy = deepcopy(rs) + rs_orig_args = rs_copy.exe_args + ensemble = Ensemble( + "step", + params, + params_as_args=["H", "g_param"], + run_settings=rs_copy, + perm_strat="step", + ) + assert len(ensemble) == 2 + + exe_args_0 = rs_orig_args + ["-H", "5", "--g_param=a"] + assert ensemble.entities[0].run_settings.exe_args == exe_args_0 + + exe_args_1 = rs_orig_args + ["-H", "6", "--g_param=b"] + assert ensemble.entities[1].run_settings.exe_args == exe_args_1 + + application_1_params = {"H": "5", "g_param": "a", "h": "5", "g": "7"} + assert ensemble.entities[0].params == application_1_params + + application_2_params = {"H": "6", "g_param": "b", "h": "6", "g": "8"} + assert ensemble.entities[1].params == application_2_params + + +def test_arg_and_application_params_all_perms(): + """Test parameterized exe arguments combined with + application parameters and all_perm strategy + """ + params = {"h": [5, 6], "g_param": ["a", "b"]} + + # Copy rs to avoid modifying referenced object + rs_copy = deepcopy(rs) + rs_orig_args = rs_copy.exe_args + ensemble = Ensemble( + "step", + params, + params_as_args=["g_param"], + run_settings=rs_copy, + perm_strat="all_perm", + ) + assert len(ensemble) == 4 + + exe_args_0 = rs_orig_args + ["--g_param=a"] + assert ensemble.entities[0].run_settings.exe_args == exe_args_0 + assert ensemble.entities[2].run_settings.exe_args == exe_args_0 + + exe_args_1 = rs_orig_args + ["--g_param=b"] + assert ensemble.entities[1].run_settings.exe_args == exe_args_1 + assert ensemble.entities[3].run_settings.exe_args == exe_args_1 + + application_0_params = {"g_param": "a", "h": "5"} + assert ensemble.entities[0].params == application_0_params + application_1_params = {"g_param": "b", "h": "5"} + assert ensemble.entities[1].params == application_1_params + application_2_params = {"g_param": "a", "h": "6"} + assert ensemble.entities[2].params == application_2_params + application_3_params = {"g_param": "b", "h": "6"} + assert ensemble.entities[3].params == application_3_params + + +# ----- Error Handling -------------------------------------- + + +# unknown permuation strategy +def test_unknown_perm_strat(): + bad_strat = "not-a-strategy" + with pytest.raises(SSUnsupportedError): + e = Ensemble("ensemble", {}, run_settings=rs, perm_strat=bad_strat) + + +def test_bad_perm_strat(): + params = {"h": [2, 3]} + with pytest.raises(UserStrategyError): + e = Ensemble("ensemble", params, run_settings=rs, perm_strat=bad_strategy) + + +def test_bad_perm_strat_2(): + params = {"h": [2, 3]} + with pytest.raises(UserStrategyError): + e = Ensemble("ensemble", params, run_settings=rs, perm_strat=bad_strategy_2) + + +# bad argument type in params +def test_incorrect_param_type(): + # can either be a list, str, or int + params = {"h": {"h": [5]}} + with pytest.raises(TypeError): + e = Ensemble("ensemble", params, run_settings=rs) + + +def test_add_application_type(): + params = {"h": 5} + e = Ensemble("ensemble", params, run_settings=rs) + with pytest.raises(TypeError): + # should be a Application not string + e.add_application("application") + + +def test_add_existing_application(): + params_1 = {"h": 5} + params_2 = {"z": 6} + application_1 = Application("identical_name", params_1, "", rs) + application_2 = Application("identical_name", params_2, "", rs) + e = Ensemble("ensemble", params_1, run_settings=rs) + e.add_application(application_1) + with pytest.raises(EntityExistsError): + e.add_application(application_2) + + +# ----- Other -------------------------------------- + + +def test_applications_property(): + params = {"h": [5, 6, 7, 8]} + e = Ensemble("test", params, run_settings=rs) + applications = e.applications + assert applications == [application for application in e] + + +def test_key_prefixing(): + params_1 = {"h": [5, 6, 7, 8]} + params_2 = {"z": 6} + e = Ensemble("test", params_1, run_settings=rs) + application = Application("application", params_2, "", rs) + e.add_application(application) + assert e.query_key_prefixing() == False + e.enable_key_prefixing() + assert e.query_key_prefixing() == True + + +def test_ensemble_type(): + exp = Experiment("name") + ens_settings = RunSettings("python") + ensemble = exp.create_ensemble("name", replicas=4, run_settings=ens_settings) + assert ensemble.type == "Ensemble" diff --git a/tests/test_entitylist.py b/tests/_legacy/test_entitylist.py similarity index 100% rename from tests/test_entitylist.py rename to tests/_legacy/test_entitylist.py diff --git a/tests/_legacy/test_experiment.py b/tests/_legacy/test_experiment.py new file mode 100644 index 0000000000..70ae5f1efc --- /dev/null +++ b/tests/_legacy/test_experiment.py @@ -0,0 +1,372 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import os +import os.path as osp +import pathlib +import shutil +import typing as t + +import pytest + +from smartsim import Experiment +from smartsim._core.config import CONFIG +from smartsim._core.config.config import Config +from smartsim._core.utils import serialize +from smartsim.database import FeatureStore +from smartsim.entity import Application +from smartsim.error import SmartSimError +from smartsim.error.errors import SSUnsupportedError +from smartsim.settings import RunSettings +from smartsim.status import InvalidJobStatus + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file belong to the slow_tests group +pytestmark = pytest.mark.slow_tests + + +def test_application_prefix(test_dir: str) -> None: + exp_name = "test_prefix" + exp = Experiment(exp_name) + + application = exp.create_application( + "application", + path=test_dir, + run_settings=RunSettings("python"), + enable_key_prefixing=True, + ) + assert application._key_prefixing_enabled == True + + +def test_application_no_name(): + exp = Experiment("test_application_no_name") + with pytest.raises(AttributeError): + _ = exp.create_application(name=None, run_settings=RunSettings("python")) + + +def test_ensemble_no_name(): + exp = Experiment("test_ensemble_no_name") + with pytest.raises(AttributeError): + _ = exp.create_ensemble( + name=None, run_settings=RunSettings("python"), replicas=2 + ) + + +def test_bad_exp_path() -> None: + with pytest.raises(NotADirectoryError): + exp = Experiment("test", "not-a-directory") + + +def test_type_exp_path() -> None: + with pytest.raises(TypeError): + exp = Experiment("test", ["this-is-a-list-dummy"]) + + +def test_stop_type() -> None: + """Wrong argument type given to stop""" + exp = Experiment("name") + with pytest.raises(TypeError): + exp.stop("application") + + +def test_finished_new_application() -> None: + # finished should fail as this application hasn't been + # launched yet. + + application = Application("name", {}, "./", RunSettings("python")) + exp = Experiment("test") + with pytest.raises(ValueError): + exp.finished(application) + + +def test_status_typeerror() -> None: + exp = Experiment("test") + with pytest.raises(TypeError): + exp.get_status([]) + + +def test_status_pre_launch() -> None: + application = Application("name", {}, "./", RunSettings("python")) + exp = Experiment("test") + assert exp.get_status(application)[0] == InvalidJobStatus.NEVER_STARTED + + +def test_bad_ensemble_init_no_rs(test_dir: str) -> None: + """params supplied without run settings""" + exp = Experiment("test", exp_path=test_dir) + with pytest.raises(SmartSimError): + exp.create_ensemble("name", {"param1": 1}) + + +def test_bad_ensemble_init_no_params(test_dir: str) -> None: + """params supplied without run settings""" + exp = Experiment("test", exp_path=test_dir) + with pytest.raises(SmartSimError): + exp.create_ensemble("name", run_settings=RunSettings("python")) + + +def test_bad_ensemble_init_no_rs_bs(test_dir: str) -> None: + """ensemble init without run settings or batch settings""" + exp = Experiment("test", exp_path=test_dir) + with pytest.raises(SmartSimError): + exp.create_ensemble("name") + + +def test_stop_entity(test_dir: str) -> None: + exp_name = "test_stop_entity" + exp = Experiment(exp_name, exp_path=test_dir) + m = exp.create_application( + "application", path=test_dir, run_settings=RunSettings("sleep", "5") + ) + exp.start(m, block=False) + assert exp.finished(m) == False + exp.stop(m) + assert exp.finished(m) == True + + +def test_poll(test_dir: str) -> None: + # Ensure that a SmartSimError is not raised + exp_name = "test_exp_poll" + exp = Experiment(exp_name, exp_path=test_dir) + application = exp.create_application( + "application", path=test_dir, run_settings=RunSettings("sleep", "5") + ) + exp.start(application, block=False) + exp.poll(interval=1) + exp.stop(application) + + +def test_summary(test_dir: str) -> None: + exp_name = "test_exp_summary" + exp = Experiment(exp_name, exp_path=test_dir) + m = exp.create_application( + "application", path=test_dir, run_settings=RunSettings("echo", "Hello") + ) + exp.start(m) + summary_str = exp.summary(style="plain") + print(summary_str) + + summary_lines = summary_str.split("\n") + assert 2 == len(summary_lines) + + headers, values = [s.split() for s in summary_lines] + headers = ["Index"] + headers + + row = dict(zip(headers, values)) + assert m.name == row["Name"] + assert m.type == row["Entity-Type"] + assert 0 == int(row["RunID"]) + assert 0 == int(row["Returncode"]) + + +def test_launcher_detection( + wlmutils: "conftest.WLMUtils", monkeypatch: pytest.MonkeyPatch +) -> None: + if wlmutils.get_test_launcher() == "pals": + pytest.skip(reason="Launcher detection cannot currently detect pbs vs pals") + if wlmutils.get_test_launcher() == "local": + monkeypatch.setenv("PATH", "") # Remove all WLMs from PATH + if wlmutils.get_test_launcher() == "dragon": + pytest.skip(reason="Launcher detection cannot currently detect dragon") + + exp = Experiment("test-launcher-detection", launcher="auto") + + assert exp._launcher == wlmutils.get_test_launcher() + + +def test_enable_disable_telemetry( + monkeypatch: pytest.MonkeyPatch, test_dir: str, config: Config +) -> None: + # Global telemetry defaults to `on` and can be modified by + # setting the value of env var SMARTSIM_FLAG_TELEMETRY to 0/1 + monkeypatch.setattr(os, "environ", {}) + exp = Experiment("my-exp", exp_path=test_dir) + exp.telemetry.enable() + assert exp.telemetry.is_enabled + + exp.telemetry.disable() + assert not exp.telemetry.is_enabled + + exp.telemetry.enable() + assert exp.telemetry.is_enabled + + exp.telemetry.disable() + assert not exp.telemetry.is_enabled + + exp.start() + mani_path = ( + pathlib.Path(test_dir) / config.telemetry_subdir / serialize.MANIFEST_FILENAME + ) + assert mani_path.exists() + + +def test_telemetry_default( + monkeypatch: pytest.MonkeyPatch, test_dir: str, config: Config +) -> None: + """Ensure the default values for telemetry configuration match expectation + that experiment telemetry is on""" + + # If env var related to telemetry doesn't exist, experiment should default to True + monkeypatch.setattr(os, "environ", {}) + exp = Experiment("my-exp", exp_path=test_dir) + assert exp.telemetry.is_enabled + + # If telemetry disabled in env, should get False + monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", "0") + exp = Experiment("my-exp", exp_path=test_dir) + assert not exp.telemetry.is_enabled + + # If telemetry enabled in env, should get True + monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", "1") + exp = Experiment("my-exp", exp_path=test_dir) + assert exp.telemetry.is_enabled + + +def test_error_on_cobalt() -> None: + with pytest.raises(SSUnsupportedError): + exp = Experiment("cobalt_exp", launcher="cobalt") + + +def test_default_feature_store_path( + monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" +) -> None: + """Ensure the default file structure is created for FeatureStore""" + + exp_name = "default-feature-store-path" + exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) + monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) + db = exp.create_feature_store( + port=wlmutils.get_test_port(), interface=wlmutils.get_test_interface() + ) + exp.start(db) + feature_store_path = pathlib.Path(test_dir) / db.name + assert feature_store_path.exists() + assert db.path == str(feature_store_path) + + +def test_default_application_path( + monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" +) -> None: + """Ensure the default file structure is created for Application""" + + exp_name = "default-application-path" + exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) + monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) + settings = exp.create_run_settings(exe="echo", exe_args="hello") + application = exp.create_application(name="application_name", run_settings=settings) + exp.start(application) + application_path = pathlib.Path(test_dir) / application.name + assert application_path.exists() + assert application.path == str(application_path) + + +def test_default_ensemble_path( + monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" +) -> None: + """Ensure the default file structure is created for Ensemble""" + + exp_name = "default-ensemble-path" + exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) + monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) + settings = exp.create_run_settings(exe="echo", exe_args="hello") + ensemble = exp.create_ensemble( + name="ensemble_name", run_settings=settings, replicas=2 + ) + exp.start(ensemble) + ensemble_path = pathlib.Path(test_dir) / ensemble.name + assert ensemble_path.exists() + assert ensemble.path == str(ensemble_path) + for member in ensemble.applications: + member_path = ensemble_path / member.name + assert member_path.exists() + assert member.path == str(ensemble_path / member.name) + + +def test_user_feature_store_path( + monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" +) -> None: + """Ensure a relative path is used to created FeatureStore folder""" + + exp_name = "default-feature-store-path" + exp = Experiment(exp_name, launcher="local", exp_path=test_dir) + monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) + db = exp.create_feature_store( + port=wlmutils.get_test_port(), + interface=wlmutils.get_test_interface(), + path="./testing_folder1234", + ) + exp.start(db) + feature_store_path = pathlib.Path(osp.abspath("./testing_folder1234")) + assert feature_store_path.exists() + assert db.path == str(feature_store_path) + shutil.rmtree(feature_store_path) + + +def test_default_application_with_path( + monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" +) -> None: + """Ensure a relative path is used to created Application folder""" + + exp_name = "default-ensemble-path" + exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) + monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) + settings = exp.create_run_settings(exe="echo", exe_args="hello") + application = exp.create_application( + name="application_name", run_settings=settings, path="./testing_folder1234" + ) + exp.start(application) + application_path = pathlib.Path(osp.abspath("./testing_folder1234")) + assert application_path.exists() + assert application.path == str(application_path) + shutil.rmtree(application_path) + + +def test_default_ensemble_with_path( + monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" +) -> None: + """Ensure a relative path is used to created Ensemble folder""" + + exp_name = "default-ensemble-path" + exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) + monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) + settings = exp.create_run_settings(exe="echo", exe_args="hello") + ensemble = exp.create_ensemble( + name="ensemble_name", + run_settings=settings, + path="./testing_folder1234", + replicas=2, + ) + exp.start(ensemble) + ensemble_path = pathlib.Path(osp.abspath("./testing_folder1234")) + assert ensemble_path.exists() + assert ensemble.path == str(ensemble_path) + for member in ensemble.applications: + member_path = ensemble_path / member.name + assert member_path.exists() + assert member.path == str(member_path) + shutil.rmtree(ensemble_path) diff --git a/tests/test_fixtures.py b/tests/_legacy/test_fixtures.py similarity index 70% rename from tests/test_fixtures.py rename to tests/_legacy/test_fixtures.py index ea753374e7..15823e1581 100644 --- a/tests/test_fixtures.py +++ b/tests/_legacy/test_fixtures.py @@ -29,7 +29,7 @@ import pytest from smartsim import Experiment -from smartsim.database import Orchestrator +from smartsim.database import FeatureStore from smartsim.error import SmartSimError from smartsim.error.errors import SSUnsupportedError @@ -37,20 +37,20 @@ pytestmark = pytest.mark.group_a -def test_db_fixtures(local_experiment, local_db, prepare_db): - db = prepare_db(local_db).orchestrator - local_experiment.reconnect_orchestrator(db.checkpoint_file) - assert db.is_active() - local_experiment.stop(db) +def test_db_fixtures(local_experiment, local_fs, prepare_fs): + fs = prepare_fs(local_fs).featurestore + local_experiment.reconnect_feature_store(fs.checkpoint_file) + assert fs.is_active() + local_experiment.stop(fs) -def test_create_new_db_fixture_if_stopped(local_experiment, local_db, prepare_db): +def test_create_new_fs_fixture_if_stopped(local_experiment, local_fs, prepare_fs): # Run this twice to make sure that there is a stopped database - output = prepare_db(local_db) - local_experiment.reconnect_orchestrator(output.orchestrator.checkpoint_file) - local_experiment.stop(output.orchestrator) - - output = prepare_db(local_db) - assert output.new_db - local_experiment.reconnect_orchestrator(output.orchestrator.checkpoint_file) - assert output.orchestrator.is_active() + output = prepare_fs(local_fs) + local_experiment.reconnect_feature_store(output.featurestore.checkpoint_file) + local_experiment.stop(output.featurestore) + + output = prepare_fs(local_fs) + assert output.new_fs + local_experiment.reconnect_feature_store(output.featurestore.checkpoint_file) + assert output.featurestore.is_active() diff --git a/tests/_legacy/test_generator.py b/tests/_legacy/test_generator.py new file mode 100644 index 0000000000..c3bfcad648 --- /dev/null +++ b/tests/_legacy/test_generator.py @@ -0,0 +1,381 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import filecmp +from os import path as osp + +import pytest +from tabulate import tabulate + +from smartsim import Experiment +from smartsim._core.generation import Generator +from smartsim.database import FeatureStore +from smartsim.settings import RunSettings + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + + +rs = RunSettings("python", exe_args="sleep.py") + + +""" +Test the generation of files and input data for an experiment + +TODO + - test lists of inputs for each file type + - test empty directories + - test re-generation + +""" + + +def get_gen_file(fileutils, filename): + return fileutils.get_test_conf_path(osp.join("generator_files", filename)) + + +def test_ensemble(fileutils, test_dir): + exp = Experiment("gen-test", launcher="local") + + gen = Generator(test_dir) + params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} + ensemble = exp.create_ensemble("test", params=params, run_settings=rs) + + config = get_gen_file(fileutils, "in.atm") + ensemble.attach_generator_files(to_configure=config) + gen.generate_experiment(ensemble) + + assert len(ensemble) == 9 + assert osp.isdir(osp.join(test_dir, "test")) + for i in range(9): + assert osp.isdir(osp.join(test_dir, "test/test_" + str(i))) + + +def test_ensemble_overwrite(fileutils, test_dir): + exp = Experiment("gen-test-overwrite", launcher="local") + + gen = Generator(test_dir, overwrite=True) + + params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} + ensemble = exp.create_ensemble("test", params=params, run_settings=rs) + + config = get_gen_file(fileutils, "in.atm") + ensemble.attach_generator_files(to_configure=[config]) + gen.generate_experiment(ensemble) + + # re generate without overwrite + config = get_gen_file(fileutils, "in.atm") + ensemble.attach_generator_files(to_configure=[config]) + gen.generate_experiment(ensemble) + + assert len(ensemble) == 9 + assert osp.isdir(osp.join(test_dir, "test")) + for i in range(9): + assert osp.isdir(osp.join(test_dir, "test/test_" + str(i))) + + +def test_ensemble_overwrite_error(fileutils, test_dir): + exp = Experiment("gen-test-overwrite-error", launcher="local") + + gen = Generator(test_dir) + + params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} + ensemble = exp.create_ensemble("test", params=params, run_settings=rs) + + config = get_gen_file(fileutils, "in.atm") + ensemble.attach_generator_files(to_configure=[config]) + gen.generate_experiment(ensemble) + + # re generate without overwrite + config = get_gen_file(fileutils, "in.atm") + ensemble.attach_generator_files(to_configure=[config]) + with pytest.raises(FileExistsError): + gen.generate_experiment(ensemble) + + +def test_full_exp(fileutils, test_dir, wlmutils): + exp = Experiment("gen-test", test_dir, launcher="local") + + application = exp.create_application("application", run_settings=rs) + script = fileutils.get_test_conf_path("sleep.py") + application.attach_generator_files(to_copy=script) + + feature_store = FeatureStore(wlmutils.get_test_port()) + params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} + ensemble = exp.create_ensemble("test_ens", params=params, run_settings=rs) + + config = get_gen_file(fileutils, "in.atm") + ensemble.attach_generator_files(to_configure=config) + exp.generate(feature_store, ensemble, application) + + # test for ensemble + assert osp.isdir(osp.join(test_dir, "test_ens/")) + for i in range(9): + assert osp.isdir(osp.join(test_dir, "test_ens/test_ens_" + str(i))) + + # test for feature_store dir + assert osp.isdir(osp.join(test_dir, feature_store.name)) + + # test for application file + assert osp.isdir(osp.join(test_dir, "application")) + assert osp.isfile(osp.join(test_dir, "application/sleep.py")) + + +def test_dir_files(fileutils, test_dir): + """test the generate of applications with files that + are directories with subdirectories and files + """ + + exp = Experiment("gen-test", test_dir, launcher="local") + + params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} + ensemble = exp.create_ensemble("dir_test", params=params, run_settings=rs) + conf_dir = get_gen_file(fileutils, "test_dir") + ensemble.attach_generator_files(to_configure=conf_dir) + + exp.generate(ensemble, tag="@") + + assert osp.isdir(osp.join(test_dir, "dir_test/")) + for i in range(9): + application_path = osp.join(test_dir, "dir_test/dir_test_" + str(i)) + assert osp.isdir(application_path) + assert osp.isdir(osp.join(application_path, "test_dir_1")) + assert osp.isfile(osp.join(application_path, "test.in")) + + +def test_print_files(fileutils, test_dir, capsys): + """Test the stdout print of files attached to an ensemble""" + + exp = Experiment("print-attached-files-test", test_dir, launcher="local") + + ensemble = exp.create_ensemble("dir_test", replicas=1, run_settings=rs) + ensemble.entities = [] + + ensemble.print_attached_files() + captured = capsys.readouterr() + assert captured.out == "The ensemble is empty, no files to show.\n" + + params = {"THERMO": [10, 20], "STEPS": [20, 30]} + ensemble = exp.create_ensemble("dir_test", params=params, run_settings=rs) + gen_dir = get_gen_file(fileutils, "test_dir") + symlink_dir = get_gen_file(fileutils, "to_symlink_dir") + copy_dir = get_gen_file(fileutils, "to_copy_dir") + + ensemble.print_attached_files() + captured = capsys.readouterr() + expected_out = ( + tabulate( + [ + [application.name, "No file attached to this application."] + for application in ensemble.applications + ], + headers=["Application name", "Files"], + tablefmt="grid", + ) + + "\n" + ) + + assert captured.out == expected_out + + ensemble.attach_generator_files() + ensemble.print_attached_files() + captured = capsys.readouterr() + expected_out = ( + tabulate( + [ + [application.name, "No file attached to this entity."] + for application in ensemble.applications + ], + headers=["Application name", "Files"], + tablefmt="grid", + ) + + "\n" + ) + assert captured.out == expected_out + + ensemble.attach_generator_files( + to_configure=[gen_dir, copy_dir], to_copy=copy_dir, to_symlink=symlink_dir + ) + + expected_out = tabulate( + [ + ["Copy", copy_dir], + ["Symlink", symlink_dir], + ["Configure", f"{gen_dir}\n{copy_dir}"], + ], + headers=["Strategy", "Files"], + tablefmt="grid", + ) + + assert all( + str(application.files) == expected_out for application in ensemble.applications + ) + + expected_out_multi = ( + tabulate( + [[application.name, expected_out] for application in ensemble.applications], + headers=["Application name", "Files"], + tablefmt="grid", + ) + + "\n" + ) + ensemble.print_attached_files() + + captured = capsys.readouterr() + assert captured.out == expected_out_multi + + +def test_multiple_tags(fileutils, test_dir): + """Test substitution of multiple tagged parameters on same line""" + + exp = Experiment("test-multiple-tags", test_dir) + application_params = {"port": 6379, "password": "unbreakable_password"} + application_settings = RunSettings("bash", "multi_tags_template.sh") + parameterized_application = exp.create_application( + "multi-tags", run_settings=application_settings, params=application_params + ) + config = get_gen_file(fileutils, "multi_tags_template.sh") + parameterized_application.attach_generator_files(to_configure=[config]) + exp.generate(parameterized_application, overwrite=True) + exp.start(parameterized_application, block=True) + + with open(osp.join(parameterized_application.path, "multi-tags.out")) as f: + log_content = f.read() + assert "My two parameters are 6379 and unbreakable_password, OK?" in log_content + + +def test_generation_log(fileutils, test_dir): + """Test that an error is issued when a tag is unused and make_fatal is True""" + + exp = Experiment("gen-log-test", test_dir, launcher="local") + + params = {"THERMO": [10, 20], "STEPS": [10, 20]} + ensemble = exp.create_ensemble("dir_test", params=params, run_settings=rs) + conf_file = get_gen_file(fileutils, "in.atm") + ensemble.attach_generator_files(to_configure=conf_file) + + def not_header(line): + """you can add other general checks in here""" + return not line.startswith("Generation start date and time:") + + exp.generate(ensemble, verbose=True) + + log_file = osp.join(test_dir, "smartsim_params.txt") + ground_truth = get_gen_file( + fileutils, osp.join("log_params", "smartsim_params.txt") + ) + + with open(log_file) as f1, open(ground_truth) as f2: + assert not not_header(f1.readline()) + f1 = filter(not_header, f1) + f2 = filter(not_header, f2) + assert all(x == y for x, y in zip(f1, f2)) + + for entity in ensemble: + assert filecmp.cmp( + osp.join(entity.path, "smartsim_params.txt"), + get_gen_file( + fileutils, + osp.join("log_params", "dir_test", entity.name, "smartsim_params.txt"), + ), + ) + + +def test_config_dir(fileutils, test_dir): + """Test the generation and configuration of applications with + tagged files that are directories with subdirectories and files + """ + exp = Experiment("config-dir", launcher="local") + + gen = Generator(test_dir) + + params = {"PARAM0": [0, 1], "PARAM1": [2, 3]} + ensemble = exp.create_ensemble("test", params=params, run_settings=rs) + + config = get_gen_file(fileutils, "tag_dir_template") + ensemble.attach_generator_files(to_configure=config) + gen.generate_experiment(ensemble) + + assert osp.isdir(osp.join(test_dir, "test")) + + def _check_generated(test_num, param_0, param_1): + conf_test_dir = osp.join(test_dir, "test", f"test_{test_num}") + assert osp.isdir(conf_test_dir) + assert osp.isdir(osp.join(conf_test_dir, "nested_0")) + assert osp.isdir(osp.join(conf_test_dir, "nested_1")) + + with open(osp.join(conf_test_dir, "nested_0", "tagged_0.sh")) as f: + line = f.readline() + assert line.strip() == f'echo "Hello with parameter 0 = {param_0}"' + + with open(osp.join(conf_test_dir, "nested_1", "tagged_1.sh")) as f: + line = f.readline() + assert line.strip() == f'echo "Hello with parameter 1 = {param_1}"' + + _check_generated(0, 0, 2) + _check_generated(1, 0, 3) + _check_generated(2, 1, 2) + _check_generated(3, 1, 3) + + +def test_no_gen_if_file_not_exist(fileutils): + """Test that generation of file with non-existant config + raises a FileNotFound exception + """ + exp = Experiment("file-not-found", launcher="local") + ensemble = exp.create_ensemble("test", params={"P": [0, 1]}, run_settings=rs) + config = get_gen_file(fileutils, "path_not_exist") + with pytest.raises(FileNotFoundError): + ensemble.attach_generator_files(to_configure=config) + + +def test_no_gen_if_symlink_to_dir(fileutils): + """Test that when configuring a directory containing a symlink + a ValueError exception is raised to prevent circular file + structure configuration + """ + exp = Experiment("circular-config-files", launcher="local") + ensemble = exp.create_ensemble("test", params={"P": [0, 1]}, run_settings=rs) + config = get_gen_file(fileutils, "circular_config") + with pytest.raises(ValueError): + ensemble.attach_generator_files(to_configure=config) + + +def test_no_file_overwrite(): + exp = Experiment("test_no_file_overwrite", launcher="local") + ensemble = exp.create_ensemble("test", params={"P": [0, 1]}, run_settings=rs) + with pytest.raises(ValueError): + ensemble.attach_generator_files( + to_configure=["/normal/file.txt", "/path/to/smartsim_params.txt"] + ) + with pytest.raises(ValueError): + ensemble.attach_generator_files( + to_symlink=["/normal/file.txt", "/path/to/smartsim_params.txt"] + ) + with pytest.raises(ValueError): + ensemble.attach_generator_files( + to_copy=["/normal/file.txt", "/path/to/smartsim_params.txt"] + ) diff --git a/tests/test_helpers.py b/tests/_legacy/test_helpers.py similarity index 88% rename from tests/test_helpers.py rename to tests/_legacy/test_helpers.py index 523ed7191c..7b453905cb 100644 --- a/tests/test_helpers.py +++ b/tests/_legacy/test_helpers.py @@ -30,12 +30,32 @@ import pytest from smartsim._core.utils import helpers -from smartsim._core.utils.helpers import cat_arg_and_value +from smartsim._core.utils.helpers import cat_arg_and_value, unpack +from smartsim.entity.application import Application +from smartsim.launchable.job import Job +from smartsim.settings.launch_settings import LaunchSettings # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a +def test_unpack_iterates_over_nested_jobs_in_expected_order(wlmutils): + launch_settings = LaunchSettings(wlmutils.get_test_launcher()) + app = Application("app_name", exe="python") + job_1 = Job(app, launch_settings) + job_2 = Job(app, launch_settings) + job_3 = Job(app, launch_settings) + job_4 = Job(app, launch_settings) + job_5 = Job(app, launch_settings) + + assert ( + [job_1, job_2, job_3, job_4, job_5] + == list(unpack([job_1, [job_2, job_3], job_4, [job_5]])) + == list(unpack([job_1, job_2, [job_3, job_4], job_5])) + == list(unpack([job_1, [job_2, [job_3, job_4], job_5]])) + ) + + def test_double_dash_concat(): result = cat_arg_and_value("--foo", "FOO") assert result == "--foo=FOO" diff --git a/tests/test_indirect.py b/tests/_legacy/test_indirect.py similarity index 99% rename from tests/test_indirect.py rename to tests/_legacy/test_indirect.py index 8143029689..7766b5825c 100644 --- a/tests/test_indirect.py +++ b/tests/_legacy/test_indirect.py @@ -54,7 +54,7 @@ [ pytest.param("indirect.py", {"+name", "+command", "+entity_type", "+telemetry_dir", "+working_dir"}, id="no args"), pytest.param("indirect.py -c echo +entity_type ttt +telemetry_dir ddd +output_file ooo +working_dir www +error_file eee", {"+command"}, id="cmd typo"), - pytest.param("indirect.py -t orchestrator +command ccc +telemetry_dir ddd +output_file ooo +working_dir www +error_file eee", {"+entity_type"}, id="etype typo"), + pytest.param("indirect.py -t featurestore +command ccc +telemetry_dir ddd +output_file ooo +working_dir www +error_file eee", {"+entity_type"}, id="etype typo"), pytest.param("indirect.py -d /foo/bar +entity_type ttt +command ccc +output_file ooo +working_dir www +error_file eee", {"+telemetry_dir"}, id="dir typo"), pytest.param("indirect.py +entity_type ttt +telemetry_dir ddd +output_file ooo +working_dir www +error_file eee", {"+command"}, id="no cmd"), pytest.param("indirect.py +command ccc +telemetry_dir ddd +output_file ooo +working_dir www +error_file eee", {"+entity_type"}, id="no etype"), diff --git a/tests/test_interrupt.py b/tests/_legacy/test_interrupt.py similarity index 84% rename from tests/test_interrupt.py rename to tests/_legacy/test_interrupt.py index c38ae02251..1b134a8848 100644 --- a/tests/test_interrupt.py +++ b/tests/_legacy/test_interrupt.py @@ -46,15 +46,15 @@ def keyboard_interrupt(pid): def test_interrupt_blocked_jobs(test_dir): """ - Launches and polls a model and an ensemble with two more models. + Launches and polls a application and an ensemble with two more applications. Once polling starts, the SIGINT signal is sent to the main thread, and consequently, all running jobs are killed. """ exp_name = "test_interrupt_blocked_jobs" exp = Experiment(exp_name, exp_path=test_dir) - model = exp.create_model( - "interrupt_blocked_model", + application = exp.create_application( + "interrupt_blocked_application", path=test_dir, run_settings=RunSettings("sleep", "100"), ) @@ -71,20 +71,20 @@ def test_interrupt_blocked_jobs(test_dir): keyboard_interrupt_thread.start() with pytest.raises(KeyboardInterrupt): - exp.start(model, ensemble, block=True, kill_on_interrupt=True) + exp.start(application, ensemble, block=True, kill_on_interrupt=True) time.sleep(2) # allow time for jobs to be stopped active_jobs = exp._control._jobs.jobs - active_db_jobs = exp._control._jobs.db_jobs + active_fs_jobs = exp._control._jobs.fs_jobs completed_jobs = exp._control._jobs.completed - assert len(active_jobs) + len(active_db_jobs) == 0 + assert len(active_jobs) + len(active_fs_jobs) == 0 assert len(completed_jobs) == num_jobs def test_interrupt_multi_experiment_unblocked_jobs(test_dir): """ - Starts two Experiments, each having one model - and an ensemble with two more models. Since + Starts two Experiments, each having one application + and an ensemble with two more applications. Since blocking is False, the main thread sleeps until the SIGINT signal is sent, resulting in both Experiment's running jobs to be killed. @@ -94,8 +94,8 @@ def test_interrupt_multi_experiment_unblocked_jobs(test_dir): experiments = [Experiment(exp_names[i], exp_path=test_dir) for i in range(2)] jobs_per_experiment = [0] * len(experiments) for i, experiment in enumerate(experiments): - model = experiment.create_model( - "interrupt_model_" + str(i), + application = experiment.create_application( + "interrupt_application_" + str(i), path=test_dir, run_settings=RunSettings("sleep", "100"), ) @@ -114,13 +114,13 @@ def test_interrupt_multi_experiment_unblocked_jobs(test_dir): with pytest.raises(KeyboardInterrupt): for experiment in experiments: - experiment.start(model, ensemble, block=False, kill_on_interrupt=True) + experiment.start(application, ensemble, block=False, kill_on_interrupt=True) keyboard_interrupt_thread.join() # since jobs aren't blocked, wait for SIGINT time.sleep(2) # allow time for jobs to be stopped for i, experiment in enumerate(experiments): active_jobs = experiment._control._jobs.jobs - active_db_jobs = experiment._control._jobs.db_jobs + active_fs_jobs = experiment._control._jobs.fs_jobs completed_jobs = experiment._control._jobs.completed - assert len(active_jobs) + len(active_db_jobs) == 0 + assert len(active_jobs) + len(active_fs_jobs) == 0 assert len(completed_jobs) == jobs_per_experiment[i] diff --git a/tests/test_launch_errors.py b/tests/_legacy/test_launch_errors.py similarity index 72% rename from tests/test_launch_errors.py rename to tests/_legacy/test_launch_errors.py index 21b3184e5e..d545bffe4e 100644 --- a/tests/test_launch_errors.py +++ b/tests/_legacy/test_launch_errors.py @@ -28,10 +28,10 @@ import pytest from smartsim import Experiment -from smartsim.database import Orchestrator +from smartsim.database import FeatureStore from smartsim.error import SSUnsupportedError from smartsim.settings import JsrunSettings, RunSettings -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a @@ -41,7 +41,7 @@ def test_unsupported_run_settings(test_dir): exp_name = "test-unsupported-run-settings" exp = Experiment(exp_name, launcher="slurm", exp_path=test_dir) bad_settings = JsrunSettings("echo", "hello") - model = exp.create_model("bad_rs", bad_settings) + model = exp.create_application("bad_rs", bad_settings) with pytest.raises(SSUnsupportedError): exp.start(model) @@ -54,25 +54,29 @@ def test_model_failure(fileutils, test_dir): script = fileutils.get_test_conf_path("bad.py") settings = RunSettings("python", f"{script} --time=3") - M1 = exp.create_model("m1", path=test_dir, run_settings=settings) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings) exp.start(M1, block=True) statuses = exp.get_status(M1) - assert all([stat == SmartSimStatus.STATUS_FAILED for stat in statuses]) + assert all([stat == JobStatus.FAILED for stat in statuses]) -def test_orchestrator_relaunch(test_dir, wlmutils): - """Test when users try to launch second orchestrator""" - exp_name = "test-orc-on-relaunch" +def test_feature_store_relaunch(test_dir, wlmutils): + """Test when users try to launch second FeatureStore""" + exp_name = "test-feature-store-on-relaunch" exp = Experiment(exp_name, launcher="local", exp_path=test_dir) - orc = Orchestrator(port=wlmutils.get_test_port(), db_identifier="orch_1") - orc.set_path(test_dir) - orc_1 = Orchestrator(port=wlmutils.get_test_port() + 1, db_identifier="orch_2") - orc_1.set_path(test_dir) + feature_store = FeatureStore( + port=wlmutils.get_test_port(), fs_identifier="feature_store_1" + ) + feature_store.set_path(test_dir) + feature_store_1 = FeatureStore( + port=wlmutils.get_test_port() + 1, fs_identifier="feature_store_2" + ) + feature_store_1.set_path(test_dir) try: - exp.start(orc) - exp.start(orc_1) + exp.start(feature_store) + exp.start(feature_store_1) finally: - exp.stop(orc) - exp.stop(orc_1) + exp.stop(feature_store) + exp.stop(feature_store_1) diff --git a/tests/test_local_launch.py b/tests/_legacy/test_local_launch.py similarity index 84% rename from tests/test_local_launch.py rename to tests/_legacy/test_local_launch.py index 85687e0142..b638f515e1 100644 --- a/tests/test_local_launch.py +++ b/tests/_legacy/test_local_launch.py @@ -27,7 +27,7 @@ import pytest from smartsim import Experiment -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a @@ -38,19 +38,19 @@ """ -def test_models(fileutils, test_dir): - exp_name = "test-models-local-launch" +def test_applications(fileutils, test_dir): + exp_name = "test-applications-local-launch" exp = Experiment(exp_name, launcher="local", exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") settings = exp.create_run_settings("python", f"{script} --time=3") - M1 = exp.create_model("m1", path=test_dir, run_settings=settings) - M2 = exp.create_model("m2", path=test_dir, run_settings=settings) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings) + M2 = exp.create_application("m2", path=test_dir, run_settings=settings) exp.start(M1, M2, block=True, summary=True) statuses = exp.get_status(M1, M2) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) def test_ensemble(fileutils, test_dir): @@ -64,4 +64,4 @@ def test_ensemble(fileutils, test_dir): exp.start(ensemble, block=True, summary=True) statuses = exp.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) diff --git a/tests/test_local_multi_run.py b/tests/_legacy/test_local_multi_run.py similarity index 80% rename from tests/test_local_multi_run.py rename to tests/_legacy/test_local_multi_run.py index a2c1d70ee9..a3762595ef 100644 --- a/tests/test_local_multi_run.py +++ b/tests/_legacy/test_local_multi_run.py @@ -27,7 +27,7 @@ import pytest from smartsim import Experiment -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a @@ -38,21 +38,21 @@ """ -def test_models(fileutils, test_dir): - exp_name = "test-models-local-launch" +def test_applications(fileutils, test_dir): + exp_name = "test-applications-local-launch" exp = Experiment(exp_name, launcher="local", exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") settings = exp.create_run_settings("python", f"{script} --time=5") - M1 = exp.create_model("m1", path=test_dir, run_settings=settings) - M2 = exp.create_model("m2", path=test_dir, run_settings=settings) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings) + M2 = exp.create_application("m2", path=test_dir, run_settings=settings) exp.start(M1, block=False) statuses = exp.get_status(M1) - assert all([stat != SmartSimStatus.STATUS_FAILED for stat in statuses]) + assert all([stat != JobStatus.FAILED for stat in statuses]) - # start another while first model is running + # start another while first application is running exp.start(M2, block=True) statuses = exp.get_status(M1, M2) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) diff --git a/tests/test_local_restart.py b/tests/_legacy/test_local_restart.py similarity index 81% rename from tests/test_local_restart.py rename to tests/_legacy/test_local_restart.py index 2556c55977..5f22c96a0f 100644 --- a/tests/test_local_restart.py +++ b/tests/_legacy/test_local_restart.py @@ -27,34 +27,34 @@ import pytest from smartsim import Experiment -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b """ -Test restarting ensembles and models. +Test restarting ensembles and applications. """ def test_restart(fileutils, test_dir): - exp_name = "test-models-local-restart" + exp_name = "test-applications-local-restart" exp = Experiment(exp_name, launcher="local", exp_path=test_dir) script = fileutils.get_test_conf_path("sleep.py") settings = exp.create_run_settings("python", f"{script} --time=3") - M1 = exp.create_model("m1", path=test_dir, run_settings=settings) + M1 = exp.create_application("m1", path=test_dir, run_settings=settings) exp.start(M1, block=True) statuses = exp.get_status(M1) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) - # restart the model + # restart the application exp.start(M1, block=True) statuses = exp.get_status(M1) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) def test_ensemble(fileutils, test_dir): @@ -68,9 +68,9 @@ def test_ensemble(fileutils, test_dir): exp.start(ensemble, block=True) statuses = exp.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) # restart the ensemble exp.start(ensemble, block=True) statuses = exp.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) diff --git a/tests/test_logs.py b/tests/_legacy/test_logs.py similarity index 99% rename from tests/test_logs.py rename to tests/_legacy/test_logs.py index a187baa2a3..42c3335760 100644 --- a/tests/test_logs.py +++ b/tests/_legacy/test_logs.py @@ -204,7 +204,7 @@ def thrower(_self): sleep_rs.set_nodes(1) sleep_rs.set_tasks(1) - sleep = exp.create_model("SleepModel", sleep_rs) + sleep = exp.create_application("SleepModel", sleep_rs) exp.generate(sleep) exp.start(sleep, block=True) except Exception as ex: diff --git a/tests/test_lsf_parser.py b/tests/_legacy/test_lsf_parser.py similarity index 92% rename from tests/test_lsf_parser.py rename to tests/_legacy/test_lsf_parser.py index abd27eb5ae..0234ee4e90 100644 --- a/tests/test_lsf_parser.py +++ b/tests/_legacy/test_lsf_parser.py @@ -26,7 +26,7 @@ import pytest -from smartsim._core.launcher.lsf import lsfParser +from smartsim._core.launcher.lsf import lsf_parser # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b @@ -37,7 +37,7 @@ def test_parse_bsub(): output = "Job <12345> is submitted to queue ." - step_id = lsfParser.parse_bsub(output) + step_id = lsf_parser.parse_bsub(output) assert step_id == "12345" @@ -57,15 +57,15 @@ def test_parse_bsub_error(): "Not a member of the specified project: . You are currently a member of the following projects:\n" "ABC123" ) - parsed_error = lsfParser.parse_bsub_error(output) + parsed_error = lsf_parser.parse_bsub_error(output) assert error == parsed_error output = "NOT A PARSABLE ERROR\nBUT STILL AN ERROR MESSAGE" - parsed_error = lsfParser.parse_bsub_error(output) + parsed_error = lsf_parser.parse_bsub_error(output) assert output == parsed_error output = " \n" - parsed_error = lsfParser.parse_bsub_error(output) + parsed_error = lsf_parser.parse_bsub_error(output) assert parsed_error == "LSF run error" @@ -79,7 +79,7 @@ def test_parse_bsub_nodes(fileutils): "1234567 smartsim RUN batch login1 batch3:a01n02:a01n02:a01n02:a01n02:a01n02:a01n06:a01n06:a01n06:a01n06:a01n06 SmartSim Jul 24 12:53\n" ) nodes = ["batch3", "a01n02", "a01n06"] - parsed_nodes = lsfParser.parse_bjobs_nodes(output) + parsed_nodes = lsf_parser.parse_bjobs_nodes(output) assert nodes == parsed_nodes @@ -98,7 +98,7 @@ def test_parse_max_step_id(): " 4 0 1 various various 137 Killed\n" " 5 0 3 various various 137 Killed\n" ) - parsed_id = lsfParser.parse_max_step_id_from_jslist(output) + parsed_id = lsf_parser.parse_max_step_id_from_jslist(output) assert parsed_id == "9" @@ -121,6 +121,6 @@ def test_parse_jslist(): " 1 1 4 various various 0 Running\n" " 11 1 1 1 1 1 Running\n" ) - parsed_result = lsfParser.parse_jslist_stepid(output, "1") + parsed_result = lsf_parser.parse_jslist_stepid(output, "1") result = ("Running", "0") assert parsed_result == result diff --git a/tests/test_lsf_settings.py b/tests/_legacy/test_lsf_settings.py similarity index 99% rename from tests/test_lsf_settings.py rename to tests/_legacy/test_lsf_settings.py index fcb3516483..64dbd001cc 100644 --- a/tests/test_lsf_settings.py +++ b/tests/_legacy/test_lsf_settings.py @@ -144,7 +144,7 @@ def test_jsrun_mpmd(): def test_catch_colo_mpmd(): settings = JsrunSettings("python") - settings.colocated_db_settings = {"port": 6379, "cpus": 1} + settings.colocated_fs_settings = {"port": 6379, "cpus": 1} settings_2 = JsrunSettings("python") with pytest.raises(SSUnsupportedError): settings.make_mpmd(settings_2) diff --git a/tests/test_manifest.py b/tests/_legacy/test_manifest.py similarity index 71% rename from tests/test_manifest.py rename to tests/_legacy/test_manifest.py index f4a1b0afb5..ae65d44d83 100644 --- a/tests/test_manifest.py +++ b/tests/_legacy/test_manifest.py @@ -41,13 +41,15 @@ from smartsim._core.control.manifest import ( _LaunchedManifestMetadata as LaunchedManifestMetadata, ) -from smartsim._core.launcher.step import Step -from smartsim.database import Orchestrator -from smartsim.entity import Ensemble, Model -from smartsim.entity.dbobject import DBModel, DBScript +from smartsim.database import FeatureStore +from smartsim.entity.dbobject import FSModel, FSScript from smartsim.error import SmartSimError from smartsim.settings import RunSettings +if t.TYPE_CHECKING: + from smartsim._core.launcher.step import Step + from smartsim.entity import Ensemble, Model + # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b @@ -55,7 +57,7 @@ # ---- create entities for testing -------- _EntityResult = t.Tuple[ - Experiment, t.Tuple[Model, Model], Ensemble, Orchestrator, DBModel, DBScript + Experiment, t.Tuple["Model", "Model"], "Ensemble", FeatureStore, FSModel, FSScript ] @@ -68,12 +70,12 @@ def entities(test_dir: str) -> _EntityResult: model_2 = exp.create_model("model_1", run_settings=rs) ensemble = exp.create_ensemble("ensemble", run_settings=rs, replicas=1) - orc = Orchestrator() + orc = FeatureStore() orc_1 = deepcopy(orc) orc_1.name = "orc2" - db_script = DBScript("some-script", "def main():\n print('hello world')\n") - db_model = DBModel("some-model", "TORCH", b"some-model-bytes") + db_script = FSScript("some-script", "def main():\n print('hello world')\n") + db_model = FSModel("some-model", "TORCH", b"some-model-bytes") return exp, (model, model_2), ensemble, orc, db_model, db_script @@ -86,7 +88,7 @@ def test_separate(entities: _EntityResult) -> None: assert len(manifest.models) == 1 assert manifest.ensembles[0] == ensemble assert len(manifest.ensembles) == 1 - assert manifest.dbs[0] == orc + assert manifest.fss[0] == feature_store def test_separate_type() -> None: @@ -98,7 +100,7 @@ def test_name_collision(entities: _EntityResult) -> None: _, (model, model_2), _, _, _, _ = entities with pytest.raises(SmartSimError): - _ = Manifest(model, model_2) + _ = Manifest(application, application_2) def test_catch_empty_ensemble(entities: _EntityResult) -> None: @@ -124,36 +126,36 @@ class Person: @pytest.mark.parametrize( - "target_obj, target_prop, target_value, has_db_objects", + "target_obj, target_prop, target_value, has_fs_objects", [ - pytest.param(None, None, None, False, id="No DB Objects"), - pytest.param("m0", "dbm", "dbm", True, id="Model w/ DB Model"), - pytest.param("m0", "dbs", "dbs", True, id="Model w/ DB Script"), - pytest.param("ens", "dbm", "dbm", True, id="Ensemble w/ DB Model"), - pytest.param("ens", "dbs", "dbs", True, id="Ensemble w/ DB Script"), - pytest.param("ens_0", "dbm", "dbm", True, id="Ensemble Member w/ DB Model"), - pytest.param("ens_0", "dbs", "dbs", True, id="Ensemble Member w/ DB Script"), + pytest.param(None, None, None, False, id="No FS Objects"), + pytest.param("a0", "fsm", "fsm", True, id="Model w/ FS Model"), + pytest.param("a0", "fss", "fss", True, id="Model w/ FS Script"), + pytest.param("ens", "fsm", "fsm", True, id="Ensemble w/ FS Model"), + pytest.param("ens", "fss", "fss", True, id="Ensemble w/ FS Script"), + pytest.param("ens_0", "fsm", "fsm", True, id="Ensemble Member w/ FS Model"), + pytest.param("ens_0", "fss", "fss", True, id="Ensemble Member w/ FS Script"), ], ) -def test_manifest_detects_db_objects( +def test_manifest_detects_fs_objects( monkeypatch: pytest.MonkeyPatch, target_obj: str, target_prop: str, target_value: str, - has_db_objects: bool, + has_fs_objects: bool, entities: _EntityResult, ) -> None: - _, (model, _), ensemble, _, db_model, db_script = entities + _, (app, _), ensemble, _, fs_model, fs_script = entities target_map = { - "m0": model, - "dbm": db_model, - "dbs": db_script, + "a0": app, + "fsm": fs_model, + "fss": fs_script, "ens": ensemble, "ens_0": ensemble.entities[0], } prop_map = { - "dbm": "_db_models", - "dbs": "_db_scripts", + "fsm": "_fs_models", + "fss": "_fs_scripts", } if target_obj: patch = ( @@ -163,43 +165,45 @@ def test_manifest_detects_db_objects( ) monkeypatch.setattr(*patch) - assert Manifest(model, ensemble).has_db_objects == has_db_objects + assert Manifest(model, ensemble).has_fs_objects == has_fs_objects def test_launched_manifest_transform_data(entities: _EntityResult) -> None: - _, (model, model_2), ensemble, orc, _, _ = entities + _, (application, application_2), ensemble, feature_store, _, _ = entities - models = [(model, 1), (model_2, 2)] + applications = [(application, 1), (application_2, 2)] ensembles = [(ensemble, [(m, i) for i, m in enumerate(ensemble.entities)])] - dbs = [(orc, [(n, i) for i, n in enumerate(orc.entities)])] - lmb = LaunchedManifest( + fss = [(feature_store, [(n, i) for i, n in enumerate(feature_store.entities)])] + launched = LaunchedManifest( metadata=LaunchedManifestMetadata("name", "path", "launcher", "run_id"), - models=models, # type: ignore + applications=applications, # type: ignore ensembles=ensembles, # type: ignore - databases=dbs, # type: ignore + featurestores=fss, # type: ignore ) - transformed = lmb.map(lambda x: str(x)) + transformed = launched.map(lambda x: str(x)) - assert transformed.models == tuple((m, str(i)) for m, i in models) + assert transformed.applications == tuple((m, str(i)) for m, i in applications) assert transformed.ensembles[0][1] == tuple((m, str(i)) for m, i in ensembles[0][1]) - assert transformed.databases[0][1] == tuple((n, str(i)) for n, i in dbs[0][1]) + assert transformed.featurestores[0][1] == tuple((n, str(i)) for n, i in fss[0][1]) def test_launched_manifest_builder_correctly_maps_data(entities: _EntityResult) -> None: - _, (model, model_2), ensemble, orc, _, _ = entities + _, (application, application_2), ensemble, feature_store, _, _ = entities lmb = LaunchedManifestBuilder( "name", "path", "launcher name", str(uuid4()) ) # type: ignore - lmb.add_model(model, 1) - lmb.add_model(model_2, 1) + lmb.add_application(application, 1) + lmb.add_application(application_2, 1) lmb.add_ensemble(ensemble, [i for i in range(len(ensemble.entities))]) - lmb.add_database(orc, [i for i in range(len(orc.entities))]) + lmb.add_feature_store( + feature_store, [i for i in range(len(feature_store.entities))] + ) manifest = lmb.finalize() - assert len(manifest.models) == 2 + assert len(manifest.applications) == 2 assert len(manifest.ensembles) == 1 - assert len(manifest.databases) == 1 + assert len(manifest.featurestores) == 1 def test_launced_manifest_builder_raises_if_lens_do_not_match( @@ -213,7 +217,7 @@ def test_launced_manifest_builder_raises_if_lens_do_not_match( with pytest.raises(ValueError): lmb.add_ensemble(ensemble, list(range(123))) with pytest.raises(ValueError): - lmb.add_database(orc, list(range(123))) + lmb.add_feature_store(feature_store, list(range(123))) def test_launched_manifest_builer_raises_if_attaching_data_to_empty_collection( @@ -221,7 +225,7 @@ def test_launched_manifest_builer_raises_if_attaching_data_to_empty_collection( ) -> None: _, _, ensemble, _, _, _ = entities - lmb: LaunchedManifestBuilder[t.Tuple[str, Step]] = LaunchedManifestBuilder( + lmb: LaunchedManifestBuilder[t.Tuple[str, "Step"]] = LaunchedManifestBuilder( "name", "path", "launcher", str(uuid4()) ) monkeypatch.setattr(ensemble, "entities", []) @@ -231,7 +235,7 @@ def test_launched_manifest_builer_raises_if_attaching_data_to_empty_collection( def test_lmb_and_launched_manifest_have_same_paths_for_launched_metadata() -> None: exp_path = "/path/to/some/exp" - lmb: LaunchedManifestBuilder[t.Tuple[str, Step]] = LaunchedManifestBuilder( + lmb: LaunchedManifestBuilder[t.Tuple[str, "Step"]] = LaunchedManifestBuilder( "exp_name", exp_path, "launcher", str(uuid4()) ) manifest = lmb.finalize() diff --git a/tests/test_model.py b/tests/_legacy/test_model.py similarity index 75% rename from tests/test_model.py rename to tests/_legacy/test_model.py index 152ce20584..f8a84deb8d 100644 --- a/tests/test_model.py +++ b/tests/_legacy/test_model.py @@ -24,6 +24,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import typing as t from uuid import uuid4 import numpy as np @@ -32,11 +33,14 @@ from smartsim import Experiment from smartsim._core.control.manifest import LaunchedManifestBuilder from smartsim._core.launcher.step import SbatchStep, SrunStep -from smartsim.entity import Ensemble, Model -from smartsim.entity.model import _parse_model_parameters +from smartsim.builders import Ensemble +from smartsim.entity import Application from smartsim.error import EntityExistsError, SSUnsupportedError from smartsim.settings import RunSettings, SbatchSettings, SrunSettings -from smartsim.settings.mpiSettings import _BaseMPISettings + +# from smartsim.settings.mpiSettings import + +_BaseMPISettings = t.Any # The tests in this file belong to the slow_tests group pytestmark = pytest.mark.slow_tests @@ -46,7 +50,7 @@ def test_register_incoming_entity_preexists(): exp = Experiment("experiment", launcher="local") rs = RunSettings("python", exe_args="sleep.py") ensemble = exp.create_ensemble(name="ensemble", replicas=1, run_settings=rs) - m = exp.create_model("model", run_settings=rs) + m = exp.create_application("application", run_settings=rs) m.register_incoming_entity(ensemble["ensemble_0"]) assert len(m.incoming_entities) == 1 with pytest.raises(EntityExistsError): @@ -56,36 +60,38 @@ def test_register_incoming_entity_preexists(): def test_disable_key_prefixing(): exp = Experiment("experiment", launcher="local") rs = RunSettings("python", exe_args="sleep.py") - m = exp.create_model("model", run_settings=rs) + m = exp.create_application("application", run_settings=rs) m.disable_key_prefixing() assert m.query_key_prefixing() == False -def test_catch_colo_mpmd_model(): +def test_catch_colo_mpmd_application(): exp = Experiment("experiment", launcher="local") rs = _BaseMPISettings("python", exe_args="sleep.py", fail_if_missing_exec=False) - # make it an mpmd model + # make it an mpmd application rs_2 = _BaseMPISettings("python", exe_args="sleep.py", fail_if_missing_exec=False) rs.make_mpmd(rs_2) - model = exp.create_model("bad_colo_model", rs) + application = exp.create_application("bad_colo_application", rs) # make it colocated which should raise and error with pytest.raises(SSUnsupportedError): - model.colocate_db() + application.colocate_fs() -def test_attach_batch_settings_to_model(): +def test_attach_batch_settings_to_application(): exp = Experiment("experiment", launcher="slurm") bs = SbatchSettings() rs = SrunSettings("python", exe_args="sleep.py") - model_wo_bs = exp.create_model("test_model", run_settings=rs) - assert model_wo_bs.batch_settings is None + application_wo_bs = exp.create_application("test_application", run_settings=rs) + assert application_wo_bs.batch_settings is None - model_w_bs = exp.create_model("test_model_2", run_settings=rs, batch_settings=bs) - assert isinstance(model_w_bs.batch_settings, SbatchSettings) + application_w_bs = exp.create_application( + "test_application_2", run_settings=rs, batch_settings=bs + ) + assert isinstance(application_w_bs.batch_settings, SbatchSettings) @pytest.fixture @@ -118,53 +124,57 @@ def launch_step_nop(self, step, entity): return _monkeypatch_exp_controller -def test_model_with_batch_settings_makes_batch_step( +def test_application_with_batch_settings_makes_batch_step( monkeypatch_exp_controller, test_dir ): exp = Experiment("experiment", launcher="slurm", exp_path=test_dir) bs = SbatchSettings() rs = SrunSettings("python", exe_args="sleep.py") - model = exp.create_model("test_model", run_settings=rs, batch_settings=bs) + application = exp.create_application( + "test_application", run_settings=rs, batch_settings=bs + ) entity_steps = monkeypatch_exp_controller(exp) - exp.start(model) + exp.start(application) assert len(entity_steps) == 1 step, entity = entity_steps[0] - assert isinstance(entity, Model) + assert isinstance(entity, Application) assert isinstance(step, SbatchStep) -def test_model_without_batch_settings_makes_run_step( +def test_application_without_batch_settings_makes_run_step( monkeypatch, monkeypatch_exp_controller, test_dir ): exp = Experiment("experiment", launcher="slurm", exp_path=test_dir) rs = SrunSettings("python", exe_args="sleep.py") - model = exp.create_model("test_model", run_settings=rs) + application = exp.create_application("test_application", run_settings=rs) # pretend we are in an allocation to not raise alloc err monkeypatch.setenv("SLURM_JOB_ID", "12345") entity_steps = monkeypatch_exp_controller(exp) - exp.start(model) + exp.start(application) assert len(entity_steps) == 1 step, entity = entity_steps[0] - assert isinstance(entity, Model) + assert isinstance(entity, Application) assert isinstance(step, SrunStep) -def test_models_batch_settings_are_ignored_in_ensemble( +def test_applications_batch_settings_are_ignored_in_ensemble( monkeypatch_exp_controller, test_dir ): exp = Experiment("experiment", launcher="slurm", exp_path=test_dir) bs_1 = SbatchSettings(nodes=5) rs = SrunSettings("python", exe_args="sleep.py") - model = exp.create_model("test_model", run_settings=rs, batch_settings=bs_1) + application = exp.create_application( + "test_application", run_settings=rs, batch_settings=bs_1 + ) bs_2 = SbatchSettings(nodes=10) ens = exp.create_ensemble("test_ensemble", batch_settings=bs_2) - ens.add_model(model) + ens.add_application(application) entity_steps = monkeypatch_exp_controller(exp) exp.start(ens) @@ -176,18 +186,7 @@ def test_models_batch_settings_are_ignored_in_ensemble( assert step.batch_settings.batch_args["nodes"] == "10" assert len(step.step_cmds) == 1 step_cmd = step.step_cmds[0] - assert any("srun" in tok for tok in step_cmd) # call the model using run settings + assert any( + "srun" in tok for tok in step_cmd + ) # call the application using run settings assert not any("sbatch" in tok for tok in step_cmd) # no sbatch in sbatch - - -@pytest.mark.parametrize("dtype", [int, np.float32, str]) -def test_good_model_params(dtype): - print(dtype(0.6)) - params = {"foo": dtype(0.6)} - assert all(isinstance(val, str) for val in _parse_model_parameters(params).values()) - - -@pytest.mark.parametrize("bad_val", [["eggs"], {"n": 5}, object]) -def test_bad_model_params(bad_val): - with pytest.raises(TypeError): - _parse_model_parameters({"foo": bad_val}) diff --git a/tests/test_modelwriter.py b/tests/_legacy/test_modelwriter.py similarity index 86% rename from tests/test_modelwriter.py rename to tests/_legacy/test_modelwriter.py index a857d7c5f0..9aab51e619 100644 --- a/tests/test_modelwriter.py +++ b/tests/_legacy/test_modelwriter.py @@ -31,7 +31,7 @@ import pytest -from smartsim._core.generation.modelwriter import ModelWriter +from smartsim._core.generation.modelwriter import ApplicationWriter from smartsim.error.errors import ParameterWriterError, SmartSimError from smartsim.settings import RunSettings @@ -62,9 +62,9 @@ def test_write_easy_configs(fileutils, test_dir): dir_util.copy_tree(conf_path, test_dir) assert path.isdir(test_dir) - # init modelwriter - writer = ModelWriter() - writer.configure_tagged_model_files(glob(test_dir + "/*"), param_dict) + # init ApplicationWriter + writer = ApplicationWriter() + writer.configure_tagged_application_files(glob(test_dir + "/*"), param_dict) written_files = sorted(glob(test_dir + "/*")) correct_files = sorted(glob(correct_path + "*")) @@ -90,11 +90,11 @@ def test_write_med_configs(fileutils, test_dir): dir_util.copy_tree(conf_path, test_dir) assert path.isdir(test_dir) - # init modelwriter - writer = ModelWriter() + # init ApplicationWriter + writer = ApplicationWriter() writer.set_tag(writer.tag, "(;.+;)") assert writer.regex == "(;.+;)" - writer.configure_tagged_model_files(glob(test_dir + "/*"), param_dict) + writer.configure_tagged_application_files(glob(test_dir + "/*"), param_dict) written_files = sorted(glob(test_dir + "/*")) correct_files = sorted(glob(correct_path + "*")) @@ -122,10 +122,10 @@ def test_write_new_tag_configs(fileutils, test_dir): dir_util.copy_tree(conf_path, test_dir) assert path.isdir(test_dir) - # init modelwriter - writer = ModelWriter() + # init ApplicationWriter + writer = ApplicationWriter() writer.set_tag("@") - writer.configure_tagged_model_files(glob(test_dir + "/*"), param_dict) + writer.configure_tagged_application_files(glob(test_dir + "/*"), param_dict) written_files = sorted(glob(test_dir + "/*")) correct_files = sorted(glob(correct_path + "*")) @@ -135,13 +135,13 @@ def test_write_new_tag_configs(fileutils, test_dir): def test_mw_error_1(): - writer = ModelWriter() + writer = ApplicationWriter() with pytest.raises(ParameterWriterError): - writer.configure_tagged_model_files("[not/a/path]", {"5": 10}) + writer.configure_tagged_application_files("[not/a/path]", {"5": 10}) def test_mw_error_2(): - writer = ModelWriter() + writer = ApplicationWriter() with pytest.raises(ParameterWriterError): writer._write_changes("[not/a/path]") @@ -157,9 +157,9 @@ def test_write_mw_error_3(fileutils, test_dir): dir_util.copy_tree(conf_path, test_dir) assert path.isdir(test_dir) - # init modelwriter - writer = ModelWriter() + # init ApplicationWriter + writer = ApplicationWriter() with pytest.raises(SmartSimError): - writer.configure_tagged_model_files( + writer.configure_tagged_application_files( glob(test_dir + "/*"), param_dict, make_missing_tags_fatal=True ) diff --git a/tests/test_mpi_settings.py b/tests/_legacy/test_mpi_settings.py similarity index 99% rename from tests/test_mpi_settings.py rename to tests/_legacy/test_mpi_settings.py index 7d8db6e757..40c3f4ce0a 100644 --- a/tests/test_mpi_settings.py +++ b/tests/_legacy/test_mpi_settings.py @@ -173,7 +173,7 @@ def test_mpi_add_mpmd(): def test_catch_colo_mpmd(): settings = _BaseMPISettings(*default_mpi_args, **default_mpi_kwargs) - settings.colocated_db_settings = {"port": 6379, "cpus": 1} + settings.colocated_fs_settings = {"port": 6379, "cpus": 1} settings_2 = _BaseMPISettings(*default_mpi_args, **default_mpi_kwargs) with pytest.raises(SSUnsupportedError): settings.make_mpmd(settings_2) diff --git a/tests/test_multidb.py b/tests/_legacy/test_multidb.py similarity index 55% rename from tests/test_multidb.py rename to tests/_legacy/test_multidb.py index 81f21856af..3e48d87522 100644 --- a/tests/test_multidb.py +++ b/tests/_legacy/test_multidb.py @@ -28,11 +28,11 @@ import pytest from smartsim import Experiment -from smartsim.database import Orchestrator +from smartsim.database import FeatureStore from smartsim.entity.entity import SmartSimEntity from smartsim.error.errors import SSDBIDConflictError from smartsim.log import get_logger -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b @@ -40,7 +40,7 @@ logger = get_logger(__name__) -supported_dbs = ["uds", "tcp"] +supported_fss = ["uds", "tcp"] on_wlm = (pytest.test_launcher in pytest.wlm_options,) @@ -52,7 +52,7 @@ def make_entity_context(exp: Experiment, entity: SmartSimEntity): try: yield entity finally: - if exp.get_status(entity)[0] == SmartSimStatus.STATUS_RUNNING: + if exp.get_status(entity)[0] == JobStatus.RUNNING: exp.stop(entity) @@ -66,76 +66,79 @@ def choose_host(wlmutils, index=0): def check_not_failed(exp, *args): statuses = exp.get_status(*args) - assert all(stat is not SmartSimStatus.STATUS_FAILED for stat in statuses) + assert all(stat is not JobStatus.FAILED for stat in statuses) -@pytest.mark.parametrize("db_type", supported_dbs) -def test_db_identifier_standard_then_colo_error( - fileutils, wlmutils, coloutils, db_type, test_dir +@pytest.mark.parametrize("fs_type", supported_fss) +def test_fs_identifier_standard_then_colo_error( + fileutils, wlmutils, coloutils, fs_type, test_dir ): - """Test that it is possible to create_database then colocate_db_uds/colocate_db_tcp - with unique db_identifiers""" + """Test that it is possible to create_feature_store then colocate_fs_uds/colocate_fs_tcp + with unique fs_identifiers""" # Set experiment name - exp_name = "test_db_identifier_standard_then_colo" + exp_name = "test_fs_identifier_standard_then_colo" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() test_interface = wlmutils.get_test_interface() test_port = wlmutils.get_test_port() - test_script = fileutils.get_test_conf_path("smartredis/db_id_err.py") + test_script = fileutils.get_test_conf_path("smartredis/fs_id_err.py") # Create SmartSim Experiment exp = Experiment(exp_name, launcher=test_launcher, exp_path=test_dir) - # create regular database - orc = exp.create_database( + # create regular feature store + feature_store = exp.create_feature_store( port=test_port, interface=test_interface, - db_identifier="testdb_colo", + fs_identifier="testdb_colo", hosts=choose_host(wlmutils), ) - assert orc.name == "testdb_colo" + assert feature_store.name == "testdb_colo" - db_args = { + fs_args = { "port": test_port + 1, - "db_cpus": 1, + "fs_cpus": 1, "debug": True, - "db_identifier": "testdb_colo", + "fs_identifier": "testdb_colo", } smartsim_model = coloutils.setup_test_colo( - fileutils, db_type, exp, test_script, db_args, on_wlm=on_wlm + fileutils, fs_type, exp, test_script, fs_args, on_wlm=on_wlm ) assert ( - smartsim_model.run_settings.colocated_db_settings["db_identifier"] + smartsim_model.run_settings.colocated_fs_settings["fs_identifier"] == "testdb_colo" ) - with make_entity_context(exp, orc), make_entity_context(exp, smartsim_model): - exp.start(orc) + with ( + make_entity_context(exp, feature_store), + make_entity_context(exp, smartsim_model), + ): + exp.start(feature_store) with pytest.raises(SSDBIDConflictError) as ex: exp.start(smartsim_model) assert ( - "has already been used. Pass in a unique name for db_identifier" + "has already been used. Pass in a unique name for fs_identifier" in ex.value.args[0] ) - check_not_failed(exp, orc) + check_not_failed(exp, feature_store) -@pytest.mark.parametrize("db_type", supported_dbs) -def test_db_identifier_colo_then_standard( - fileutils, wlmutils, coloutils, db_type, test_dir +@pytest.mark.parametrize("fs_type", supported_fss) +def test_fs_identifier_colo_then_standard( + fileutils, wlmutils, coloutils, fs_type, test_dir ): - """Test colocate_db_uds/colocate_db_tcp then create_database with database + """Test colocate_fs_uds/colocate_fs_tcp then create_feature_store with feature store identifiers. """ # Set experiment name - exp_name = "test_db_identifier_colo_then_standard" + exp_name = "test_fs_identifier_colo_then_standard" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -152,52 +155,55 @@ def test_db_identifier_colo_then_standard( colo_settings.set_tasks_per_node(1) # Create the SmartSim Model - smartsim_model = exp.create_model("colocated_model", colo_settings) + smartsim_model = exp.create_application("colocated_model", colo_settings) - db_args = { + fs_args = { "port": test_port, - "db_cpus": 1, + "fs_cpus": 1, "debug": True, - "db_identifier": "testdb_colo", + "fs_identifier": "testdb_colo", } smartsim_model = coloutils.setup_test_colo( fileutils, - db_type, + fs_type, exp, test_script, - db_args, + fs_args, on_wlm=on_wlm, ) assert ( - smartsim_model.run_settings.colocated_db_settings["db_identifier"] + smartsim_model.run_settings.colocated_fs_settings["fs_identifier"] == "testdb_colo" ) - # Create Database - orc = exp.create_database( + # Create feature store + feature_store = exp.create_feature_store( port=test_port + 1, interface=test_interface, - db_identifier="testdb_colo", + fs_identifier="testdb_colo", hosts=choose_host(wlmutils), ) - assert orc.name == "testdb_colo" + assert feature_store.name == "testdb_colo" - with make_entity_context(exp, orc), make_entity_context(exp, smartsim_model): + with ( + make_entity_context(exp, feature_store), + make_entity_context(exp, smartsim_model), + ): exp.start(smartsim_model, block=True) - exp.start(orc) + exp.start(feature_store) - check_not_failed(exp, orc, smartsim_model) + check_not_failed(exp, feature_store, smartsim_model) -def test_db_identifier_standard_twice_not_unique(wlmutils, test_dir): - """Test uniqueness of db_identifier several calls to create_database, with non unique names, +def test_fs_identifier_standard_twice_not_unique(wlmutils, test_dir): + """Test uniqueness of fs_identifier several calls to create_feature_store, with non unique names, checking error is raised before exp start is called""" # Set experiment name - exp_name = "test_db_identifier_multiple_create_database_not_unique" + exp_name = "test_fs_identifier_multiple_create_feature_store_not_unique" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -207,42 +213,45 @@ def test_db_identifier_standard_twice_not_unique(wlmutils, test_dir): # Create SmartSim Experiment exp = Experiment(exp_name, launcher=test_launcher, exp_path=test_dir) - # CREATE DATABASE with db_identifier - orc = exp.create_database( + # CREATE feature store with fs_identifier + feature_store = exp.create_feature_store( port=test_port, interface=test_interface, - db_identifier="my_db", + fs_identifier="my_fs", hosts=choose_host(wlmutils), ) - assert orc.name == "my_db" + assert feature_store.name == "my_fs" - orc2 = exp.create_database( + feature_store2 = exp.create_feature_store( port=test_port + 1, interface=test_interface, - db_identifier="my_db", + fs_identifier="my_fs", hosts=choose_host(wlmutils, index=1), ) - assert orc2.name == "my_db" + assert feature_store2.name == "my_fs" - # CREATE DATABASE with db_identifier - with make_entity_context(exp, orc2), make_entity_context(exp, orc): - exp.start(orc) + # CREATE feature store with fs_identifier + with ( + make_entity_context(exp, feature_store2), + make_entity_context(exp, feature_store), + ): + exp.start(feature_store) with pytest.raises(SSDBIDConflictError) as ex: - exp.start(orc2) + exp.start(feature_store) assert ( - "has already been used. Pass in a unique name for db_identifier" + "has already been used. Pass in a unique name for fs_identifier" in ex.value.args[0] ) - check_not_failed(exp, orc) + check_not_failed(exp, feature_store) -def test_db_identifier_create_standard_once(test_dir, wlmutils): - """One call to create database with a database identifier""" +def test_fs_identifier_create_standard_once(test_dir, wlmutils): + """One call to create feature store with a feature storeidentifier""" # Set experiment name - exp_name = "test_db_identifier_create_standard_once" + exp_name = "test_fs_identifier_create_standard_once" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -252,22 +261,22 @@ def test_db_identifier_create_standard_once(test_dir, wlmutils): # Create the SmartSim Experiment exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) - # Create the SmartSim database - db = exp.create_database( + # Create the SmartSim feature store + fs = exp.create_feature_store( port=test_port, - db_nodes=1, + fs_nodes=1, interface=test_interface, - db_identifier="testdb_reg", + fs_identifier="testdb_reg", hosts=choose_host(wlmutils), ) - with make_entity_context(exp, db): - exp.start(db) + with make_entity_context(exp, fs): + exp.start(fs) - check_not_failed(exp, db) + check_not_failed(exp, fs) -def test_multidb_create_standard_twice(wlmutils, test_dir): - """Multiple calls to create database with unique db_identifiers""" +def test_multifs_create_standard_twice(wlmutils, test_dir): + """Multiple calls to create feature store with unique fs_identifiers""" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -276,36 +285,36 @@ def test_multidb_create_standard_twice(wlmutils, test_dir): # start a new Experiment for this section exp = Experiment( - "test_multidb_create_standard_twice", exp_path=test_dir, launcher=test_launcher + "test_multifs_create_standard_twice", exp_path=test_dir, launcher=test_launcher ) - # create and start an instance of the Orchestrator database - db = exp.create_database( + # create and start an instance of the FeatureStore feature store + fs = exp.create_feature_store( port=test_port, interface=test_interface, - db_identifier="testdb_reg", + fs_identifier="testdb_reg", hosts=choose_host(wlmutils, 1), ) - # create database with different db_id - db2 = exp.create_database( + # create feature store with different fs_id + fs2 = exp.create_feature_store( port=test_port + 1, interface=test_interface, - db_identifier="testdb_reg2", + fs_identifier="testdb_reg2", hosts=choose_host(wlmutils, 2), ) # launch - with make_entity_context(exp, db), make_entity_context(exp, db2): - exp.start(db, db2) + with make_entity_context(exp, fs), make_entity_context(exp, fs2): + exp.start(fs, fs2) - with make_entity_context(exp, db), make_entity_context(exp, db2): - exp.start(db, db2) + with make_entity_context(exp, fs), make_entity_context(exp, fs2): + exp.start(fs, fs2) -@pytest.mark.parametrize("db_type", supported_dbs) -def test_multidb_colo_once(fileutils, test_dir, wlmutils, coloutils, db_type): - """create one model with colocated database with db_identifier""" +@pytest.mark.parametrize("fs_type", supported_fss) +def test_multifs_colo_once(fileutils, test_dir, wlmutils, coloutils, fs_type): + """create one model with colocated feature store with fs_identifier""" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -315,7 +324,7 @@ def test_multidb_colo_once(fileutils, test_dir, wlmutils, coloutils, db_type): # start a new Experiment for this section exp = Experiment( - "test_multidb_colo_once", launcher=test_launcher, exp_path=test_dir + "test_multifs_colo_once", launcher=test_launcher, exp_path=test_dir ) # create run settings @@ -324,22 +333,22 @@ def test_multidb_colo_once(fileutils, test_dir, wlmutils, coloutils, db_type): run_settings.set_tasks_per_node(1) # Create the SmartSim Model - smartsim_model = exp.create_model("smartsim_model", run_settings) + smartsim_model = exp.create_application("smartsim_model", run_settings) - db_args = { + fs_args = { "port": test_port + 1, - "db_cpus": 1, + "fs_cpus": 1, "debug": True, - "db_identifier": "testdb_colo", + "fs_identifier": "testdb_colo", } - # Create model with colocated database + # Create model with colocated feature store smartsim_model = coloutils.setup_test_colo( fileutils, - db_type, + fs_type, exp, test_script, - db_args, + fs_args, on_wlm=on_wlm, ) @@ -349,9 +358,9 @@ def test_multidb_colo_once(fileutils, test_dir, wlmutils, coloutils, db_type): check_not_failed(exp, smartsim_model) -@pytest.mark.parametrize("db_type", supported_dbs) -def test_multidb_standard_then_colo(fileutils, test_dir, wlmutils, coloutils, db_type): - """Create regular database then colocate_db_tcp/uds with unique db_identifiers""" +@pytest.mark.parametrize("fs_type", supported_fss) +def test_multifs_standard_then_colo(fileutils, test_dir, wlmutils, coloutils, fs_type): + """Create regular feature store then colocate_fs_tcp/uds with unique fs_identifiers""" # Retrieve parameters from testing environment test_port = wlmutils.get_test_port() @@ -362,43 +371,43 @@ def test_multidb_standard_then_colo(fileutils, test_dir, wlmutils, coloutils, db # start a new Experiment for this section exp = Experiment( - "test_multidb_standard_then_colo", exp_path=test_dir, launcher=test_launcher + "test_multifs_standard_then_colo", exp_path=test_dir, launcher=test_launcher ) - # create and generate an instance of the Orchestrator database - db = exp.create_database( + # create and generate an instance of the FeatureStore feature store + fs = exp.create_feature_store( port=test_port, interface=test_interface, - db_identifier="testdb_reg", + fs_identifier="testdb_reg", hosts=choose_host(wlmutils), ) - db_args = { + fs_args = { "port": test_port + 1, - "db_cpus": 1, + "fs_cpus": 1, "debug": True, - "db_identifier": "testdb_colo", + "fs_identifier": "testdb_colo", } - # Create model with colocated database + # Create model with colocated feature store smartsim_model = coloutils.setup_test_colo( fileutils, - db_type, + fs_type, exp, test_script, - db_args, + fs_args, on_wlm=on_wlm, ) - with make_entity_context(exp, db), make_entity_context(exp, smartsim_model): - exp.start(db) + with make_entity_context(exp, fs), make_entity_context(exp, smartsim_model): + exp.start(fs) exp.start(smartsim_model, block=True) - check_not_failed(exp, smartsim_model, db) + check_not_failed(exp, smartsim_model, fs) -@pytest.mark.parametrize("db_type", supported_dbs) -def test_multidb_colo_then_standard(fileutils, test_dir, wlmutils, coloutils, db_type): - """create regular database then colocate_db_tcp/uds with unique db_identifiers""" +@pytest.mark.parametrize("fs_type", supported_fss) +def test_multifs_colo_then_standard(fileutils, test_dir, wlmutils, coloutils, fs_type): + """create regular feature store then colocate_fs_tcp/uds with unique fs_identifiers""" # Retrieve parameters from testing environment test_port = wlmutils.get_test_port() @@ -411,49 +420,49 @@ def test_multidb_colo_then_standard(fileutils, test_dir, wlmutils, coloutils, db # start a new Experiment exp = Experiment( - "test_multidb_colo_then_standard", exp_path=test_dir, launcher=test_launcher + "test_multifs_colo_then_standard", exp_path=test_dir, launcher=test_launcher ) - db_args = { + fs_args = { "port": test_port, - "db_cpus": 1, + "fs_cpus": 1, "debug": True, - "db_identifier": "testdb_colo", + "fs_identifier": "testdb_colo", } - # Create model with colocated database + # Create model with colocated feature store smartsim_model = coloutils.setup_test_colo( - fileutils, db_type, exp, test_script, db_args, on_wlm=on_wlm + fileutils, fs_type, exp, test_script, fs_args, on_wlm=on_wlm ) - # create and start an instance of the Orchestrator database - db = exp.create_database( + # create and start an instance of the FeatureStore feature store + fs = exp.create_feature_store( port=test_port + 1, interface=test_interface, - db_identifier="testdb_reg", + fs_identifier="testdb_reg", hosts=choose_host(wlmutils), ) - with make_entity_context(exp, db), make_entity_context(exp, smartsim_model): + with make_entity_context(exp, fs), make_entity_context(exp, smartsim_model): exp.start(smartsim_model, block=False) - exp.start(db) + exp.start(fs) exp.poll(smartsim_model) - check_not_failed(exp, db, smartsim_model) + check_not_failed(exp, fs, smartsim_model) @pytest.mark.skipif( pytest.test_launcher not in pytest.wlm_options, reason="Not testing WLM integrations", ) -@pytest.mark.parametrize("db_type", supported_dbs) -def test_launch_cluster_orc_single_dbid( - test_dir, coloutils, fileutils, wlmutils, db_type +@pytest.mark.parametrize("fs_type", supported_fss) +def test_launch_cluster_feature_store_single_fsid( + test_dir, coloutils, fileutils, wlmutils, fs_type ): - """test clustered 3-node orchestrator with single command with a database identifier""" + """test clustered 3-node FeatureStore with single command with a feature store identifier""" # TODO detect number of nodes in allocation and skip if not sufficent - exp_name = "test_launch_cluster_orc_single_dbid" + exp_name = "test_launch_cluster_feature_store_single_fsid" launcher = wlmutils.get_test_launcher() test_port = wlmutils.get_test_port() test_script = fileutils.get_test_conf_path("smartredis/multidbid.py") @@ -461,32 +470,35 @@ def test_launch_cluster_orc_single_dbid( # batch = False to launch on existing allocation network_interface = wlmutils.get_test_interface() - orc: Orchestrator = exp.create_database( + feature_store: FeatureStore = exp.create_feature_store( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, interface=network_interface, single_cmd=True, hosts=wlmutils.get_test_hostlist(), - db_identifier="testdb_reg", + fs_identifier="testdb_reg", ) - db_args = { + fs_args = { "port": test_port, - "db_cpus": 1, + "fs_cpus": 1, "debug": True, - "db_identifier": "testdb_colo", + "fs_identifier": "testdb_colo", } - # Create model with colocated database + # Create model with colocated feature store smartsim_model = coloutils.setup_test_colo( - fileutils, db_type, exp, test_script, db_args, on_wlm=on_wlm + fileutils, fs_type, exp, test_script, fs_args, on_wlm=on_wlm ) - with make_entity_context(exp, orc), make_entity_context(exp, smartsim_model): - exp.start(orc, block=True) + with ( + make_entity_context(exp, feature_store), + make_entity_context(exp, smartsim_model), + ): + exp.start(feature_store, block=True) exp.start(smartsim_model, block=True) - job_dict = exp._control._jobs.get_db_host_addresses() - assert len(job_dict[orc.entities[0].db_identifier]) == 3 + job_dict = exp._control._jobs.get_fs_host_addresses() + assert len(job_dict[feature_store.entities[0].fs_identifier]) == 3 - check_not_failed(exp, orc, smartsim_model) + check_not_failed(exp, feature_store, smartsim_model) diff --git a/tests/test_orc_config_settings.py b/tests/_legacy/test_orc_config_settings.py similarity index 76% rename from tests/test_orc_config_settings.py rename to tests/_legacy/test_orc_config_settings.py index 74d0c1af29..3f32da8db5 100644 --- a/tests/test_orc_config_settings.py +++ b/tests/_legacy/test_orc_config_settings.py @@ -27,7 +27,7 @@ import pytest -from smartsim.database import Orchestrator +from smartsim.database import FeatureStore from smartsim.error import SmartSimError try: @@ -41,41 +41,41 @@ pytestmark = pytest.mark.group_b -def test_config_methods(dbutils, prepare_db, local_db): +def test_config_methods(fsutils, prepare_fs, local_fs): """Test all configuration file edit methods on an active db""" - db = prepare_db(local_db).orchestrator + fs = prepare_fs(local_fs).featurestore # test the happy path and ensure all configuration file edit methods # successfully execute when given correct key-value pairs - configs = dbutils.get_db_configs() + configs = fsutils.get_fs_configs() for setting, value in configs.items(): - config_set_method = dbutils.get_config_edit_method(db, setting) + config_set_method = fsutils.get_config_edit_method(fs, setting) config_set_method(value) - # ensure SmartSimError is raised when Orchestrator.set_db_conf + # ensure SmartSimError is raised when FeatureStore.set_fs_conf # is given invalid CONFIG key-value pairs - ss_error_configs = dbutils.get_smartsim_error_db_configs() + ss_error_configs = fsutils.get_smartsim_error_fs_configs() for key, value_list in ss_error_configs.items(): for value in value_list: with pytest.raises(SmartSimError): - db.set_db_conf(key, value) + fs.set_fs_conf(key, value) - # ensure TypeError is raised when Orchestrator.set_db_conf + # ensure TypeError is raised when FeatureStore.set_fs_conf # is given either a key or a value that is not a string - type_error_configs = dbutils.get_type_error_db_configs() + type_error_configs = fsutils.get_type_error_fs_configs() for key, value_list in type_error_configs.items(): for value in value_list: with pytest.raises(TypeError): - db.set_db_conf(key, value) + fs.set_db_conf(key, value) -def test_config_methods_inactive(dbutils): +def test_config_methods_inactive(fsutils): """Ensure a SmartSimError is raised when trying to - set configurations on an inactive database + set configurations on an inactive feature store """ - db = Orchestrator() - configs = dbutils.get_db_configs() + fs = FeatureStore() + configs = fsutils.get_fs_configs() for setting, value in configs.items(): - config_set_method = dbutils.get_config_edit_method(db, setting) + config_set_method = fsutils.get_config_edit_method(fs, setting) with pytest.raises(SmartSimError): config_set_method(value) diff --git a/tests/test_orchestrator.py b/tests/_legacy/test_orchestrator.py similarity index 56% rename from tests/test_orchestrator.py rename to tests/_legacy/test_orchestrator.py index 66fb894f78..5febb8d1bd 100644 --- a/tests/test_orchestrator.py +++ b/tests/_legacy/test_orchestrator.py @@ -31,7 +31,7 @@ import pytest from smartsim import Experiment -from smartsim.database import Orchestrator +from smartsim.database import FeatureStore from smartsim.error import SmartSimError from smartsim.error.errors import SSUnsupportedError @@ -43,48 +43,48 @@ import conftest -def test_orc_parameters() -> None: +def test_feature_store_parameters() -> None: threads_per_queue = 2 inter_op_threads = 2 intra_op_threads = 2 - db = Orchestrator( - db_nodes=1, + fs = FeatureStore( + fs_nodes=1, threads_per_queue=threads_per_queue, inter_op_threads=inter_op_threads, intra_op_threads=intra_op_threads, ) - assert db.queue_threads == threads_per_queue - assert db.inter_threads == inter_op_threads - assert db.intra_threads == intra_op_threads + assert fs.queue_threads == threads_per_queue + assert fs.inter_threads == inter_op_threads + assert fs.intra_threads == intra_op_threads - module_str = db._rai_module + module_str = fs._rai_module assert "THREADS_PER_QUEUE" in module_str assert "INTRA_OP_PARALLELISM" in module_str assert "INTER_OP_PARALLELISM" in module_str def test_is_not_active() -> None: - db = Orchestrator(db_nodes=1) - assert not db.is_active() + fs = FeatureStore(fs_nodes=1) + assert not fs.is_active() -def test_inactive_orc_get_address() -> None: - db = Orchestrator() +def test_inactive_feature_store_get_address() -> None: + fs = FeatureStore() with pytest.raises(SmartSimError): - db.get_address() + fs.get_address() -def test_orc_is_active_functions( +def test_feature_store_is_active_functions( local_experiment, - prepare_db, - local_db, + prepare_fs, + local_fs, ) -> None: - db = prepare_db(local_db).orchestrator - db = local_experiment.reconnect_orchestrator(db.checkpoint_file) - assert db.is_active() + fs = prepare_fs(local_fs).featurestore + fs = local_experiment.reconnect_feature_store(fs.checkpoint_file) + assert fs.is_active() - # check if the orchestrator can get the address - assert db.get_address() == [f"127.0.0.1:{db.ports[0]}"] + # check if the feature store can get the address + assert fs.get_address() == [f"127.0.0.1:{fs.ports[0]}"] def test_multiple_interfaces( @@ -101,126 +101,135 @@ def test_multiple_interfaces( net_if_addrs = ["lo", net_if_addrs[0]] port = wlmutils.get_test_port() - db = Orchestrator(port=port, interface=net_if_addrs) - db.set_path(test_dir) + fs = FeatureStore(port=port, interface=net_if_addrs) + fs.set_path(test_dir) - exp.start(db) + exp.start(fs) - # check if the orchestrator is active - assert db.is_active() + # check if the FeatureStore is active + assert fs.is_active() - # check if the orchestrator can get the address + # check if the feature store can get the address correct_address = [f"127.0.0.1:{port}"] - if not correct_address == db.get_address(): - exp.stop(db) + if not correct_address == fs.get_address(): + exp.stop(fs) assert False - exp.stop(db) + exp.stop(fs) -def test_catch_local_db_errors() -> None: - # local database with more than one node not allowed +def test_catch_local_feature_store_errors() -> None: + # local feature store with more than one node not allowed with pytest.raises(SSUnsupportedError): - db = Orchestrator(db_nodes=2) + fs = FeatureStore(fs_nodes=2) - # Run command for local orchestrator not allowed + # Run command for local FeatureStore not allowed with pytest.raises(SmartSimError): - db = Orchestrator(run_command="srun") + fs = FeatureStore(run_command="srun") - # Batch mode for local orchestrator is not allowed + # Batch mode for local FeatureStore is not allowed with pytest.raises(SmartSimError): - db = Orchestrator(batch=True) + fs = FeatureStore(batch=True) ##### PBS ###### def test_pbs_set_run_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, interface="lo", launcher="pbs", run_command="aprun", ) - orc.set_run_arg("account", "ACCOUNT") + feature_store.set_run_arg("account", "ACCOUNT") assert all( - [db.run_settings.run_args["account"] == "ACCOUNT" for db in orc.entities] + [ + fs.run_settings.run_args["account"] == "ACCOUNT" + for fs in feature_store.entities + ] ) - orc.set_run_arg("pes-per-numa-node", "5") + feature_store.set_run_arg("pes-per-numa-node", "5") assert all( - ["pes-per-numa-node" not in db.run_settings.run_args for db in orc.entities] + [ + "pes-per-numa-node" not in fs.run_settings.run_args + for fs in feature_store.entities + ] ) def test_pbs_set_batch_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, interface="lo", launcher="pbs", run_command="aprun", ) with pytest.raises(SmartSimError): - orc.set_batch_arg("account", "ACCOUNT") + feature_store.set_batch_arg("account", "ACCOUNT") - orc2 = Orchestrator( + feature_store2 = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=True, interface="lo", launcher="pbs", run_command="aprun", ) - orc2.set_batch_arg("account", "ACCOUNT") - assert orc2.batch_settings.batch_args["account"] == "ACCOUNT" - orc2.set_batch_arg("N", "another_name") - assert "N" not in orc2.batch_settings.batch_args + feature_store2.set_batch_arg("account", "ACCOUNT") + assert feature_store2.batch_settings.batch_args["account"] == "ACCOUNT" + feature_store2.set_batch_arg("N", "another_name") + assert "N" not in feature_store2.batch_settings.batch_args ##### Slurm ###### def test_slurm_set_run_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, interface="lo", launcher="slurm", run_command="srun", ) - orc.set_run_arg("account", "ACCOUNT") + feature_store.set_run_arg("account", "ACCOUNT") assert all( - [db.run_settings.run_args["account"] == "ACCOUNT" for db in orc.entities] + [ + fs.run_settings.run_args["account"] == "ACCOUNT" + for fs in feature_store.entities + ] ) def test_slurm_set_batch_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, interface="lo", launcher="slurm", run_command="srun", ) with pytest.raises(SmartSimError): - orc.set_batch_arg("account", "ACCOUNT") + feature_store.set_batch_arg("account", "ACCOUNT") - orc2 = Orchestrator( + feature_store2 = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=True, interface="lo", launcher="slurm", run_command="srun", ) - orc2.set_batch_arg("account", "ACCOUNT") - assert orc2.batch_settings.batch_args["account"] == "ACCOUNT" + feature_store2.set_batch_arg("account", "ACCOUNT") + assert feature_store2.batch_settings.batch_args["account"] == "ACCOUNT" @pytest.mark.parametrize( @@ -230,98 +239,100 @@ def test_slurm_set_batch_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: pytest.param(False, id="Multiple `srun`s"), ], ) -def test_orc_results_in_correct_number_of_shards(single_cmd: bool) -> None: +def test_feature_store_results_in_correct_number_of_shards(single_cmd: bool) -> None: num_shards = 5 - orc = Orchestrator( + feature_store = FeatureStore( port=12345, launcher="slurm", run_command="srun", - db_nodes=num_shards, + fs_nodes=num_shards, batch=False, single_cmd=single_cmd, ) if single_cmd: - assert len(orc.entities) == 1 - (node,) = orc.entities + assert len(feature_store.entities) == 1 + (node,) = feature_store.entities assert len(node.run_settings.mpmd) == num_shards - 1 else: - assert len(orc.entities) == num_shards - assert all(node.run_settings.mpmd == [] for node in orc.entities) + assert len(feature_store.entities) == num_shards + assert all(node.run_settings.mpmd == [] for node in feature_store.entities) assert ( - orc.num_shards == orc.db_nodes == sum(node.num_shards for node in orc.entities) + feature_store.num_shards + == feature_store.fs_nodes + == sum(node.num_shards for node in feature_store.entities) ) ###### LSF ###### -def test_catch_orc_errors_lsf(wlmutils: t.Type["conftest.WLMUtils"]) -> None: +def test_catch_feature_store_errors_lsf(wlmutils: t.Type["conftest.WLMUtils"]) -> None: with pytest.raises(SSUnsupportedError): - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=2, - db_per_host=2, + fs_nodes=2, + fs_per_host=2, batch=False, launcher="lsf", run_command="jsrun", ) - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=False, hosts=["batch", "host1", "host2"], launcher="lsf", run_command="jsrun", ) with pytest.raises(SmartSimError): - orc.set_batch_arg("P", "MYPROJECT") + feature_store.set_batch_arg("P", "MYPROJECT") def test_lsf_set_run_args(wlmutils: t.Type["conftest.WLMUtils"]) -> None: - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=True, hosts=["batch", "host1", "host2"], launcher="lsf", run_command="jsrun", ) - orc.set_run_arg("l", "gpu-gpu") - assert all(["l" not in db.run_settings.run_args for db in orc.entities]) + feature_store.set_run_arg("l", "gpu-gpu") + assert all(["l" not in fs.run_settings.run_args for fs in feature_store.entities]) def test_lsf_set_batch_args(wlmutils: t.Type["conftest.WLMUtils"]) -> None: - orc = Orchestrator( + feature_store = FeatureStore( wlmutils.get_test_port(), - db_nodes=3, + fs_nodes=3, batch=True, hosts=["batch", "host1", "host2"], launcher="lsf", run_command="jsrun", ) - assert orc.batch_settings.batch_args["m"] == '"batch host1 host2"' - orc.set_batch_arg("D", "102400000") - assert orc.batch_settings.batch_args["D"] == "102400000" + assert feature_store.batch_settings.batch_args["m"] == '"batch host1 host2"' + feature_store.set_batch_arg("D", "102400000") + assert feature_store.batch_settings.batch_args["D"] == "102400000" def test_orc_telemetry(test_dir: str, wlmutils: t.Type["conftest.WLMUtils"]) -> None: - """Ensure the default behavior for an orchestrator is to disable telemetry""" - db = Orchestrator(port=wlmutils.get_test_port()) - db.set_path(test_dir) + """Ensure the default behavior for a feature store is to disable telemetry""" + fs = FeatureStore(port=wlmutils.get_test_port()) + fs.set_path(test_dir) # default is disabled - assert not db.telemetry.is_enabled + assert not fs.telemetry.is_enabled # ensure updating value works as expected - db.telemetry.enable() - assert db.telemetry.is_enabled + fs.telemetry.enable() + assert fs.telemetry.is_enabled # toggle back - db.telemetry.disable() - assert not db.telemetry.is_enabled + fs.telemetry.disable() + assert not fs.telemetry.is_enabled # toggle one more time - db.telemetry.enable() - assert db.telemetry.is_enabled + fs.telemetry.enable() + assert fs.telemetry.is_enabled diff --git a/tests/test_output_files.py b/tests/_legacy/test_output_files.py similarity index 63% rename from tests/test_output_files.py rename to tests/_legacy/test_output_files.py index f3830051c8..55ecfd90a5 100644 --- a/tests/test_output_files.py +++ b/tests/_legacy/test_output_files.py @@ -33,9 +33,9 @@ from smartsim._core.config import CONFIG from smartsim._core.control.controller import Controller, _AnonymousBatchJob from smartsim._core.launcher.step import Step -from smartsim.database.orchestrator import Orchestrator -from smartsim.entity.ensemble import Ensemble -from smartsim.entity.model import Model +from smartsim.builders.ensemble import Ensemble +from smartsim.database.orchestrator import FeatureStore +from smartsim.entity.application import Application from smartsim.settings.base import RunSettings from smartsim.settings.slurmSettings import SbatchSettings, SrunSettings @@ -50,47 +50,71 @@ batch_rs = SrunSettings("echo", ["spam", "eggs"]) ens = Ensemble("ens", params={}, run_settings=rs, batch_settings=bs, replicas=3) -orc = Orchestrator(db_nodes=3, batch=True, launcher="slurm", run_command="srun") -model = Model("test_model", params={}, path="", run_settings=rs) -batch_model = Model( - "batch_test_model", params={}, path="", run_settings=batch_rs, batch_settings=bs +feature_store = FeatureStore( + fs_nodes=3, batch=True, launcher="slurm", run_command="srun" ) -anon_batch_model = _AnonymousBatchJob(batch_model) +application = Application("test_application", params={}, path="", run_settings=rs) +batch_application = Application( + "batch_test_application", + params={}, + path="", + run_settings=batch_rs, + batch_settings=bs, +) +anon_batch_application = _AnonymousBatchJob(batch_application) -def test_mutated_model_output(test_dir): - exp_name = "test-mutated-model-output" +def test_mutated_application_output(test_dir): + exp_name = "test-mutated-application-output" exp = Experiment(exp_name, launcher="local", exp_path=test_dir) - test_model = exp.create_model("test_model", path=test_dir, run_settings=rs) - exp.generate(test_model) - exp.start(test_model, block=True) - - assert pathlib.Path(test_model.path).exists() - assert pathlib.Path(test_model.path, f"{test_model.name}.out").is_symlink() - assert pathlib.Path(test_model.path, f"{test_model.name}.err").is_symlink() - - with open(pathlib.Path(test_model.path, f"{test_model.name}.out"), "r") as file: + test_application = exp.create_application( + "test_application", path=test_dir, run_settings=rs + ) + exp.generate(test_application) + exp.start(test_application, block=True) + + assert pathlib.Path(test_application.path).exists() + assert pathlib.Path( + test_application.path, f"{test_application.name}.out" + ).is_symlink() + assert pathlib.Path( + test_application.path, f"{test_application.name}.err" + ).is_symlink() + + with open( + pathlib.Path(test_application.path, f"{test_application.name}.out"), "r" + ) as file: log_contents = file.read() assert "spam eggs" in log_contents - first_link = os.readlink(pathlib.Path(test_model.path, f"{test_model.name}.out")) - - test_model.run_settings.exe_args = ["hello", "world"] - exp.generate(test_model, overwrite=True) - exp.start(test_model, block=True) - - assert pathlib.Path(test_model.path).exists() - assert pathlib.Path(test_model.path, f"{test_model.name}.out").is_symlink() - assert pathlib.Path(test_model.path, f"{test_model.name}.err").is_symlink() - - with open(pathlib.Path(test_model.path, f"{test_model.name}.out"), "r") as file: + first_link = os.readlink( + pathlib.Path(test_application.path, f"{test_application.name}.out") + ) + + test_application.run_settings.exe_args = ["hello", "world"] + exp.generate(test_application, overwrite=True) + exp.start(test_application, block=True) + + assert pathlib.Path(test_application.path).exists() + assert pathlib.Path( + test_application.path, f"{test_application.name}.out" + ).is_symlink() + assert pathlib.Path( + test_application.path, f"{test_application.name}.err" + ).is_symlink() + + with open( + pathlib.Path(test_application.path, f"{test_application.name}.out"), "r" + ) as file: log_contents = file.read() assert "hello world" in log_contents - second_link = os.readlink(pathlib.Path(test_model.path, f"{test_model.name}.out")) + second_link = os.readlink( + pathlib.Path(test_application.path, f"{test_application.name}.out") + ) with open(first_link, "r") as file: first_historical_log = file.read() @@ -106,16 +130,16 @@ def test_mutated_model_output(test_dir): def test_get_output_files_with_create_job_step(test_dir): """Testing output files through _create_job_step""" exp_dir = pathlib.Path(test_dir) - status_dir = exp_dir / CONFIG.telemetry_subdir / model.type - step = controller._create_job_step(model, status_dir) - expected_out_path = status_dir / model.name / (model.name + ".out") - expected_err_path = status_dir / model.name / (model.name + ".err") + status_dir = exp_dir / CONFIG.telemetry_subdir / application.type + step = controller._create_job_step(application, status_dir) + expected_out_path = status_dir / application.name / (application.name + ".out") + expected_err_path = status_dir / application.name / (application.name + ".err") assert step.get_output_files() == (str(expected_out_path), str(expected_err_path)) @pytest.mark.parametrize( "entity", - [pytest.param(ens, id="ensemble"), pytest.param(orc, id="orchestrator")], + [pytest.param(ens, id="ensemble"), pytest.param(feature_store, id="featurestore")], ) def test_get_output_files_with_create_batch_job_step(entity, test_dir): """Testing output files through _create_batch_job_step""" @@ -137,20 +161,20 @@ def test_get_output_files_with_create_batch_job_step(entity, test_dir): ) -def test_model_get_output_files(test_dir): - """Testing model output files with manual step creation""" +def test_application_get_output_files(test_dir): + """Testing application output files with manual step creation""" exp_dir = pathlib.Path(test_dir) - step = Step(model.name, model.path, model.run_settings) + step = Step(application.name, application.path, application.run_settings) step.meta["status_dir"] = exp_dir / "output_dir" - expected_out_path = step.meta["status_dir"] / (model.name + ".out") - expected_err_path = step.meta["status_dir"] / (model.name + ".err") + expected_out_path = step.meta["status_dir"] / (application.name + ".out") + expected_err_path = step.meta["status_dir"] / (application.name + ".err") assert step.get_output_files() == (str(expected_out_path), str(expected_err_path)) def test_ensemble_get_output_files(test_dir): """Testing ensemble output files with manual step creation""" exp_dir = pathlib.Path(test_dir) - for member in ens.models: + for member in ens.applications: step = Step(member.name, member.path, member.run_settings) step.meta["status_dir"] = exp_dir / "output_dir" expected_out_path = step.meta["status_dir"] / (member.name + ".out") diff --git a/tests/test_pals_settings.py b/tests/_legacy/test_pals_settings.py similarity index 99% rename from tests/test_pals_settings.py rename to tests/_legacy/test_pals_settings.py index 8bc23d14d0..4fcf7cae34 100644 --- a/tests/test_pals_settings.py +++ b/tests/_legacy/test_pals_settings.py @@ -33,7 +33,7 @@ import smartsim._core.config.config from smartsim._core.launcher import PBSLauncher -from smartsim._core.launcher.step.mpiStep import MpiexecStep +from smartsim._core.launcher.step.mpi_step import MpiexecStep from smartsim.error import SSUnsupportedError from smartsim.settings import PalsMpiexecSettings diff --git a/tests/test_pbs_parser.py b/tests/_legacy/test_pbs_parser.py similarity index 88% rename from tests/test_pbs_parser.py rename to tests/_legacy/test_pbs_parser.py index ae01ffb19b..b5b7081751 100644 --- a/tests/test_pbs_parser.py +++ b/tests/_legacy/test_pbs_parser.py @@ -28,7 +28,7 @@ import pytest -from smartsim._core.launcher.pbs import pbsParser +from smartsim._core.launcher.pbs import pbs_parser # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b @@ -39,14 +39,14 @@ def test_parse_qsub(): output = "12345.sdb" - step_id = pbsParser.parse_qsub(output) + step_id = pbs_parser.parse_qsub(output) assert step_id == "12345.sdb" def test_parse_qsub_error(): output = "qsub: Unknown queue" error = "Unknown queue" - parsed_error = pbsParser.parse_qsub_error(output) + parsed_error = pbs_parser.parse_qsub_error(output) assert error == parsed_error @@ -58,7 +58,7 @@ def test_parse_qstat_nodes(fileutils): file_path = fileutils.get_test_conf_path("qstat.json") output = Path(file_path).read_text() nodes = ["server_1", "server_2"] - parsed_nodes = pbsParser.parse_qstat_nodes(output) + parsed_nodes = pbs_parser.parse_qstat_nodes(output) assert nodes == parsed_nodes @@ -70,7 +70,7 @@ def test_parse_qstat_status(): "1289903.sdb jobname username 00:00:00 R queue\n" ) status = "R" - parsed_status = pbsParser.parse_qstat_jobid(output, "1289903.sdb") + parsed_status = pbs_parser.parse_qstat_jobid(output, "1289903.sdb") assert status == parsed_status @@ -80,7 +80,7 @@ def test_parse_qstat_status_not_found(): "---------------- ---------------- ---------------- -------- - -----\n" "1289903.sdb jobname username 00:00:00 R queue\n" ) - parsed_status = pbsParser.parse_qstat_jobid(output, "9999999.sdb") + parsed_status = pbs_parser.parse_qstat_jobid(output, "9999999.sdb") assert parsed_status is None @@ -90,5 +90,5 @@ def test_parse_qstat_status_json(fileutils): file_path = fileutils.get_test_conf_path("qstat.json") output = Path(file_path).read_text() status = "R" - parsed_status = pbsParser.parse_qstat_jobid_json(output, "16705.sdb") + parsed_status = pbs_parser.parse_qstat_jobid_json(output, "16705.sdb") assert status == parsed_status diff --git a/tests/test_pbs_settings.py b/tests/_legacy/test_pbs_settings.py similarity index 100% rename from tests/test_pbs_settings.py rename to tests/_legacy/test_pbs_settings.py diff --git a/tests/test_preview.py b/tests/_legacy/test_preview.py similarity index 78% rename from tests/test_preview.py rename to tests/_legacy/test_preview.py index a18d107281..25a51671d0 100644 --- a/tests/test_preview.py +++ b/tests/_legacy/test_preview.py @@ -36,11 +36,11 @@ import smartsim import smartsim._core._cli.utils as _utils from smartsim import Experiment -from smartsim._core import Manifest, previewrenderer +from smartsim._core import Manifest, preview_renderer from smartsim._core.config import CONFIG from smartsim._core.control.controller import Controller from smartsim._core.control.job import Job -from smartsim.database import Orchestrator +from smartsim.database import FeatureStore from smartsim.entity.entity import SmartSimEntity from smartsim.error.errors import PreviewFormatError from smartsim.settings import QsubBatchSettings, RunSettings @@ -66,41 +66,41 @@ def preview_object(test_dir) -> t.Dict[str, Job]: """ rs = RunSettings(exe="echo", exe_args="ifname=lo") s = SmartSimEntity(name="faux-name", path=test_dir, run_settings=rs) - o = Orchestrator() + o = FeatureStore() o.entity = s - s.db_identifier = "test_db_id" + s.fs_identifier = "test_fs_id" s.ports = [1235] s.num_shards = 1 job = Job("faux-name", "faux-step-id", s, "slurm", True) - active_dbjobs: t.Dict[str, Job] = {"mock_job": job} - return active_dbjobs + active_fsjobs: t.Dict[str, Job] = {"mock_job": job} + return active_fsjobs @pytest.fixture -def preview_object_multidb(test_dir) -> t.Dict[str, Job]: +def preview_object_multifs(test_dir) -> t.Dict[str, Job]: """ - Bare bones orch + Bare bones feature store """ rs = RunSettings(exe="echo", exe_args="ifname=lo") s = SmartSimEntity(name="faux-name", path=test_dir, run_settings=rs) - o = Orchestrator() + o = FeatureStore() o.entity = s - s.db_identifier = "testdb_reg" + s.fs_identifier = "testfs_reg" s.ports = [8750] s.num_shards = 1 job = Job("faux-name", "faux-step-id", s, "slurm", True) rs2 = RunSettings(exe="echo", exe_args="ifname=lo") s2 = SmartSimEntity(name="faux-name_2", path=test_dir, run_settings=rs) - o2 = Orchestrator() + o2 = FeatureStore() o2.entity = s2 - s2.db_identifier = "testdb_reg2" + s2.fs_identifier = "testfs_reg2" s2.ports = [8752] s2.num_shards = 1 job2 = Job("faux-name_2", "faux-step-id_2", s2, "slurm", True) - active_dbjobs: t.Dict[str, Job] = {"mock_job": job, "mock_job2": job2} - return active_dbjobs + active_fsjobs: t.Dict[str, Job] = {"mock_job": job, "mock_job2": job2} + return active_fsjobs def add_batch_resources(wlmutils, batch_settings): @@ -130,7 +130,7 @@ def test_get_ifname_filter(): loader = jinja2.DictLoader(template_dict) env = jinja2.Environment(loader=loader, autoescape=True) - env.filters["get_ifname"] = previewrenderer.get_ifname + env.filters["get_ifname"] = preview_renderer.get_ifname t = env.get_template("ts") @@ -140,14 +140,14 @@ def test_get_ifname_filter(): assert output == expected_output -def test_get_dbtype_filter(): - """Test get_dbtype filter to extract database backend from config""" +def test_get_fstype_filter(): + """Test get_fstype filter to extract database backend from config""" - template_str = "{{ config | get_dbtype }}" + template_str = "{{ config | get_fstype }}" template_dict = {"ts": template_str} loader = jinja2.DictLoader(template_dict) env = jinja2.Environment(loader=loader, autoescape=True) - env.filters["get_dbtype"] = previewrenderer.get_dbtype + env.filters["get_fstype"] = preview_renderer.get_fstype t = env.get_template("ts") output = t.render(config=CONFIG.database_cli) @@ -183,7 +183,7 @@ def test_experiment_preview(test_dir, wlmutils): exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) # Execute method for template rendering - output = previewrenderer.render(exp, verbosity_level="debug") + output = preview_renderer.render(exp, verbosity_level="debug") # Evaluate output summary_lines = output.split("\n") @@ -203,7 +203,7 @@ def test_experiment_preview_properties(test_dir, wlmutils): exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) # Execute method for template rendering - output = previewrenderer.render(exp, verbosity_level="debug") + output = preview_renderer.render(exp, verbosity_level="debug") # Evaluate output summary_lines = output.split("\n") @@ -215,44 +215,44 @@ def test_experiment_preview_properties(test_dir, wlmutils): assert exp.launcher == summary_dict["Launcher"] -def test_orchestrator_preview_render(test_dir, wlmutils, choose_host): - """Test correct preview output properties for Orchestrator preview""" +def test_feature_store_preview_render(test_dir, wlmutils, choose_host): + """Test correct preview output properties for FeatureStore preview""" # Prepare entities test_launcher = wlmutils.get_test_launcher() test_interface = wlmutils.get_test_interface() test_port = wlmutils.get_test_port() - exp_name = "test_orchestrator_preview_properties" + exp_name = "test_feature_store_preview_properties" exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) # create regular database - orc = exp.create_database( + feature_store = exp.create_feature_store( port=test_port, interface=test_interface, hosts=choose_host(wlmutils), ) - preview_manifest = Manifest(orc) + preview_manifest = Manifest(feature_store) # Execute method for template rendering - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output - assert "Database Identifier" in output + assert "Feature Store Identifier" in output assert "Shards" in output assert "TCP/IP Port(s)" in output assert "Network Interface" in output assert "Type" in output assert "Executable" in output - db_path = _utils.get_db_path() - if db_path: - db_type, _ = db_path.name.split("-", 1) + fs_path = _utils.get_db_path() + if fs_path: + fs_type, _ = fs_path.name.split("-", 1) - assert orc.db_identifier in output - assert str(orc.num_shards) in output - assert orc._interfaces[0] in output - assert db_type in output + assert feature_store.fs_identifier in output + assert str(feature_store.num_shards) in output + assert feature_store._interfaces[0] in output + assert fs_type in output assert CONFIG.database_exe in output - assert orc.run_command in output - assert str(orc.db_nodes) in output + assert feature_store.run_command in output + assert str(feature_store.fs_nodes) in output def test_preview_to_file(test_dir, wlmutils): @@ -268,7 +268,7 @@ def test_preview_to_file(test_dir, wlmutils): path = pathlib.Path(test_dir) / filename # Execute preview method exp.preview( - output_format=previewrenderer.Format.PLAINTEXT, + output_format=preview_renderer.Format.PLAINTEXT, output_filename=str(path), verbosity_level="debug", ) @@ -290,16 +290,16 @@ def test_model_preview(test_dir, wlmutils): rs1 = RunSettings("bash", "multi_tags_template.sh") rs2 = exp.create_run_settings("echo", ["spam", "eggs"]) - hello_world_model = exp.create_model( + hello_world_model = exp.create_application( "echo-hello", run_settings=rs1, params=model_params ) - spam_eggs_model = exp.create_model("echo-spam", run_settings=rs2) + spam_eggs_model = exp.create_application("echo-spam", run_settings=rs2) preview_manifest = Manifest(hello_world_model, spam_eggs_model) # Execute preview method - rendered_preview = previewrenderer.render( + rendered_preview = preview_renderer.render( exp, preview_manifest, verbosity_level="debug" ) @@ -333,13 +333,15 @@ def test_model_preview_properties(test_dir, wlmutils): se_param3 = "eggs" rs2 = exp.create_run_settings(se_param1, [se_param2, se_param3]) - hello_world_model = exp.create_model(hw_name, run_settings=rs1, params=model_params) - spam_eggs_model = exp.create_model(se_name, run_settings=rs2) + hello_world_model = exp.create_application( + hw_name, run_settings=rs1, params=model_params + ) + spam_eggs_model = exp.create_application(se_name, run_settings=rs2) preview_manifest = Manifest(hello_world_model, spam_eggs_model) # Execute preview method - rendered_preview = previewrenderer.render( + rendered_preview = preview_renderer.render( exp, preview_manifest, verbosity_level="debug" ) @@ -385,7 +387,7 @@ def test_preview_model_tagged_files(fileutils, test_dir, wlmutils): model_params = {"port": 6379, "password": "unbreakable_password"} model_settings = RunSettings("bash", "multi_tags_template.sh") - hello_world_model = exp.create_model( + hello_world_model = exp.create_application( "echo-hello", run_settings=model_settings, params=model_params ) @@ -398,7 +400,7 @@ def test_preview_model_tagged_files(fileutils, test_dir, wlmutils): preview_manifest = Manifest(hello_world_model) # Execute preview method - rendered_preview = previewrenderer.render( + rendered_preview = preview_renderer.render( exp, preview_manifest, verbosity_level="debug" ) @@ -417,19 +419,19 @@ def test_model_key_prefixing(test_dir, wlmutils): test_launcher = wlmutils.get_test_launcher() exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) - db = exp.create_database(port=6780, interface="lo") - exp.generate(db, overwrite=True) + fs = exp.create_feature_store(port=6780, interface="lo") + exp.generate(fs, overwrite=True) rs1 = exp.create_run_settings("echo", ["hello", "world"]) - model = exp.create_model("model_test", run_settings=rs1) + model = exp.create_application("model_test", run_settings=rs1) # enable key prefixing on model model.enable_key_prefixing() exp.generate(model, overwrite=True) - preview_manifest = Manifest(db, model) + preview_manifest = Manifest(fs, model) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Key Prefix" in output @@ -467,7 +469,7 @@ def test_ensembles_preview(test_dir, wlmutils): ) preview_manifest = Manifest(ensemble) - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Ensemble Name" in output @@ -491,14 +493,14 @@ def test_preview_models_and_ensembles(test_dir, wlmutils): hw_name = "echo-hello" se_name = "echo-spam" ens_name = "echo-ensemble" - hello_world_model = exp.create_model(hw_name, run_settings=rs1) - spam_eggs_model = exp.create_model(se_name, run_settings=rs2) + hello_world_model = exp.create_application(hw_name, run_settings=rs1) + spam_eggs_model = exp.create_application(se_name, run_settings=rs2) hello_ensemble = exp.create_ensemble(ens_name, run_settings=rs1, replicas=3) exp.generate(hello_world_model, spam_eggs_model, hello_ensemble) preview_manifest = Manifest(hello_world_model, spam_eggs_model, hello_ensemble) - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Models" in output @@ -520,8 +522,8 @@ def test_ensemble_preview_client_configuration(test_dir, wlmutils): "test-preview-ensemble-clientconfig", exp_path=test_dir, launcher=test_launcher ) # Create Orchestrator - db = exp.create_database(port=6780, interface="lo") - exp.generate(db, overwrite=True) + fs = exp.create_feature_store(port=6780, interface="lo") + exp.generate(fs, overwrite=True) rs1 = exp.create_run_settings("echo", ["hello", "world"]) # Create ensemble ensemble = exp.create_ensemble("fd_simulation", run_settings=rs1, replicas=2) @@ -530,42 +532,42 @@ def test_ensemble_preview_client_configuration(test_dir, wlmutils): exp.generate(ensemble, overwrite=True) rs2 = exp.create_run_settings("echo", ["spam", "eggs"]) # Create model - ml_model = exp.create_model("tf_training", rs2) + ml_model = exp.create_application("tf_training", rs2) for sim in ensemble.entities: ml_model.register_incoming_entity(sim) exp.generate(ml_model, overwrite=True) - preview_manifest = Manifest(db, ml_model, ensemble) + preview_manifest = Manifest(fs, ml_model, ensemble) # Call preview renderer for testing output - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Client Configuration" in output - assert "Database Identifier" in output - assert "Database Backend" in output + assert "Feature Store Identifier" in output + assert "Feature Store Backend" in output assert "Type" in output -def test_ensemble_preview_client_configuration_multidb(test_dir, wlmutils): +def test_ensemble_preview_client_configuration_multifs(test_dir, wlmutils): """ Test preview of client configuration and key prefixing in Ensemble preview - with multiple databases + with multiple feature stores """ # Prepare entities test_launcher = wlmutils.get_test_launcher() exp = Experiment( - "test-preview-multidb-clinet-config", exp_path=test_dir, launcher=test_launcher + "test-preview-multifs-clinet-config", exp_path=test_dir, launcher=test_launcher ) - # Create Orchestrator - db1_dbid = "db_1" - db1 = exp.create_database(port=6780, interface="lo", db_identifier=db1_dbid) - exp.generate(db1, overwrite=True) - # Create another Orchestrator - db2_dbid = "db_2" - db2 = exp.create_database(port=6784, interface="lo", db_identifier=db2_dbid) - exp.generate(db2, overwrite=True) + # Create feature store + fs1_fsid = "fs_1" + fs1 = exp.create_feature_store(port=6780, interface="lo", fs_identifier=fs1_fsid) + exp.generate(fs1, overwrite=True) + # Create another feature store + fs2_fsid = "fs_2" + fs2 = exp.create_feature_store(port=6784, interface="lo", fs_identifier=fs2_fsid) + exp.generate(fs2, overwrite=True) rs1 = exp.create_run_settings("echo", ["hello", "world"]) # Create ensemble @@ -575,24 +577,24 @@ def test_ensemble_preview_client_configuration_multidb(test_dir, wlmutils): exp.generate(ensemble, overwrite=True) rs2 = exp.create_run_settings("echo", ["spam", "eggs"]) # Create model - ml_model = exp.create_model("tf_training", rs2) + ml_model = exp.create_application("tf_training", rs2) for sim in ensemble.entities: ml_model.register_incoming_entity(sim) exp.generate(ml_model, overwrite=True) - preview_manifest = Manifest(db1, db2, ml_model, ensemble) + preview_manifest = Manifest(fs1, fs2, ml_model, ensemble) # Call preview renderer for testing output - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Client Configuration" in output - assert "Database Identifier" in output - assert "Database Backend" in output + assert "Feature Store Identifier" in output + assert "Feature Store Backend" in output assert "TCP/IP Port(s)" in output assert "Type" in output - assert db1_dbid in output - assert db2_dbid in output + assert fs1_fsid in output + assert fs2_fsid in output def test_ensemble_preview_attached_files(fileutils, test_dir, wlmutils): @@ -628,7 +630,7 @@ def test_ensemble_preview_attached_files(fileutils, test_dir, wlmutils): preview_manifest = Manifest(ensemble) # Call preview renderer for testing output - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Tagged Files for Model Configuration" in output @@ -649,12 +651,12 @@ def test_ensemble_preview_attached_files(fileutils, test_dir, wlmutils): assert "generator_files/to_symlink_dir" in link -def test_preview_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlutils): +def test_preview_colocated_fs_model_ensemble(fileutils, test_dir, wlmutils, mlutils): """ - Test preview of DBModel on colocated ensembles + Test preview of FSModel on colocated ensembles """ - exp_name = "test-preview-colocated-db-model-ensemble" + exp_name = "test-preview-colocated-fs-model-ensemble" test_launcher = wlmutils.get_test_launcher() test_interface = wlmutils.get_test_interface() test_port = wlmutils.get_test_port() @@ -674,7 +676,7 @@ def test_preview_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlut ) # Create colocated SmartSim Model - colo_model = exp.create_model("colocated_model", colo_settings) + colo_model = exp.create_application("colocated_model", colo_settings) # Create and save ML model to filesystem content = "empty test" @@ -693,10 +695,10 @@ def test_preview_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlut outputs="Identity", ) - # Colocate a database with the first ensemble members + # Colocate a feature store with the first ensemble members for i, entity in enumerate(colo_ensemble): - entity.colocate_db_tcp( - port=test_port + i, db_cpus=1, debug=True, ifname=test_interface + entity.colocate_fs_tcp( + port=test_port + i, fs_cpus=1, debug=True, ifname=test_interface ) # Add ML models to each ensemble member to make sure they # do not conflict with other ML models @@ -715,10 +717,10 @@ def test_preview_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlut # Add another ensemble member colo_ensemble.add_model(colo_model) - # Colocate a database with the new ensemble member - colo_model.colocate_db_tcp( + # Colocate a feature store with the new ensemble member + colo_model.colocate_fs_tcp( port=test_port + len(colo_ensemble) - 1, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) @@ -743,7 +745,7 @@ def test_preview_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlut preview_manifest = Manifest(colo_ensemble) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Models" in output @@ -764,12 +766,12 @@ def test_preview_colocated_db_model_ensemble(fileutils, test_dir, wlmutils, mlut assert model_outputs in output -def test_preview_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlutils): +def test_preview_colocated_fs_script_ensemble(fileutils, test_dir, wlmutils, mlutils): """ - Test preview of DB Scripts on colocated DB from ensemble + Test preview of FS Scripts on colocated FS from ensemble """ - exp_name = "test-preview-colocated-db-script" + exp_name = "test-preview-colocated-fs-script" test_launcher = wlmutils.get_test_launcher() test_interface = wlmutils.get_test_interface() @@ -778,7 +780,7 @@ def test_preview_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlu test_num_gpus = mlutils.get_test_num_gpus() if pytest.test_device == "GPU" else 1 expected_torch_script = "torchscript.py" - test_script = fileutils.get_test_conf_path("run_dbscript_smartredis.py") + test_script = fileutils.get_test_conf_path("run_fsscript_smartredis.py") torch_script = fileutils.get_test_conf_path(expected_torch_script) # Create SmartSim Experiment @@ -794,15 +796,15 @@ def test_preview_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlu ) # Create a SmartSim model - colo_model = exp.create_model("colocated_model", colo_settings) + colo_model = exp.create_application("colocated_model", colo_settings) - # Colocate a db with each ensemble entity and add a script + # Colocate a fs with each ensemble entity and add a script # to each entity via file for i, entity in enumerate(colo_ensemble): entity.disable_key_prefixing() - entity.colocate_db_tcp( + entity.colocate_fs_tcp( port=test_port + i, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) @@ -815,10 +817,10 @@ def test_preview_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlu first_device=0, ) - # Colocate a db with the non-ensemble Model - colo_model.colocate_db_tcp( + # Colocate a fs with the non-ensemble Model + colo_model.colocate_fs_tcp( port=test_port + len(colo_ensemble), - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) @@ -848,16 +850,16 @@ def test_preview_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlu ) # Assert we have added one model to the ensemble - assert len(colo_ensemble._db_scripts) == 1 + assert len(colo_ensemble._fs_scripts) == 1 # Assert we have added both models to each entity - assert all([len(entity._db_scripts) == 2 for entity in colo_ensemble]) + assert all([len(entity._fs_scripts) == 2 for entity in colo_ensemble]) exp.generate(colo_ensemble) preview_manifest = Manifest(colo_ensemble) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Torch Scripts" in output @@ -872,7 +874,7 @@ def test_preview_colocated_db_script_ensemble(fileutils, test_dir, wlmutils, mlu def test_preview_active_infrastructure(wlmutils, test_dir, preview_object): - """Test active infrastructure without other orchestrators""" + """Test active infrastructure without other feature stores""" # Prepare entities test_launcher = wlmutils.get_test_launcher() @@ -880,12 +882,12 @@ def test_preview_active_infrastructure(wlmutils, test_dir, preview_object): exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) # Execute method for template rendering - output = previewrenderer.render( - exp, active_dbjobs=preview_object, verbosity_level="debug" + output = preview_renderer.render( + exp, active_fsjobs=preview_object, verbosity_level="debug" ) assert "Active Infrastructure" in output - assert "Database Identifier" in output + assert "Feature Store Identifier" in output assert "Shards" in output assert "Network Interface" in output assert "Type" in output @@ -897,48 +899,48 @@ def test_preview_orch_active_infrastructure( ): """ Test correct preview output properties for active infrastructure preview - with other orchestrators + with other feature stores """ # Prepare entities test_launcher = wlmutils.get_test_launcher() test_interface = wlmutils.get_test_interface() test_port = wlmutils.get_test_port() - exp_name = "test_orchestrator_active_infrastructure_preview" + exp_name = "test_feature_store_active_infrastructure_preview" exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) - orc2 = exp.create_database( + feature_store2 = exp.create_feature_store( port=test_port, interface=test_interface, hosts=choose_host(wlmutils), - db_identifier="orc_2", + fs_identifier="fs_2", ) - orc3 = exp.create_database( + feature_store3 = exp.create_feature_store( port=test_port, interface=test_interface, hosts=choose_host(wlmutils), - db_identifier="orc_3", + fs_identifier="fs_3", ) - preview_manifest = Manifest(orc2, orc3) + preview_manifest = Manifest(feature_store2, feature_store3) # Execute method for template rendering - output = previewrenderer.render( - exp, preview_manifest, active_dbjobs=preview_object, verbosity_level="debug" + output = preview_renderer.render( + exp, preview_manifest, active_fsjobs=preview_object, verbosity_level="debug" ) assert "Active Infrastructure" in output - assert "Database Identifier" in output + assert "Feature Store Identifier" in output assert "Shards" in output assert "Network Interface" in output assert "Type" in output assert "TCP/IP" in output -def test_preview_multidb_active_infrastructure( +def test_preview_multifs_active_infrastructure( wlmutils, test_dir, choose_host, preview_object_multidb ): - """multiple started databases active infrastructure""" + """multiple started feature stores active infrastructure""" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -947,32 +949,32 @@ def test_preview_multidb_active_infrastructure( # start a new Experiment for this section exp = Experiment( - "test_preview_multidb_active_infrastructure", + "test_preview_multifs_active_infrastructure", exp_path=test_dir, launcher=test_launcher, ) # Execute method for template rendering - output = previewrenderer.render( - exp, active_dbjobs=preview_object_multidb, verbosity_level="debug" + output = preview_renderer.render( + exp, active_fsjobs=preview_object_multifs, verbosity_level="debug" ) assert "Active Infrastructure" in output - assert "Database Identifier" in output + assert "Feature Store Identifier" in output assert "Shards" in output assert "Network Interface" in output assert "Type" in output assert "TCP/IP" in output - assert "testdb_reg" in output - assert "testdb_reg2" in output - assert "Ochestrators" not in output + assert "testfs_reg" in output + assert "testfs_reg2" in output + assert "Feature Stores" not in output -def test_preview_active_infrastructure_orchestrator_error( +def test_preview_active_infrastructure_feature_store_error( wlmutils, test_dir, choose_host, monkeypatch: pytest.MonkeyPatch ): - """Demo error when trying to preview a started orchestrator""" + """Demo error when trying to preview a started feature store""" # Prepare entities test_launcher = wlmutils.get_test_launcher() test_interface = wlmutils.get_test_interface() @@ -981,56 +983,56 @@ def test_preview_active_infrastructure_orchestrator_error( exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) monkeypatch.setattr( - smartsim.database.orchestrator.Orchestrator, "is_active", lambda x: True + smartsim.database.orchestrator.FeatureStore, "is_active", lambda x: True ) - orc = exp.create_database( + orc = exp.create_feature_store( port=test_port, interface=test_interface, hosts=choose_host(wlmutils), - db_identifier="orc_1", + fs_identifier="orc_1", ) # Retrieve any active jobs - active_dbjobs = exp._control.active_orchestrator_jobs + active_fsjobs = exp._control.active_feature_store_jobs preview_manifest = Manifest(orc) # Execute method for template rendering - output = previewrenderer.render( - exp, preview_manifest, active_dbjobs=active_dbjobs, verbosity_level="debug" + output = preview_renderer.render( + exp, preview_manifest, active_fsjobs=active_fsjobs, verbosity_level="debug" ) assert "WARNING: Cannot preview orc_1, because it is already started" in output -def test_active_orchestrator_jobs_property( +def test_active_feature_store_jobs_property( wlmutils, test_dir, preview_object, ): - """Ensure db_jobs remaines unchanged after deletion - of active_orchestrator_jobs property stays intact when retrieving db_jobs""" + """Ensure fs_jobs remaines unchanged after deletion + of active_feature_store_jobs property stays intact when retrieving fs_jobs""" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() # start a new Experiment for this section exp = Experiment( - "test-active_orchestrator_jobs-property", + "test-active_feature_store_jobs-property", exp_path=test_dir, launcher=test_launcher, ) controller = Controller() - controller._jobs.db_jobs = preview_object + controller._jobs.fs_jobs = preview_object # Modify the returned job collection - active_orchestrator_jobs = exp._control.active_orchestrator_jobs - active_orchestrator_jobs["test"] = "test_value" + active_feature_store_jobs = exp._control.active_feature_store_jobs + active_feature_store_jobs["test"] = "test_value" # Verify original collection is not also modified - assert not exp._control.active_orchestrator_jobs.get("test", None) + assert not exp._control.active_feature_store_jobs.get("test", None) def test_verbosity_info_ensemble(test_dir, wlmutils): @@ -1050,14 +1052,14 @@ def test_verbosity_info_ensemble(test_dir, wlmutils): hw_name = "echo-hello" se_name = "echo-spam" ens_name = "echo-ensemble" - hello_world_model = exp.create_model(hw_name, run_settings=rs1) - spam_eggs_model = exp.create_model(se_name, run_settings=rs2) + hello_world_model = exp.create_application(hw_name, run_settings=rs1) + spam_eggs_model = exp.create_application(se_name, run_settings=rs2) hello_ensemble = exp.create_ensemble(ens_name, run_settings=rs1, replicas=3) exp.generate(hello_world_model, spam_eggs_model, hello_ensemble) preview_manifest = Manifest(hello_world_model, spam_eggs_model, hello_ensemble) - output = previewrenderer.render(exp, preview_manifest, verbosity_level="info") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="info") assert "Executable" not in output assert "Executable Arguments" not in output @@ -1065,14 +1067,14 @@ def test_verbosity_info_ensemble(test_dir, wlmutils): assert "echo_ensemble_1" not in output -def test_verbosity_info_colocated_db_model_ensemble( +def test_verbosity_info_colocated_fs_model_ensemble( fileutils, test_dir, wlmutils, mlutils ): - """Test preview of DBModel on colocated ensembles, first adding the DBModel to the - ensemble, then colocating DB. + """Test preview of FSModel on colocated ensembles, first adding the FSModel to the + ensemble, then colocating FS. """ - exp_name = "test-colocated-db-model-ensemble-reordered" + exp_name = "test-colocated-fs-model-ensemble-reordered" test_launcher = wlmutils.get_test_launcher() test_interface = wlmutils.get_test_interface() test_port = wlmutils.get_test_port() @@ -1092,7 +1094,7 @@ def test_verbosity_info_colocated_db_model_ensemble( ) # Create colocated SmartSim Model - colo_model = exp.create_model("colocated_model", colo_settings) + colo_model = exp.create_application("colocated_model", colo_settings) # Create and save ML model to filesystem content = "empty test" @@ -1111,10 +1113,10 @@ def test_verbosity_info_colocated_db_model_ensemble( outputs="Identity", ) - # Colocate a database with the first ensemble members + # Colocate a feature store with the first ensemble members for i, entity in enumerate(colo_ensemble): - entity.colocate_db_tcp( - port=test_port + i, db_cpus=1, debug=True, ifname=test_interface + entity.colocate_fs_tcp( + port=test_port + i, fs_cpus=1, debug=True, ifname=test_interface ) # Add ML models to each ensemble member to make sure they # do not conflict with other ML models @@ -1133,10 +1135,10 @@ def test_verbosity_info_colocated_db_model_ensemble( # Add another ensemble member colo_ensemble.add_model(colo_model) - # Colocate a database with the new ensemble member - colo_model.colocate_db_tcp( + # Colocate a feature store with the new ensemble member + colo_model.colocate_fs_tcp( port=test_port + len(colo_ensemble) - 1, - db_cpus=1, + fs_cpus=1, debug=True, ifname=test_interface, ) @@ -1161,30 +1163,30 @@ def test_verbosity_info_colocated_db_model_ensemble( preview_manifest = Manifest(colo_ensemble) # Execute preview method - output = previewrenderer.render(exp, preview_manifest, verbosity_level="info") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="info") assert "Outgoing Key Collision Prevention (Key Prefixing)" not in output assert "Devices Per Node" not in output -def test_verbosity_info_orchestrator(test_dir, wlmutils, choose_host): - """Test correct preview output properties for Orchestrator preview""" +def test_verbosity_info_feature_store(test_dir, wlmutils, choose_host): + """Test correct preview output properties for feature store preview""" # Prepare entities test_launcher = wlmutils.get_test_launcher() test_interface = wlmutils.get_test_interface() test_port = wlmutils.get_test_port() - exp_name = "test_orchestrator_preview_properties" + exp_name = "test_feature_store_preview_properties" exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) - # create regular database - orc = exp.create_database( + # create regular feature store + feature_store = exp.create_feature_store( port=test_port, interface=test_interface, hosts=choose_host(wlmutils), ) - preview_manifest = Manifest(orc) + preview_manifest = Manifest(feature_store) # Execute method for template rendering - output = previewrenderer.render(exp, preview_manifest, verbosity_level="info") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="info") # Evaluate output assert "Executable" not in output @@ -1198,9 +1200,9 @@ def test_verbosity_info_ensemble(test_dir, wlmutils): # Prepare entities test_launcher = wlmutils.get_test_launcher() exp = Experiment("key_prefix_test", exp_path=test_dir, launcher=test_launcher) - # Create Orchestrator - db = exp.create_database(port=6780, interface="lo") - exp.generate(db, overwrite=True) + # Create feature store + fs = exp.create_feature_store(port=6780, interface="lo") + exp.generate(fs, overwrite=True) rs1 = exp.create_run_settings("echo", ["hello", "world"]) # Create ensemble ensemble = exp.create_ensemble("fd_simulation", run_settings=rs1, replicas=2) @@ -1209,16 +1211,16 @@ def test_verbosity_info_ensemble(test_dir, wlmutils): exp.generate(ensemble, overwrite=True) rs2 = exp.create_run_settings("echo", ["spam", "eggs"]) # Create model - ml_model = exp.create_model("tf_training", rs2) + ml_model = exp.create_application("tf_training", rs2) for sim in ensemble.entities: ml_model.register_incoming_entity(sim) exp.generate(ml_model, overwrite=True) - preview_manifest = Manifest(db, ml_model, ensemble) + preview_manifest = Manifest(fs, ml_model, ensemble) # Call preview renderer for testing output - output = previewrenderer.render(exp, preview_manifest, verbosity_level="info") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="info") # Evaluate output assert "Outgoing Key Collision Prevention (Key Prefixing)" in output @@ -1266,8 +1268,8 @@ def test_check_verbosity_level(): exp.preview(verbosity_level="info") -def test_preview_colocated_db_singular_model(wlmutils, test_dir): - """Test preview behavior when a colocated db is only added to +def test_preview_colocated_fs_singular_model(wlmutils, test_dir): + """Test preview behavior when a colocated fs is only added to one model. The expected behviour is that both models are colocated """ @@ -1277,24 +1279,24 @@ def test_preview_colocated_db_singular_model(wlmutils, test_dir): rs = exp.create_run_settings("sleep", ["100"]) - model_1 = exp.create_model("model_1", run_settings=rs) - model_2 = exp.create_model("model_2", run_settings=rs) + model_1 = exp.create_application("model_1", run_settings=rs) + model_2 = exp.create_application("model_2", run_settings=rs) - model_1.colocate_db() + model_1.colocate_fs() exp.generate(model_1, model_2, overwrite=True) preview_manifest = Manifest(model_1, model_2) # Call preview renderer for testing output - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") assert "model_1" in output assert "model_2" in output assert "Client Configuration" in output -def test_preview_db_script(wlmutils, test_dir): +def test_preview_fs_script(wlmutils, test_dir): """ Test preview of model instance with a torch script. """ @@ -1307,8 +1309,8 @@ def test_preview_db_script(wlmutils, test_dir): model_settings = exp.create_run_settings(exe="python", exe_args="params.py") # Initialize a Model object - model_instance = exp.create_model("model_name", model_settings) - model_instance.colocate_db_tcp() + model_instance = exp.create_application("model_name", model_settings) + model_instance.colocate_fs_tcp() # TorchScript string torch_script_str = "def negate(x):\n\treturn torch.neg(x)\n" @@ -1324,7 +1326,7 @@ def test_preview_db_script(wlmutils, test_dir): preview_manifest = Manifest(model_instance) # Call preview renderer for testing output - output = previewrenderer.render(exp, preview_manifest, verbosity_level="debug") + output = preview_renderer.render(exp, preview_manifest, verbosity_level="debug") # Evaluate output assert "Torch Script" in output diff --git a/tests/test_reconnect_orchestrator.py b/tests/_legacy/test_reconnect_orchestrator.py similarity index 68% rename from tests/test_reconnect_orchestrator.py rename to tests/_legacy/test_reconnect_orchestrator.py index 6ce93c6f93..715c977ec1 100644 --- a/tests/test_reconnect_orchestrator.py +++ b/tests/_legacy/test_reconnect_orchestrator.py @@ -30,8 +30,8 @@ import pytest from smartsim import Experiment -from smartsim.database import Orchestrator -from smartsim.status import SmartSimStatus +from smartsim.database import FeatureStore +from smartsim.status import JobStatus # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b @@ -39,45 +39,46 @@ first_dir = "" -# TODO ensure database is shutdown +# TODO ensure feature store is shutdown # use https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test -def test_local_orchestrator(test_dir, wlmutils): - """Test launching orchestrator locally""" +def test_local_feature_store(test_dir, wlmutils): + """Test launching feature store locally""" global first_dir - exp_name = "test-orc-launch-local" + exp_name = "test-feature-store-launch-local" exp = Experiment(exp_name, launcher="local", exp_path=test_dir) first_dir = test_dir - orc = Orchestrator(port=wlmutils.get_test_port()) - orc.set_path(osp.join(test_dir, "orchestrator")) + feature_store = FeatureStore(port=wlmutils.get_test_port()) + feature_store.set_path(osp.join(test_dir, "feature_store")) - exp.start(orc) - statuses = exp.get_status(orc) - assert [stat != SmartSimStatus.STATUS_FAILED for stat in statuses] + exp.start(feature_store) + statuses = exp.get_status(feature_store) + assert [stat != JobStatus.FAILED for stat in statuses] # simulate user shutting down main thread exp._control._jobs.actively_monitoring = False exp._control._launcher.task_manager.actively_monitoring = False -def test_reconnect_local_orc(test_dir): - """Test reconnecting to orchestrator from first experiment""" +def test_reconnect_local_feature_store(test_dir): + """Test reconnecting to feature store from first experiment""" global first_dir # start new experiment - exp_name = "test-orc-local-reconnect-2nd" + exp_name = "test-feature-store-local-reconnect-2nd" exp_2 = Experiment(exp_name, launcher="local", exp_path=test_dir) - checkpoint = osp.join(first_dir, "orchestrator", "smartsim_db.dat") - reloaded_orc = exp_2.reconnect_orchestrator(checkpoint) + checkpoint = osp.join(first_dir, "feature_store", "smartsim_db.dat") + + reloaded_feature_store = exp_2.reconnect_feature_store(checkpoint) # let statuses update once time.sleep(5) - statuses = exp_2.get_status(reloaded_orc) + statuses = exp_2.get_status(reloaded_feature_store) for stat in statuses: - if stat == SmartSimStatus.STATUS_FAILED: - exp_2.stop(reloaded_orc) + if stat == JobStatus.FAILED: + exp_2.stop(reloaded_feature_store) assert False - exp_2.stop(reloaded_orc) + exp_2.stop(reloaded_feature_store) diff --git a/tests/test_run_settings.py b/tests/_legacy/test_run_settings.py similarity index 90% rename from tests/test_run_settings.py rename to tests/_legacy/test_run_settings.py index 056dad64b7..8209334dcf 100644 --- a/tests/test_run_settings.py +++ b/tests/_legacy/test_run_settings.py @@ -31,6 +31,7 @@ import pytest +from smartsim import Experiment from smartsim.error.errors import SSUnsupportedError from smartsim.settings import ( MpiexecSettings, @@ -41,6 +42,7 @@ Singularity, ) from smartsim.settings.settings import create_run_settings +from smartsim.status import JobStatus # The tests in this file belong to the slow_tests group pytestmark = pytest.mark.slow_tests @@ -567,3 +569,55 @@ def test_update_env_null_valued(env_vars): with pytest.raises(TypeError) as ex: rs = RunSettings(sample_exe, run_command=cmd, env_vars=orig_env) rs.update_env(env_vars) + + +def test_create_run_settings_run_args_leading_dashes(test_dir, wlmutils): + """ + Test warning for leading `-` in run_args in `exp.create_run_settings` + """ + exp_name = "test-create-run_settings-run_args-leading-dashes" + exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) + + run_args = {"--nodes": 1} + settings = exp.create_run_settings( + "echo", exe_args=["hello", "world"], run_command="srun", run_args=run_args + ) + model = exp.create_model("sr_issue_model", run_settings=settings) + exp.start(model) + + statuses = exp.get_status(model) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) + + +def test_set_run_args_leading_dashes(test_dir, wlmutils): + """ + Test warning for leading `-` for run_args in `settings.set` + """ + exp_name = "test-set-run-args-leading-dashes" + exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) + settings = exp.create_run_settings( + "echo", exe_args=["hello", "world"], run_command="srun" + ) + settings.set("--nodes", "1") + + model = exp.create_model("sr_issue_model", run_settings=settings) + exp.start(model) + statuses = exp.get_status(model) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) + + +def test_run_args_integer(test_dir, wlmutils): + """ + Test that `setting.set` will take an integer as a run argument + """ + exp_name = "test-run-args-integer" + exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) + settings = exp.create_run_settings( + "echo", exe_args=["hello", "world"], run_command="srun" + ) + settings.set("--nodes", 1) + + model = exp.create_model("sr_issue_model", run_settings=settings) + exp.start(model) + statuses = exp.get_status(model) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) diff --git a/tests/test_schema_utils.py b/tests/_legacy/test_schema_utils.py similarity index 100% rename from tests/test_schema_utils.py rename to tests/_legacy/test_schema_utils.py diff --git a/tests/test_serialize.py b/tests/_legacy/test_serialize.py similarity index 87% rename from tests/test_serialize.py rename to tests/_legacy/test_serialize.py index b2dc0b7a70..eb56d75540 100644 --- a/tests/test_serialize.py +++ b/tests/_legacy/test_serialize.py @@ -36,7 +36,7 @@ from smartsim._core._cli import utils from smartsim._core.control.manifest import LaunchedManifestBuilder from smartsim._core.utils import serialize -from smartsim.database.orchestrator import Orchestrator +from smartsim.database.orchestrator import FeatureStore _CFG_TM_ENABLED_ATTR = "telemetry_enabled" @@ -123,31 +123,33 @@ def test_started_entities_are_serialized(test_dir, manifest_json): rs1 = exp.create_run_settings("echo", ["hello", "world"]) rs2 = exp.create_run_settings("echo", ["spam", "eggs"]) - hello_world_model = exp.create_model("echo-hello", run_settings=rs1) - spam_eggs_model = exp.create_model("echo-spam", run_settings=rs2) + hello_world_application = exp.create_application("echo-hello", run_settings=rs1) + spam_eggs_application = exp.create_application("echo-spam", run_settings=rs2) hello_ensemble = exp.create_ensemble("echo-ensemble", run_settings=rs1, replicas=3) - exp.generate(hello_world_model, spam_eggs_model, hello_ensemble) - exp.start(hello_world_model, spam_eggs_model, block=False) + exp.generate(hello_world_application, spam_eggs_application, hello_ensemble) + exp.start(hello_world_application, spam_eggs_application, block=False) exp.start(hello_ensemble, block=False) try: with open(manifest_json, "r") as f: manifest = json.load(f) assert len(manifest["runs"]) == 2 - assert len(manifest["runs"][0]["model"]) == 2 + assert len(manifest["runs"][0]["application"]) == 2 assert len(manifest["runs"][0]["ensemble"]) == 0 - assert len(manifest["runs"][1]["model"]) == 0 + assert len(manifest["runs"][1]["application"]) == 0 assert len(manifest["runs"][1]["ensemble"]) == 1 - assert len(manifest["runs"][1]["ensemble"][0]["models"]) == 3 + assert len(manifest["runs"][1]["ensemble"][0]["applications"]) == 3 finally: - exp.stop(hello_world_model, spam_eggs_model, hello_ensemble) + exp.stop(hello_world_application, spam_eggs_application, hello_ensemble) -def test_serialzed_database_does_not_break_if_using_a_non_standard_install(monkeypatch): - monkeypatch.setattr(utils, "get_db_path", lambda: None) - db = Orchestrator() - dict_ = serialize._dictify_db(db, []) +def test_serialzed_feature_store_does_not_break_if_using_a_non_standard_install( + monkeypatch, +): + monkeypatch.setattr(utils, "get_fs_path", lambda: None) + fs = FeatureStore() + dict_ = serialize._dictify_fs(fs, []) assert dict_["type"] == "Unknown" diff --git a/tests/test_sge_batch_settings.py b/tests/_legacy/test_sge_batch_settings.py similarity index 98% rename from tests/test_sge_batch_settings.py rename to tests/_legacy/test_sge_batch_settings.py index fa40b4b00e..f81bee1eab 100644 --- a/tests/test_sge_batch_settings.py +++ b/tests/_legacy/test_sge_batch_settings.py @@ -29,7 +29,7 @@ import pytest from smartsim import Experiment -from smartsim._core.launcher.sge.sgeParser import parse_qstat_jobid_xml +from smartsim._core.launcher.sge.sge_parser import parse_qstat_jobid_xml from smartsim.error import SSConfigError from smartsim.settings import SgeQsubBatchSettings from smartsim.settings.mpiSettings import _BaseMPISettings diff --git a/tests/test_shell_util.py b/tests/_legacy/test_shell_util.py similarity index 100% rename from tests/test_shell_util.py rename to tests/_legacy/test_shell_util.py diff --git a/tests/test_slurm_get_alloc.py b/tests/_legacy/test_slurm_get_alloc.py similarity index 100% rename from tests/test_slurm_get_alloc.py rename to tests/_legacy/test_slurm_get_alloc.py diff --git a/tests/test_slurm_parser.py b/tests/_legacy/test_slurm_parser.py similarity index 84% rename from tests/test_slurm_parser.py rename to tests/_legacy/test_slurm_parser.py index b5f7cf32ae..e73ec7ed7e 100644 --- a/tests/test_slurm_parser.py +++ b/tests/_legacy/test_slurm_parser.py @@ -26,7 +26,7 @@ import pytest -from smartsim._core.launcher.slurm import slurmParser +from smartsim._core.launcher.slurm import slurm_parser # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b @@ -41,7 +41,7 @@ def test_parse_salloc(): "salloc: Waiting for resource configuration\n" "salloc: Nodes nid00116 are ready for job" ) - alloc_id = slurmParser.parse_salloc(output) + alloc_id = slurm_parser.parse_salloc(output) assert alloc_id == "118568" @@ -54,7 +54,7 @@ def test_parse_salloc_extra(): "salloc: Waiting for resource configuration\n" "salloc: Nodes prod76-0006 are ready for job\n" ) - alloc_id = slurmParser.parse_salloc(output) + alloc_id = slurm_parser.parse_salloc(output) assert alloc_id == "22942" @@ -64,14 +64,14 @@ def test_parse_salloc_high(): "salloc: Waiting for resource configuration\n" "salloc: Nodes nid00034 are ready for job\n" ) - alloc_id = slurmParser.parse_salloc(output) + alloc_id = slurm_parser.parse_salloc(output) assert alloc_id == "29917893" def test_parse_salloc_error(): output = "salloc: error: Job submit/allocate failed: Job dependency problem" error = "Job submit/allocate failed: Job dependency problem" - parsed_error = slurmParser.parse_salloc_error(output) + parsed_error = slurm_parser.parse_salloc_error(output) assert error == parsed_error @@ -81,7 +81,7 @@ def test_parse_salloc_error_2(): "Try 'salloc --help' for more information\n" ) error = "unrecognized option '--no-a-option'" - parsed_error = slurmParser.parse_salloc_error(output) + parsed_error = slurm_parser.parse_salloc_error(output) assert error == parsed_error @@ -93,7 +93,7 @@ def test_parse_salloc_error_3(): "\nsalloc: error: Job submit/allocate failed: Invalid node name specified\n" ) error = "Job submit/allocate failed: Invalid node name specified" - parsed_error = slurmParser.parse_salloc_error(output) + parsed_error = slurm_parser.parse_salloc_error(output) assert error == parsed_error @@ -103,7 +103,7 @@ def test_parse_salloc_error_4(): "salloc: error: Job submit/allocate failed: Unspecified error\n" ) error = "No hardware architecture specified (-C)!" - parsed_error = slurmParser.parse_salloc_error(output) + parsed_error = slurm_parser.parse_salloc_error(output) assert error == parsed_error @@ -116,7 +116,7 @@ def test_parse_sstat_nodes(): """ output = "118594.extern|nid00028|38671|\n" "118594.0|nid00028|38703|" nodes = ["nid00028"] - parsed_nodes = slurmParser.parse_sstat_nodes(output, "118594") + parsed_nodes = slurm_parser.parse_sstat_nodes(output, "118594") assert nodes == parsed_nodes @@ -126,7 +126,7 @@ def test_parse_sstat_nodes_1(): """ output = "22942.0|prod76-0006|354345|" nodes = ["prod76-0006"] - parsed_nodes = slurmParser.parse_sstat_nodes(output, "22942.0") + parsed_nodes = slurm_parser.parse_sstat_nodes(output, "22942.0") assert nodes == parsed_nodes @@ -136,7 +136,7 @@ def test_parse_sstat_nodes_2(): """ output = "29917893.extern|nid00034|44860|\n" "29917893.0|nid00034|44887|\n" nodes = ["nid00034"] - parsed_nodes = slurmParser.parse_sstat_nodes(output, "29917893.0") + parsed_nodes = slurm_parser.parse_sstat_nodes(output, "29917893.0") assert nodes == parsed_nodes @@ -152,7 +152,7 @@ def test_parse_sstat_nodes_3(): "29917893.2|nid00034|45174|\n" ) nodes = ["nid00034"] - parsed_nodes = slurmParser.parse_sstat_nodes(output, "29917893.2") + parsed_nodes = slurm_parser.parse_sstat_nodes(output, "29917893.2") assert nodes == parsed_nodes @@ -171,7 +171,7 @@ def test_parse_sstat_nodes_4(): "30000.2|nid00036|45174,32435|\n" ) nodes = set(["nid00034", "nid00035", "nid00036"]) - parsed_nodes = set(slurmParser.parse_sstat_nodes(output, "30000")) + parsed_nodes = set(slurm_parser.parse_sstat_nodes(output, "30000")) assert nodes == parsed_nodes @@ -190,7 +190,7 @@ def test_parse_sstat_nodes_4(): "30000.2|nid00036|45174,32435|\n" ) nodes = set(["nid00034", "nid00035", "nid00036"]) - parsed_nodes = set(slurmParser.parse_sstat_nodes(output, "30000")) + parsed_nodes = set(slurm_parser.parse_sstat_nodes(output, "30000")) assert nodes == parsed_nodes @@ -206,7 +206,7 @@ def test_parse_sstat_nodes_5(): "29917893.2|nid00034|45174|\n" ) nodes = ["nid00034"] - parsed_nodes = slurmParser.parse_sstat_nodes(output, "29917893.2") + parsed_nodes = slurm_parser.parse_sstat_nodes(output, "29917893.2") assert nodes == parsed_nodes @@ -221,7 +221,7 @@ def test_parse_sacct_step_id(): "m2-119225.1|119225.1|" ) step_id = "119225.0" - parsed_step_id = slurmParser.parse_step_id_from_sacct(output, "m1-119225.0") + parsed_step_id = slurm_parser.parse_step_id_from_sacct(output, "m1-119225.0") assert step_id == parsed_step_id @@ -231,12 +231,12 @@ def test_parse_sacct_step_id_2(): "extern|119225.extern|\n" "m1-119225.0|119225.0|\n" "m2-119225.1|119225.1|\n" - "orchestrator_0-119225.2|119225.2|\n" + "featurestore_0-119225.2|119225.2|\n" "n1-119225.3|119225.3|" ) step_id = "119225.2" - parsed_step_id = slurmParser.parse_step_id_from_sacct( - output, "orchestrator_0-119225.2" + parsed_step_id = slurm_parser.parse_step_id_from_sacct( + output, "featurestore_0-119225.2" ) assert step_id == parsed_step_id @@ -251,7 +251,7 @@ def test_parse_sacct_step_id_2(): "cti_dlaunch1.0|962333.3|" ) step_id = "962333.1" - parsed_step_id = slurmParser.parse_step_id_from_sacct(output, "python-962333.1") + parsed_step_id = slurm_parser.parse_step_id_from_sacct(output, "python-962333.1") assert step_id == parsed_step_id @@ -261,7 +261,7 @@ def test_parse_sacct_status(): """ output = "29917893.2|COMPLETED|0:0|\n" status = ("COMPLETED", "0") - parsed_status = slurmParser.parse_sacct(output, "29917893.2") + parsed_status = slurm_parser.parse_sacct(output, "29917893.2") assert status == parsed_status @@ -271,7 +271,7 @@ def test_parse_sacct_status_1(): """ output = "22999.0|FAILED|1:0|\n" status = ("FAILED", "1") - parsed_status = slurmParser.parse_sacct(output, "22999.0") + parsed_status = slurm_parser.parse_sacct(output, "22999.0") assert status == parsed_status @@ -281,5 +281,5 @@ def test_parse_sacct_status_2(): """ output = "22999.10|COMPLETED|0:0|\n22999.1|FAILED|1:0|\n" status = ("FAILED", "1") - parsed_status = slurmParser.parse_sacct(output, "22999.1") + parsed_status = slurm_parser.parse_sacct(output, "22999.1") assert status == parsed_status diff --git a/tests/test_slurm_settings.py b/tests/_legacy/test_slurm_settings.py similarity index 97% rename from tests/test_slurm_settings.py rename to tests/_legacy/test_slurm_settings.py index d9d820244e..9fd0f5e82b 100644 --- a/tests/test_slurm_settings.py +++ b/tests/_legacy/test_slurm_settings.py @@ -79,7 +79,7 @@ def test_update_env(): def test_catch_colo_mpmd(): srun = SrunSettings("python") - srun.colocated_db_settings = {"port": 6379, "cpus": 1} + srun.colocated_fs_settings = {"port": 6379, "cpus": 1} srun_2 = SrunSettings("python") # should catch the user trying to make rs mpmd that already are colocated @@ -100,7 +100,7 @@ def test_mpmd_compound_env_exports(): srun_2.env_vars = {"cmp2": "222,333", "norm2": "pqr"} srun.make_mpmd(srun_2) - from smartsim._core.launcher.step.slurmStep import SbatchStep, SrunStep + from smartsim._core.launcher.step.slurm_step import SbatchStep, SrunStep from smartsim.settings.slurmSettings import SbatchSettings step = SrunStep("teststep", "./", srun) @@ -160,7 +160,7 @@ def test_mpmd_non_compound_env_exports(): srun_2.env_vars = {"cmp2": "222", "norm2": "pqr"} srun.make_mpmd(srun_2) - from smartsim._core.launcher.step.slurmStep import SbatchStep, SrunStep + from smartsim._core.launcher.step.slurm_step import SbatchStep, SrunStep from smartsim.settings.slurmSettings import SbatchSettings step = SrunStep("teststep", "./", srun) @@ -220,7 +220,7 @@ def test_mpmd_non_compound_no_exports(): srun_2.env_vars = {} srun.make_mpmd(srun_2) - from smartsim._core.launcher.step.slurmStep import SbatchStep, SrunStep + from smartsim._core.launcher.step.slurm_step import SbatchStep, SrunStep from smartsim.settings.slurmSettings import SbatchSettings step = SrunStep("teststep", "./", srun) diff --git a/tests/test_slurm_validation.py b/tests/_legacy/test_slurm_validation.py similarity index 100% rename from tests/test_slurm_validation.py rename to tests/_legacy/test_slurm_validation.py diff --git a/tests/test_smartredis.py b/tests/_legacy/test_smartredis.py similarity index 76% rename from tests/test_smartredis.py rename to tests/_legacy/test_smartredis.py index 6f7b199340..d4ac0ceebc 100644 --- a/tests/test_smartredis.py +++ b/tests/_legacy/test_smartredis.py @@ -27,11 +27,9 @@ import pytest -from smartsim import Experiment -from smartsim._core.utils import installed_redisai_backends -from smartsim.database import Orchestrator -from smartsim.entity import Ensemble, Model -from smartsim.status import SmartSimStatus +from smartsim.builders import Ensemble +from smartsim.entity import Application +from smartsim.status import JobStatus # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b @@ -50,7 +48,9 @@ except ImportError: shouldrun = False -torch_available = "torch" in installed_redisai_backends() +torch_available = ( + "torch" in [] +) # todo: update test to replace installed_redisai_backends() shouldrun &= torch_available @@ -60,15 +60,15 @@ ) -def test_exchange(local_experiment, local_db, prepare_db, fileutils): +def test_exchange(local_experiment, local_fs, prepare_fs, fileutils): """Run two processes, each process puts a tensor on - the DB, then accesses the other process's tensor. - Finally, the tensor is used to run a model. + the FS, then accesses the other process's tensor. + Finally, the tensor is used to run a application. """ - db = prepare_db(local_db).orchestrator - # create and start a database - local_experiment.reconnect_orchestrator(db.checkpoint_file) + fs = prepare_fs(local_fs).featurestore + # create and start a feature store + local_experiment.reconnect_feature_store(fs.checkpoint_file) rs = local_experiment.create_run_settings("python", "producer.py --exchange") params = {"mult": [1, -10]} @@ -87,24 +87,24 @@ def test_exchange(local_experiment, local_db, prepare_db, fileutils): local_experiment.generate(ensemble) - # start the models + # start the applications local_experiment.start(ensemble, summary=False) # get and confirm statuses statuses = local_experiment.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) -def test_consumer(local_experiment, local_db, prepare_db, fileutils): +def test_consumer(local_experiment, local_fs, prepare_fs, fileutils): """Run three processes, each one of the first two processes - puts a tensor on the DB; the third process accesses the + puts a tensor on the FS; the third process accesses the tensors put by the two producers. - Finally, the tensor is used to run a model by each producer + Finally, the tensor is used to run a application by each producer and the consumer accesses the two results. """ - db = prepare_db(local_db).orchestrator - local_experiment.reconnect_orchestrator(db.checkpoint_file) + fs = prepare_fs(local_fs).featurestore + local_experiment.reconnect_feature_store(fs.checkpoint_file) rs_prod = local_experiment.create_run_settings("python", "producer.py") rs_consumer = local_experiment.create_run_settings("python", "consumer.py") @@ -113,10 +113,10 @@ def test_consumer(local_experiment, local_db, prepare_db, fileutils): name="producer", params=params, run_settings=rs_prod, perm_strat="step" ) - consumer = Model( + consumer = Application( "consumer", params={}, path=ensemble.path, run_settings=rs_consumer ) - ensemble.add_model(consumer) + ensemble.add_application(consumer) ensemble.register_incoming_entity(ensemble["producer_0"]) ensemble.register_incoming_entity(ensemble["producer_1"]) @@ -126,9 +126,9 @@ def test_consumer(local_experiment, local_db, prepare_db, fileutils): local_experiment.generate(ensemble) - # start the models + # start the applications local_experiment.start(ensemble, summary=False) # get and confirm statuses statuses = local_experiment.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) + assert all([stat == JobStatus.COMPLETED for stat in statuses]) diff --git a/tests/test_step_info.py b/tests/_legacy/test_step_info.py similarity index 89% rename from tests/test_step_info.py rename to tests/_legacy/test_step_info.py index fcccaa9cd4..06e914b0a8 100644 --- a/tests/test_step_info.py +++ b/tests/_legacy/test_step_info.py @@ -26,8 +26,8 @@ import pytest -from smartsim._core.launcher.stepInfo import * -from smartsim.status import SmartSimStatus +from smartsim._core.launcher.step_info import * +from smartsim.status import JobStatus # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b @@ -35,7 +35,7 @@ def test_str(): step_info = StepInfo( - status=SmartSimStatus.STATUS_COMPLETED, + status=JobStatus.COMPLETED, launcher_status="COMPLETED", returncode=0, ) @@ -47,4 +47,4 @@ def test_str(): def test_default(): step_info = UnmanagedStepInfo() - assert step_info._get_smartsim_status(None) == SmartSimStatus.STATUS_FAILED + assert step_info._get_smartsim_status(None) == JobStatus.FAILED diff --git a/tests/test_symlinking.py b/tests/_legacy/test_symlinking.py similarity index 75% rename from tests/test_symlinking.py rename to tests/_legacy/test_symlinking.py index 2b70e3e9f9..95aa187e6b 100644 --- a/tests/test_symlinking.py +++ b/tests/_legacy/test_symlinking.py @@ -32,9 +32,9 @@ from smartsim import Experiment from smartsim._core.config import CONFIG from smartsim._core.control.controller import Controller, _AnonymousBatchJob -from smartsim.database.orchestrator import Orchestrator -from smartsim.entity.ensemble import Ensemble -from smartsim.entity.model import Model +from smartsim.builders.ensemble import Ensemble +from smartsim.database.orchestrator import FeatureStore +from smartsim.entity.application import Application from smartsim.settings.base import RunSettings from smartsim.settings.slurmSettings import SbatchSettings, SrunSettings @@ -49,23 +49,29 @@ batch_rs = SrunSettings("echo", ["spam", "eggs"]) ens = Ensemble("ens", params={}, run_settings=rs, batch_settings=bs, replicas=3) -orc = Orchestrator(db_nodes=3, batch=True, launcher="slurm", run_command="srun") -model = Model("test_model", params={}, path="", run_settings=rs) -batch_model = Model( - "batch_test_model", params={}, path="", run_settings=batch_rs, batch_settings=bs +feature_store = FeatureStore( + fs_nodes=3, batch=True, launcher="slurm", run_command="srun" ) -anon_batch_model = _AnonymousBatchJob(batch_model) +application = Application("test_application", params={}, path="", run_settings=rs) +batch_application = Application( + "batch_test_application", + params={}, + path="", + run_settings=batch_rs, + batch_settings=bs, +) +anon_batch_application = _AnonymousBatchJob(batch_application) @pytest.mark.parametrize( "entity", - [pytest.param(ens, id="ensemble"), pytest.param(model, id="model")], + [pytest.param(ens, id="ensemble"), pytest.param(application, id="application")], ) def test_symlink(test_dir, entity): """Test symlinking historical output files""" entity.path = test_dir if entity.type == Ensemble: - for member in ens.models: + for member in ens.applications: symlink_with_create_job_step(test_dir, member) else: symlink_with_create_job_step(test_dir, entity) @@ -92,8 +98,8 @@ def symlink_with_create_job_step(test_dir, entity): "entity", [ pytest.param(ens, id="ensemble"), - pytest.param(orc, id="orchestrator"), - pytest.param(anon_batch_model, id="model"), + pytest.param(feature_store, id="featurestore"), + pytest.param(anon_batch_application, id="application"), ], ) def test_batch_symlink(entity, test_dir): @@ -116,31 +122,35 @@ def test_batch_symlink(entity, test_dir): def test_symlink_error(test_dir): """Ensure FileNotFoundError is thrown""" - bad_model = Model( - "bad_model", + bad_application = Application( + "bad_application", params={}, path=pathlib.Path(test_dir, "badpath"), run_settings=RunSettings("echo"), ) - telem_dir = pathlib.Path(test_dir, "bad_model_telemetry") - bad_step = controller._create_job_step(bad_model, telem_dir) + telem_dir = pathlib.Path(test_dir, "bad_application_telemetry") + bad_step = controller._create_job_step(bad_application, telem_dir) with pytest.raises(FileNotFoundError): - controller.symlink_output_files(bad_step, bad_model) + controller.symlink_output_files(bad_step, bad_application) -def test_failed_model_launch_symlinks(test_dir): +def test_failed_application_launch_symlinks(test_dir): exp_name = "failed-exp" exp = Experiment(exp_name, exp_path=test_dir) - test_model = exp.create_model( - "test_model", run_settings=batch_rs, batch_settings=bs + test_application = exp.create_application( + "test_application", run_settings=batch_rs, batch_settings=bs ) - exp.generate(test_model) + exp.generate(test_application) with pytest.raises(TypeError): - exp.start(test_model) + exp.start(test_application) - _should_not_be_symlinked(pathlib.Path(test_model.path)) - assert not pathlib.Path(test_model.path, f"{test_model.name}.out").is_symlink() - assert not pathlib.Path(test_model.path, f"{test_model.name}.err").is_symlink() + _should_not_be_symlinked(pathlib.Path(test_application.path)) + assert not pathlib.Path( + test_application.path, f"{test_application.name}.out" + ).is_symlink() + assert not pathlib.Path( + test_application.path, f"{test_application.name}.err" + ).is_symlink() def test_failed_ensemble_launch_symlinks(test_dir): @@ -161,7 +171,7 @@ def test_failed_ensemble_launch_symlinks(test_dir): test_ensemble.path, f"{test_ensemble.name}.err" ).is_symlink() - for i in range(len(test_ensemble.models)): + for i in range(len(test_ensemble.applications)): assert not pathlib.Path( test_ensemble.path, f"{test_ensemble.name}_{i}", @@ -184,7 +194,7 @@ def test_non_batch_ensemble_symlinks(test_dir): exp.generate(test_ensemble) exp.start(test_ensemble, block=True) - for i in range(len(test_ensemble.models)): + for i in range(len(test_ensemble.applications)): _should_be_symlinked( pathlib.Path( test_ensemble.path, @@ -205,31 +215,37 @@ def test_non_batch_ensemble_symlinks(test_dir): _should_not_be_symlinked(pathlib.Path(exp.exp_path, "smartsim_params.txt")) -def test_non_batch_model_symlinks(test_dir): - exp_name = "test-non-batch-model" +def test_non_batch_application_symlinks(test_dir): + exp_name = "test-non-batch-application" exp = Experiment(exp_name, exp_path=test_dir) rs = RunSettings("echo", ["spam", "eggs"]) - test_model = exp.create_model("test_model", path=test_dir, run_settings=rs) - exp.generate(test_model) - exp.start(test_model, block=True) + test_application = exp.create_application( + "test_application", path=test_dir, run_settings=rs + ) + exp.generate(test_application) + exp.start(test_application, block=True) - assert pathlib.Path(test_model.path).exists() + assert pathlib.Path(test_application.path).exists() - _should_be_symlinked(pathlib.Path(test_model.path, f"{test_model.name}.out"), True) - _should_be_symlinked(pathlib.Path(test_model.path, f"{test_model.name}.err"), False) + _should_be_symlinked( + pathlib.Path(test_application.path, f"{test_application.name}.out"), True + ) + _should_be_symlinked( + pathlib.Path(test_application.path, f"{test_application.name}.err"), False + ) _should_not_be_symlinked(pathlib.Path(exp.exp_path, "smartsim_params.txt")) -def test_non_batch_orchestrator_symlinks(test_dir): - exp = Experiment("test-non-batch-orc", exp_path=test_dir) +def test_non_batch_feature_store_symlinks(test_dir): + exp = Experiment("test-non-batch-feature-store", exp_path=test_dir) - db = exp.create_database(interface="lo") + db = exp.create_feature_store(interface="lo") exp.generate(db) exp.start(db, block=True) exp.stop(db) - for i in range(db.db_nodes): + for i in range(db.fs_nodes): _should_be_symlinked(pathlib.Path(db.path, f"{db.name}_{i}.out"), False) _should_be_symlinked(pathlib.Path(db.path, f"{db.name}_{i}.err"), False) diff --git a/tests/test_telemetry_monitor.py b/tests/_legacy/test_telemetry_monitor.py similarity index 81% rename from tests/test_telemetry_monitor.py rename to tests/_legacy/test_telemetry_monitor.py index c1bfe27199..262f07e1e6 100644 --- a/tests/test_telemetry_monitor.py +++ b/tests/_legacy/test_telemetry_monitor.py @@ -39,12 +39,12 @@ from conftest import FileUtils, WLMUtils from smartsim import Experiment from smartsim._core.control.job import Job, JobEntity -from smartsim._core.control.jobmanager import JobManager -from smartsim._core.entrypoints.telemetrymonitor import get_parser +from smartsim._core.control.job_manager import JobManager +from smartsim._core.entrypoints.telemetry_monitor import get_parser from smartsim._core.launcher.launcher import WLMLauncher -from smartsim._core.launcher.slurm.slurmLauncher import SlurmLauncher +from smartsim._core.launcher.slurm.slurm_launcher import SlurmLauncher from smartsim._core.launcher.step.step import Step, proxyable_launch_cmd -from smartsim._core.launcher.stepInfo import StepInfo +from smartsim._core.launcher.step_info import StepInfo from smartsim._core.utils import serialize from smartsim._core.utils.helpers import get_ts_ms from smartsim._core.utils.telemetry.manifest import Run, RuntimeManifest @@ -56,7 +56,7 @@ from smartsim._core.utils.telemetry.util import map_return_code, write_event from smartsim.error.errors import UnproxyableStepError from smartsim.settings.base import RunSettings -from smartsim.status import SmartSimStatus +from smartsim.status import JobStatus ALL_ARGS = {"-exp_dir", "-frequency"} PROXY_ENTRY_POINT = "smartsim._core.entrypoints.indirect" @@ -296,14 +296,14 @@ def test_load_manifest(fileutils: FileUtils, test_dir: str, config: cfg.Config): assert manifest.launcher == "Slurm" assert len(manifest.runs) == 6 - assert len(manifest.runs[0].models) == 1 - assert len(manifest.runs[2].models) == 8 # 8 models in ensemble - assert len(manifest.runs[0].orchestrators) == 0 - assert len(manifest.runs[1].orchestrators) == 3 # 3 shards in db + assert len(manifest.runs[0].applications) == 1 + assert len(manifest.runs[2].applications) == 8 # 8 applications in ensemble + assert len(manifest.runs[0].featurestores) == 0 + assert len(manifest.runs[1].featurestores) == 3 # 3 shards in fs -def test_load_manifest_colo_model(fileutils: FileUtils): - """Ensure that the runtime manifest loads correctly when containing a colocated model""" +def test_load_manifest_colo_application(fileutils: FileUtils): + """Ensure that the runtime manifest loads correctly when containing a colocated application""" # NOTE: for regeneration, this manifest can use `test_telemetry_colo` sample_manifest_path = fileutils.get_test_conf_path("telemetry/colocatedmodel.json") sample_manifest = pathlib.Path(sample_manifest_path) @@ -315,11 +315,11 @@ def test_load_manifest_colo_model(fileutils: FileUtils): assert manifest.launcher == "Slurm" assert len(manifest.runs) == 1 - assert len(manifest.runs[0].models) == 1 + assert len(manifest.runs[0].applications) == 1 -def test_load_manifest_serial_models(fileutils: FileUtils): - """Ensure that the runtime manifest loads correctly when containing multiple models""" +def test_load_manifest_serial_applications(fileutils: FileUtils): + """Ensure that the runtime manifest loads correctly when containing multiple applications""" # NOTE: for regeneration, this manifest can use `test_telemetry_colo` sample_manifest_path = fileutils.get_test_conf_path("telemetry/serialmodels.json") sample_manifest = pathlib.Path(sample_manifest_path) @@ -331,12 +331,12 @@ def test_load_manifest_serial_models(fileutils: FileUtils): assert manifest.launcher == "Slurm" assert len(manifest.runs) == 1 - assert len(manifest.runs[0].models) == 5 + assert len(manifest.runs[0].applications) == 5 -def test_load_manifest_db_and_models(fileutils: FileUtils): - """Ensure that the runtime manifest loads correctly when containing models & - orchestrator across 2 separate runs""" +def test_load_manifest_fs_and_applications(fileutils: FileUtils): + """Ensure that the runtime manifest loads correctly when containing applications & + feature store across 2 separate runs""" # NOTE: for regeneration, this manifest can use `test_telemetry_colo` sample_manifest_path = fileutils.get_test_conf_path("telemetry/db_and_model.json") sample_manifest = pathlib.Path(sample_manifest_path) @@ -348,19 +348,19 @@ def test_load_manifest_db_and_models(fileutils: FileUtils): assert manifest.launcher == "Slurm" assert len(manifest.runs) == 2 - assert len(manifest.runs[0].orchestrators) == 1 - assert len(manifest.runs[1].models) == 1 + assert len(manifest.runs[0].featurestores) == 1 + assert len(manifest.runs[1].applications) == 1 # verify collector paths from manifest are deserialized to collector config - assert manifest.runs[0].orchestrators[0].collectors["client"] - assert manifest.runs[0].orchestrators[0].collectors["memory"] + assert manifest.runs[0].featurestores[0].collectors["client"] + assert manifest.runs[0].featurestores[0].collectors["memory"] # verify collector paths missing from manifest are empty - assert not manifest.runs[0].orchestrators[0].collectors["client_count"] + assert not manifest.runs[0].featurestores[0].collectors["client_count"] -def test_load_manifest_db_and_models_1run(fileutils: FileUtils): - """Ensure that the runtime manifest loads correctly when containing models & - orchestrator in a single run""" +def test_load_manifest_fs_and_applications_1run(fileutils: FileUtils): + """Ensure that the runtime manifest loads correctly when containing applications & + featurestore in a single run""" # NOTE: for regeneration, this manifest can use `test_telemetry_colo` sample_manifest_path = fileutils.get_test_conf_path( "telemetry/db_and_model_1run.json" @@ -374,21 +374,33 @@ def test_load_manifest_db_and_models_1run(fileutils: FileUtils): assert manifest.launcher == "Slurm" assert len(manifest.runs) == 1 - assert len(manifest.runs[0].orchestrators) == 1 - assert len(manifest.runs[0].models) == 1 + assert len(manifest.runs[0].featurestores) == 1 + assert len(manifest.runs[0].applications) == 1 @pytest.mark.parametrize( - ["task_id", "step_id", "etype", "exp_isorch", "exp_ismanaged"], + ["task_id", "step_id", "etype", "exp_isfeature_store", "exp_ismanaged"], [ - pytest.param("123", "", "model", False, False, id="unmanaged, non-orch"), - pytest.param("456", "123", "ensemble", False, True, id="managed, non-orch"), - pytest.param("789", "987", "orchestrator", True, True, id="managed, orch"), - pytest.param("987", "", "orchestrator", True, False, id="unmanaged, orch"), + pytest.param( + "123", "", "application", False, False, id="unmanaged, non-feature_store" + ), + pytest.param( + "456", "123", "ensemble", False, True, id="managed, non-feature_store" + ), + pytest.param( + "789", "987", "featurestore", True, True, id="managed, feature_store" + ), + pytest.param( + "987", "", "featurestore", True, False, id="unmanaged, feature_store" + ), ], ) def test_persistable_computed_properties( - task_id: str, step_id: str, etype: str, exp_isorch: bool, exp_ismanaged: bool + task_id: str, + step_id: str, + etype: str, + exp_isfeature_store: bool, + exp_ismanaged: bool, ): name = f"test-{etype}-{uuid.uuid4()}" timestamp = get_ts_ms() @@ -407,12 +419,12 @@ def test_persistable_computed_properties( persistable = persistables[0] if persistables else None assert persistable.is_managed == exp_ismanaged - assert persistable.is_db == exp_isorch + assert persistable.is_fs == exp_isfeature_store def test_deserialize_ensemble(fileutils: FileUtils): - """Ensure that the children of ensembles (models) are correctly - placed in the models collection""" + """Ensure that the children of ensembles (applications) are correctly + placed in the applications collection""" sample_manifest_path = fileutils.get_test_conf_path("telemetry/ensembles.json") sample_manifest = pathlib.Path(sample_manifest_path) assert sample_manifest.exists() @@ -424,7 +436,7 @@ def test_deserialize_ensemble(fileutils: FileUtils): # NOTE: no longer returning ensembles, only children... # assert len(manifest.runs[0].ensembles) == 1 - assert len(manifest.runs[0].models) == 8 + assert len(manifest.runs[0].applications) == 8 def test_shutdown_conditions__no_monitored_jobs(test_dir: str): @@ -459,17 +471,17 @@ def test_shutdown_conditions__has_monitored_job(test_dir: str): telmon._action_handler = mani_handler assert not telmon._can_shutdown() - assert not bool(mani_handler.job_manager.db_jobs) + assert not bool(mani_handler.job_manager.fs_jobs) assert bool(mani_handler.job_manager.jobs) -def test_shutdown_conditions__has_db(test_dir: str): - """Show that an event handler w/a monitored db cannot shutdown""" +def test_shutdown_conditions__has_fs(test_dir: str): + """Show that an event handler w/a monitored fs cannot shutdown""" job_entity1 = JobEntity() job_entity1.name = "xyz" job_entity1.step_id = "123" job_entity1.task_id = "" - job_entity1.type = "orchestrator" # <---- make entity appear as db + job_entity1.type = "featurestore" # <---- make entity appear as fs mani_handler = ManifestEventHandler("xyz") ## TODO: see next comment and combine an add_job method on manieventhandler @@ -486,7 +498,7 @@ def test_shutdown_conditions__has_db(test_dir: str): telmon._action_handler = mani_handler # replace w/mock handler assert not telmon._can_shutdown() - assert bool([j for j in mani_handler._tracked_jobs.values() if j.is_db]) + assert bool([j for j in mani_handler._tracked_jobs.values() if j.is_fs]) assert not bool(mani_handler.job_manager.jobs) @@ -554,10 +566,10 @@ def is_alive(self) -> bool: ], ) @pytest.mark.asyncio -async def test_auto_shutdown__has_db( +async def test_auto_shutdown__has_fs( test_dir: str, cooldown_ms: int, task_duration_ms: int ): - """Ensure that the cooldown timer is respected with a running db""" + """Ensure that the cooldown timer is respected with a running fs""" class FauxObserver: """Mock for the watchdog file system event listener""" @@ -575,10 +587,10 @@ def is_alive(self) -> bool: return True entity = JobEntity() - entity.name = "db_0" + entity.name = "fs_0" entity.step_id = "123" entity.task_id = "" - entity.type = "orchestrator" + entity.type = "featurestore" entity.telemetry_on = True entity.status_dir = test_dir @@ -611,12 +623,12 @@ def is_alive(self) -> bool: assert observer.stop_count == 1 -def test_telemetry_single_model(fileutils, test_dir, wlmutils, config): - """Test that it is possible to create_database then colocate_db_uds/colocate_db_tcp - with unique db_identifiers""" +def test_telemetry_single_application(fileutils, test_dir, wlmutils, config): + """Test that it is possible to create_database then colocate_fs_uds/colocate_fs_tcp + with unique fs_identifiers""" # Set experiment name - exp_name = "telemetry_single_model" + exp_name = "telemetry_single_application" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -630,11 +642,11 @@ def test_telemetry_single_model(fileutils, test_dir, wlmutils, config): app_settings.set_nodes(1) app_settings.set_tasks_per_node(1) - # Create the SmartSim Model - smartsim_model = exp.create_model("perroquet", app_settings) - exp.generate(smartsim_model) - exp.start(smartsim_model, block=True) - assert exp.get_status(smartsim_model)[0] == SmartSimStatus.STATUS_COMPLETED + # Create the SmartSim Aapplication + smartsim_application = exp.create_application("perroquet", app_settings) + exp.generate(smartsim_application) + exp.start(smartsim_application, block=True) + assert exp.get_status(smartsim_application)[0] == JobStatus.COMPLETED telemetry_output_path = pathlib.Path(test_dir) / config.telemetry_subdir start_events = list(telemetry_output_path.rglob("start.json")) @@ -644,7 +656,7 @@ def test_telemetry_single_model(fileutils, test_dir, wlmutils, config): assert len(stop_events) == 1 -def test_telemetry_single_model_nonblocking( +def test_telemetry_single_application_nonblocking( fileutils, test_dir, wlmutils, monkeypatch, config ): """Ensure that the telemetry monitor logs exist when the experiment @@ -653,7 +665,7 @@ def test_telemetry_single_model_nonblocking( ctx.setattr(cfg.Config, "telemetry_frequency", 1) # Set experiment name - exp_name = "test_telemetry_single_model_nonblocking" + exp_name = "test_telemetry_single_application_nonblocking" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -667,15 +679,15 @@ def test_telemetry_single_model_nonblocking( app_settings.set_nodes(1) app_settings.set_tasks_per_node(1) - # Create the SmartSim Model - smartsim_model = exp.create_model("perroquet", app_settings) - exp.generate(smartsim_model) - exp.start(smartsim_model) + # Create the SmartSim Application + smartsim_application = exp.create_application("perroquet", app_settings) + exp.generate(smartsim_application) + exp.start(smartsim_application) telemetry_output_path = pathlib.Path(test_dir) / config.telemetry_subdir snooze_blocking(telemetry_output_path, max_delay=10, post_data_delay=1) - assert exp.get_status(smartsim_model)[0] == SmartSimStatus.STATUS_COMPLETED + assert exp.get_status(smartsim_application)[0] == JobStatus.COMPLETED start_events = list(telemetry_output_path.rglob("start.json")) stop_events = list(telemetry_output_path.rglob("stop.json")) @@ -684,15 +696,17 @@ def test_telemetry_single_model_nonblocking( assert len(stop_events) == 1 -def test_telemetry_serial_models(fileutils, test_dir, wlmutils, monkeypatch, config): +def test_telemetry_serial_applications( + fileutils, test_dir, wlmutils, monkeypatch, config +): """ - Test telemetry with models being run in serial (one after each other) + Test telemetry with applications being run in serial (one after each other) """ with monkeypatch.context() as ctx: ctx.setattr(cfg.Config, "telemetry_frequency", 1) # Set experiment name - exp_name = "telemetry_serial_models" + exp_name = "telemetry_serial_applications" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -706,16 +720,16 @@ def test_telemetry_serial_models(fileutils, test_dir, wlmutils, monkeypatch, con app_settings.set_nodes(1) app_settings.set_tasks_per_node(1) - # Create the SmartSim Model - smartsim_models = [ - exp.create_model(f"perroquet_{i}", app_settings) for i in range(5) + # Create the SmartSim Aapplication + smartsim_applications = [ + exp.create_application(f"perroquet_{i}", app_settings) for i in range(5) ] - exp.generate(*smartsim_models) - exp.start(*smartsim_models, block=True) + exp.generate(*smartsim_applications) + exp.start(*smartsim_applications, block=True) assert all( [ - status == SmartSimStatus.STATUS_COMPLETED - for status in exp.get_status(*smartsim_models) + status == JobStatus.COMPLETED + for status in exp.get_status(*smartsim_applications) ] ) @@ -727,18 +741,18 @@ def test_telemetry_serial_models(fileutils, test_dir, wlmutils, monkeypatch, con assert len(stop_events) == 5 -def test_telemetry_serial_models_nonblocking( +def test_telemetry_serial_applications_nonblocking( fileutils, test_dir, wlmutils, monkeypatch, config ): """ - Test telemetry with models being run in serial (one after each other) + Test telemetry with applications being run in serial (one after each other) in a non-blocking experiment """ with monkeypatch.context() as ctx: ctx.setattr(cfg.Config, "telemetry_frequency", 1) # Set experiment name - exp_name = "telemetry_serial_models" + exp_name = "telemetry_serial_applications" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -752,20 +766,20 @@ def test_telemetry_serial_models_nonblocking( app_settings.set_nodes(1) app_settings.set_tasks_per_node(1) - # Create the SmartSim Model - smartsim_models = [ - exp.create_model(f"perroquet_{i}", app_settings) for i in range(5) + # Create the SmartSim Aapplication + smartsim_applications = [ + exp.create_application(f"perroquet_{i}", app_settings) for i in range(5) ] - exp.generate(*smartsim_models) - exp.start(*smartsim_models) + exp.generate(*smartsim_applications) + exp.start(*smartsim_applications) telemetry_output_path = pathlib.Path(test_dir) / config.telemetry_subdir snooze_blocking(telemetry_output_path, max_delay=10, post_data_delay=1) assert all( [ - status == SmartSimStatus.STATUS_COMPLETED - for status in exp.get_status(*smartsim_models) + status == JobStatus.COMPLETED + for status in exp.get_status(*smartsim_applications) ] ) @@ -776,15 +790,15 @@ def test_telemetry_serial_models_nonblocking( assert len(stop_events) == 5 -def test_telemetry_db_only_with_generate(test_dir, wlmutils, monkeypatch, config): +def test_telemetry_fs_only_with_generate(test_dir, wlmutils, monkeypatch, config): """ - Test telemetry with only a database running + Test telemetry with only a feature store running """ with monkeypatch.context() as ctx: ctx.setattr(cfg.Config, "telemetry_frequency", 1) # Set experiment name - exp_name = "telemetry_db_with_generate" + exp_name = "telemetry_fs_with_generate" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -794,14 +808,16 @@ def test_telemetry_db_only_with_generate(test_dir, wlmutils, monkeypatch, config # Create SmartSim Experiment exp = Experiment(exp_name, launcher=test_launcher, exp_path=test_dir) - # create regular database - orc = exp.create_database(port=test_port, interface=test_interface) - exp.generate(orc) + # create regular feature store + feature_store = exp.create_feature_store( + port=test_port, interface=test_interface + ) + exp.generate(feature_store) telemetry_output_path = pathlib.Path(test_dir) / config.telemetry_subdir try: - exp.start(orc, block=True) + exp.start(feature_store, block=True) snooze_blocking(telemetry_output_path, max_delay=10, post_data_delay=1) @@ -811,24 +827,24 @@ def test_telemetry_db_only_with_generate(test_dir, wlmutils, monkeypatch, config assert len(start_events) == 1 assert len(stop_events) <= 1 finally: - exp.stop(orc) + exp.stop(feature_store) snooze_blocking(telemetry_output_path, max_delay=10, post_data_delay=1) - assert exp.get_status(orc)[0] == SmartSimStatus.STATUS_CANCELLED + assert exp.get_status(feature_store)[0] == JobStatus.CANCELLED stop_events = list(telemetry_output_path.rglob("stop.json")) assert len(stop_events) == 1 -def test_telemetry_db_only_without_generate(test_dir, wlmutils, monkeypatch, config): +def test_telemetry_fs_only_without_generate(test_dir, wlmutils, monkeypatch, config): """ - Test telemetry with only a non-generated database running + Test telemetry with only a non-generated feature store running """ with monkeypatch.context() as ctx: ctx.setattr(cfg.Config, "telemetry_frequency", 1) # Set experiment name - exp_name = "telemetry_db_only_without_generate" + exp_name = "telemetry_fs_only_without_generate" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -838,12 +854,14 @@ def test_telemetry_db_only_without_generate(test_dir, wlmutils, monkeypatch, con # Create SmartSim Experiment exp = Experiment(exp_name, launcher=test_launcher, exp_path=test_dir) - # create regular database - orc = exp.create_database(port=test_port, interface=test_interface) + # create regular feature store + feature_store = exp.create_feature_store( + port=test_port, interface=test_interface + ) telemetry_output_path = pathlib.Path(test_dir) / config.telemetry_subdir try: - exp.start(orc) + exp.start(feature_store) snooze_blocking(telemetry_output_path, max_delay=10, post_data_delay=1) @@ -853,25 +871,27 @@ def test_telemetry_db_only_without_generate(test_dir, wlmutils, monkeypatch, con assert len(start_events) == 1 assert len(stop_events) == 0 finally: - exp.stop(orc) + exp.stop(feature_store) snooze_blocking(telemetry_output_path, max_delay=10, post_data_delay=1) - assert exp.get_status(orc)[0] == SmartSimStatus.STATUS_CANCELLED + assert exp.get_status(feature_store)[0] == JobStatus.CANCELLED stop_events = list(telemetry_output_path.rglob("stop.json")) assert len(stop_events) == 1 -def test_telemetry_db_and_model(fileutils, test_dir, wlmutils, monkeypatch, config): +def test_telemetry_fs_and_application( + fileutils, test_dir, wlmutils, monkeypatch, config +): """ - Test telemetry with only a database and a model running + Test telemetry with only a feature store and a application running """ with monkeypatch.context() as ctx: ctx.setattr(cfg.Config, "telemetry_frequency", 1) # Set experiment name - exp_name = "telemetry_db_and_model" + exp_name = "telemetry_fs_and_application" # Retrieve parameters from testing environment test_launcher = wlmutils.get_test_launcher() @@ -882,29 +902,31 @@ def test_telemetry_db_and_model(fileutils, test_dir, wlmutils, monkeypatch, conf # Create SmartSim Experiment exp = Experiment(exp_name, launcher=test_launcher, exp_path=test_dir) - # create regular database - orc = exp.create_database(port=test_port, interface=test_interface) - exp.generate(orc) + # create regular feature store + feature_store = exp.create_feature_store( + port=test_port, interface=test_interface + ) + exp.generate(feature_store) try: - exp.start(orc) + exp.start(feature_store) # create run settings app_settings = exp.create_run_settings(sys.executable, test_script) app_settings.set_nodes(1) app_settings.set_tasks_per_node(1) - # Create the SmartSim Model - smartsim_model = exp.create_model("perroquet", app_settings) - exp.generate(smartsim_model) - exp.start(smartsim_model, block=True) + # Create the SmartSim Aapplication + smartsim_application = exp.create_application("perroquet", app_settings) + exp.generate(smartsim_application) + exp.start(smartsim_application, block=True) finally: - exp.stop(orc) + exp.stop(feature_store) telemetry_output_path = pathlib.Path(test_dir) / config.telemetry_subdir snooze_blocking(telemetry_output_path, max_delay=10, post_data_delay=1) - assert exp.get_status(orc)[0] == SmartSimStatus.STATUS_CANCELLED - assert exp.get_status(smartsim_model)[0] == SmartSimStatus.STATUS_COMPLETED + assert exp.get_status(feature_store)[0] == JobStatus.CANCELLED + assert exp.get_status(smartsim_application)[0] == JobStatus.COMPLETED start_events = list(telemetry_output_path.rglob("database/**/start.json")) stop_events = list(telemetry_output_path.rglob("database/**/stop.json")) @@ -912,8 +934,8 @@ def test_telemetry_db_and_model(fileutils, test_dir, wlmutils, monkeypatch, conf assert len(start_events) == 1 assert len(stop_events) == 1 - start_events = list(telemetry_output_path.rglob("model/**/start.json")) - stop_events = list(telemetry_output_path.rglob("model/**/stop.json")) + start_events = list(telemetry_output_path.rglob("application/**/start.json")) + stop_events = list(telemetry_output_path.rglob("application/**/stop.json")) assert len(start_events) == 1 assert len(stop_events) == 1 @@ -943,12 +965,7 @@ def test_telemetry_ensemble(fileutils, test_dir, wlmutils, monkeypatch, config): ens = exp.create_ensemble("troupeau", run_settings=app_settings, replicas=5) exp.generate(ens) exp.start(ens, block=True) - assert all( - [ - status == SmartSimStatus.STATUS_COMPLETED - for status in exp.get_status(ens) - ] - ) + assert all([status == JobStatus.COMPLETED for status in exp.get_status(ens)]) telemetry_output_path = pathlib.Path(test_dir) / config.telemetry_subdir snooze_blocking(telemetry_output_path, max_delay=10, post_data_delay=1) @@ -961,7 +978,7 @@ def test_telemetry_ensemble(fileutils, test_dir, wlmutils, monkeypatch, config): def test_telemetry_colo(fileutils, test_dir, wlmutils, coloutils, monkeypatch, config): """ - Test telemetry with only a colocated model running + Test telemetry with only a colocated application running """ with monkeypatch.context() as ctx: @@ -976,7 +993,7 @@ def test_telemetry_colo(fileutils, test_dir, wlmutils, coloutils, monkeypatch, c # Create SmartSim Experiment exp = Experiment(exp_name, launcher=test_launcher, exp_path=test_dir) - smartsim_model = coloutils.setup_test_colo( + smartsim_application = coloutils.setup_test_colo( fileutils, "uds", exp, @@ -984,12 +1001,12 @@ def test_telemetry_colo(fileutils, test_dir, wlmutils, coloutils, monkeypatch, c {}, ) - exp.generate(smartsim_model) - exp.start(smartsim_model, block=True) + exp.generate(smartsim_application) + exp.start(smartsim_application, block=True) assert all( [ - status == SmartSimStatus.STATUS_COMPLETED - for status in exp.get_status(smartsim_model) + status == JobStatus.COMPLETED + for status in exp.get_status(smartsim_application) ] ) @@ -997,7 +1014,7 @@ def test_telemetry_colo(fileutils, test_dir, wlmutils, coloutils, monkeypatch, c start_events = list(telemetry_output_path.rglob("start.json")) stop_events = list(telemetry_output_path.rglob("stop.json")) - # the colodb does NOT show up as a unique entity in the telemetry + # the colofs does NOT show up as a unique entity in the telemetry assert len(start_events) == 1 assert len(stop_events) == 1 @@ -1039,10 +1056,10 @@ def test_telemetry_autoshutdown( exp = Experiment(exp_name, launcher=test_launcher, exp_path=test_dir) rs = RunSettings("python", exe_args=["sleep.py", "1"]) - model = exp.create_model("model", run_settings=rs) + application = exp.create_application("application", run_settings=rs) start_time = get_ts_ms() - exp.start(model, block=True) + exp.start(application, block=True) telemetry_output_path = pathlib.Path(test_dir) / config.telemetry_subdir empty_mani = list(telemetry_output_path.rglob("manifest.json")) @@ -1197,39 +1214,39 @@ def test_multistart_experiment( rs_m = exp.create_run_settings("echo", ["hello", "world"], run_command=run_command) rs_m.set_nodes(1) rs_m.set_tasks(1) - model = exp.create_model("my-model", run_settings=rs_m) + application = exp.create_application("my-application", run_settings=rs_m) - db = exp.create_database( - db_nodes=1, + fs = exp.create_feature_store( + fs_nodes=1, port=wlmutils.get_test_port(), interface=wlmutils.get_test_interface(), ) - exp.generate(db, ens, model, overwrite=True) + exp.generate(fs, ens, application, overwrite=True) with monkeypatch.context() as ctx: ctx.setattr(cfg.Config, "telemetry_frequency", 1) ctx.setattr(cfg.Config, "telemetry_cooldown", 45) - exp.start(model, block=False) + exp.start(application, block=False) # track PID to see that telmon cooldown avoids restarting process tm_pid = exp._control._telemetry_monitor.pid - exp.start(db, block=False) + exp.start(fs, block=False) # check that same TM proc is active assert tm_pid == exp._control._telemetry_monitor.pid try: exp.start(ens, block=True, summary=True) finally: - exp.stop(db) + exp.stop(fs) assert tm_pid == exp._control._telemetry_monitor.pid - time.sleep(3) # time for telmon to write db stop event + time.sleep(3) # time for telmon to write fs stop event telemetry_output_path = pathlib.Path(test_dir) / config.telemetry_subdir - db_start_events = list(telemetry_output_path.rglob("database/**/start.json")) - assert len(db_start_events) == 1 + fs_start_events = list(telemetry_output_path.rglob("database/**/start.json")) + assert len(fs_start_events) == 1 m_start_events = list(telemetry_output_path.rglob("model/**/start.json")) assert len(m_start_events) == 1 @@ -1241,12 +1258,12 @@ def test_multistart_experiment( @pytest.mark.parametrize( "status_in, expected_out", [ - pytest.param(SmartSimStatus.STATUS_CANCELLED, 1, id="failure on cancellation"), - pytest.param(SmartSimStatus.STATUS_COMPLETED, 0, id="success on completion"), - pytest.param(SmartSimStatus.STATUS_FAILED, 1, id="failure on failed"), - pytest.param(SmartSimStatus.STATUS_NEW, None, id="failure on new"), - pytest.param(SmartSimStatus.STATUS_PAUSED, None, id="failure on paused"), - pytest.param(SmartSimStatus.STATUS_RUNNING, None, id="failure on running"), + pytest.param(JobStatus.CANCELLED, 1, id="failure on cancellation"), + pytest.param(JobStatus.COMPLETED, 0, id="success on completion"), + pytest.param(JobStatus.FAILED, 1, id="failure on failed"), + pytest.param(JobStatus.NEW, None, id="failure on new"), + pytest.param(JobStatus.PAUSED, None, id="failure on paused"), + pytest.param(JobStatus.RUNNING, None, id="failure on running"), ], ) def test_faux_rc(status_in: str, expected_out: t.Optional[int]): @@ -1260,18 +1277,12 @@ def test_faux_rc(status_in: str, expected_out: t.Optional[int]): @pytest.mark.parametrize( "status_in, expected_out, expected_has_jobs", [ - pytest.param( - SmartSimStatus.STATUS_CANCELLED, 1, False, id="failure on cancellation" - ), - pytest.param( - SmartSimStatus.STATUS_COMPLETED, 0, False, id="success on completion" - ), - pytest.param(SmartSimStatus.STATUS_FAILED, 1, False, id="failure on failed"), - pytest.param(SmartSimStatus.STATUS_NEW, None, True, id="failure on new"), - pytest.param(SmartSimStatus.STATUS_PAUSED, None, True, id="failure on paused"), - pytest.param( - SmartSimStatus.STATUS_RUNNING, None, True, id="failure on running" - ), + pytest.param(JobStatus.CANCELLED, 1, False, id="failure on cancellation"), + pytest.param(JobStatus.COMPLETED, 0, False, id="success on completion"), + pytest.param(JobStatus.FAILED, 1, False, id="failure on failed"), + pytest.param(JobStatus.NEW, None, True, id="failure on new"), + pytest.param(JobStatus.PAUSED, None, True, id="failure on paused"), + pytest.param(JobStatus.RUNNING, None, True, id="failure on running"), ], ) @pytest.mark.asyncio @@ -1303,7 +1314,7 @@ def _faux_updates(_self: WLMLauncher, _names: t.List[str]) -> t.List[StepInfo]: job_entity.step_id = "faux-step-id" job_entity.task_id = 1234 job_entity.status_dir = test_dir - job_entity.type = "orchestrator" + job_entity.type = "featurestore" job = Job(job_entity.name, job_entity.step_id, job_entity, "slurm", True) diff --git a/tests/utils/test_network.py b/tests/_legacy/utils/test_network.py similarity index 100% rename from tests/utils/test_network.py rename to tests/_legacy/utils/test_network.py diff --git a/tests/utils/test_security.py b/tests/_legacy/utils/test_security.py similarity index 100% rename from tests/utils/test_security.py rename to tests/_legacy/utils/test_security.py diff --git a/tests/backends/test_ml_init.py b/tests/backends/test_ml_init.py new file mode 100644 index 0000000000..7f5c6f9864 --- /dev/null +++ b/tests/backends/test_ml_init.py @@ -0,0 +1,48 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import tempfile + +import pytest + +try: + import tensorflow + import torch +except: + pytestmark = pytest.mark.skip("tensorflow or torch were not availalble") +else: + pytestmark = [pytest.mark.group_a, pytest.mark.group_b, pytest.mark.slow_tests] + + +def test_import_ss_ml(monkeypatch): + with tempfile.TemporaryDirectory() as empty_dir: + # Move to an empty directory so `smartsim` dir is not in cwd + monkeypatch.chdir(empty_dir) + + # Make sure SmartSim ML modules are importable + import smartsim.ml + import smartsim.ml.tf + import smartsim.ml.torch diff --git a/tests/dragon/__init__.py b/tests/dragon_wlm/__init__.py similarity index 100% rename from tests/dragon/__init__.py rename to tests/dragon_wlm/__init__.py diff --git a/tests/dragon/channel.py b/tests/dragon_wlm/channel.py similarity index 100% rename from tests/dragon/channel.py rename to tests/dragon_wlm/channel.py diff --git a/tests/dragon/conftest.py b/tests/dragon_wlm/conftest.py similarity index 99% rename from tests/dragon/conftest.py rename to tests/dragon_wlm/conftest.py index d542700175..bdec40b7e5 100644 --- a/tests/dragon/conftest.py +++ b/tests/dragon_wlm/conftest.py @@ -27,10 +27,7 @@ from __future__ import annotations import os -import pathlib import socket -import subprocess -import sys import typing as t import pytest diff --git a/tests/dragon/feature_store.py b/tests/dragon_wlm/feature_store.py similarity index 100% rename from tests/dragon/feature_store.py rename to tests/dragon_wlm/feature_store.py diff --git a/tests/dragon/test_core_machine_learning_worker.py b/tests/dragon_wlm/test_core_machine_learning_worker.py similarity index 98% rename from tests/dragon/test_core_machine_learning_worker.py rename to tests/dragon_wlm/test_core_machine_learning_worker.py index e9c356b4e0..f9295d9e86 100644 --- a/tests/dragon/test_core_machine_learning_worker.py +++ b/tests/dragon_wlm/test_core_machine_learning_worker.py @@ -39,10 +39,8 @@ InferenceRequest, MachineLearningWorkerCore, RequestBatch, - TransformInputResult, TransformOutputResult, ) -from smartsim._core.utils import installed_redisai_backends from .feature_store import FileSystemFeatureStore, MemoryFeatureStore @@ -53,7 +51,9 @@ is_dragon = ( pytest.test_launcher == "dragon" if hasattr(pytest, "test_launcher") else False ) -torch_available = "torch" in installed_redisai_backends() +torch_available = ( + "torch" in [] +) # todo: update test to replace installed_redisai_backends() @pytest.fixture diff --git a/tests/dragon/test_device_manager.py b/tests/dragon_wlm/test_device_manager.py similarity index 100% rename from tests/dragon/test_device_manager.py rename to tests/dragon_wlm/test_device_manager.py diff --git a/tests/dragon/test_dragon_backend.py b/tests/dragon_wlm/test_dragon_backend.py similarity index 99% rename from tests/dragon/test_dragon_backend.py rename to tests/dragon_wlm/test_dragon_backend.py index 0e64c358df..dc98f5de75 100644 --- a/tests/dragon/test_dragon_backend.py +++ b/tests/dragon_wlm/test_dragon_backend.py @@ -33,7 +33,7 @@ dragon = pytest.importorskip("dragon") -from smartsim._core.launcher.dragon.dragonBackend import DragonBackend +from smartsim._core.launcher.dragon.dragon_backend import DragonBackend from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel from smartsim._core.mli.infrastructure.comm.event import ( OnCreateConsumer, diff --git a/tests/dragon_wlm/test_dragon_comm_utils.py b/tests/dragon_wlm/test_dragon_comm_utils.py new file mode 100644 index 0000000000..a6f9c206a4 --- /dev/null +++ b/tests/dragon_wlm/test_dragon_comm_utils.py @@ -0,0 +1,257 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import uuid + +import pytest + +from smartsim.error.errors import SmartSimError + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.channels as dch +import dragon.infrastructure.parameters as dp +import dragon.managed_memory as dm +import dragon.fli as fli + +# isort: on + +from smartsim._core.mli.comm.channel import dragon_util +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.fixture(scope="function") +def the_pool() -> dm.MemoryPool: + """Creates a memory pool.""" + raw_pool_descriptor = dp.this_process.default_pd + descriptor_ = base64.b64decode(raw_pool_descriptor) + + pool = dm.MemoryPool.attach(descriptor_) + return pool + + +@pytest.fixture(scope="function") +def the_channel() -> dch.Channel: + """Creates a Channel attached to the local memory pool.""" + channel = dch.Channel.make_process_local() + return channel + + +@pytest.fixture(scope="function") +def the_fli(the_channel) -> fli.FLInterface: + """Creates an FLI attached to the local memory pool.""" + fli_ = fli.FLInterface(main_ch=the_channel, manager_ch=None) + return fli_ + + +def test_descriptor_to_channel_empty() -> None: + """Verify that `descriptor_to_channel` raises an exception when + provided with an empty descriptor.""" + descriptor = "" + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "empty" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()], +) +def test_descriptor_to_channel_b64fail(descriptor: str) -> None: + """Verify that `descriptor_to_channel` raises an exception when + provided with an incorrectly encoded descriptor. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "base64" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + [str(uuid.uuid4())], +) +def test_descriptor_to_channel_channel_fail(descriptor: str) -> None: + """Verify that `descriptor_to_channel` raises an exception when a correctly + formatted descriptor that does not describe a real channel is passed. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + assert "channel" in ex.value.args[0] + + +def test_descriptor_to_channel_channel_not_available(the_channel: dch.Channel) -> None: + """Verify that `descriptor_to_channel` raises an exception when a channel + is no longer available. + + :param the_channel: A dragon channel + """ + + # get a good descriptor & wipe out the channel so it can't be attached + descriptor = dragon_util.channel_to_descriptor(the_channel) + the_channel.destroy() + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "address" in ex.value.args[0] + + +def test_descriptor_to_channel_happy_path(the_channel: dch.Channel) -> None: + """Verify that `descriptor_to_channel` works as expected when provided + a valid descriptor + + :param the_channel: A dragon channel + """ + + # get a good descriptor + descriptor = dragon_util.channel_to_descriptor(the_channel) + + reattached = dragon_util.descriptor_to_channel(descriptor) + assert reattached + + # and just make sure creation of the descriptor is transitive + assert dragon_util.channel_to_descriptor(reattached) == descriptor + + +def test_descriptor_to_fli_empty() -> None: + """Verify that `descriptor_to_fli` raises an exception when + provided with an empty descriptor.""" + descriptor = "" + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + assert "empty" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()], +) +def test_descriptor_to_fli_b64fail(descriptor: str) -> None: + """Verify that `descriptor_to_fli` raises an exception when + provided with an incorrectly encoded descriptor. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + assert "base64" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + [str(uuid.uuid4())], +) +def test_descriptor_to_fli_fli_fail(descriptor: str) -> None: + """Verify that `descriptor_to_fli` raises an exception when a correctly + formatted descriptor that does not describe a real FLI is passed. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + assert "fli" in ex.value.args[0].lower() + + +def test_descriptor_to_fli_fli_not_available( + the_fli: fli.FLInterface, the_channel: dch.Channel +) -> None: + """Verify that `descriptor_to_fli` raises an exception when a channel + is no longer available. + + :param the_fli: A dragon FLInterface + :param the_channel: A dragon channel + """ + + # get a good descriptor & wipe out the FLI so it can't be attached + descriptor = dragon_util.channel_to_descriptor(the_fli) + the_fli.destroy() + the_channel.destroy() + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + + +def test_descriptor_to_fli_happy_path(the_fli: dch.Channel) -> None: + """Verify that `descriptor_to_fli` works as expected when provided + a valid descriptor + + :param the_fli: A dragon FLInterface + """ + + # get a good descriptor + descriptor = dragon_util.channel_to_descriptor(the_fli) + + reattached = dragon_util.descriptor_to_fli(descriptor) + assert reattached + + # and just make sure creation of the descriptor is transitive + assert dragon_util.channel_to_descriptor(reattached) == descriptor + + +def test_pool_to_descriptor_empty() -> None: + """Verify that `pool_to_descriptor` raises an exception when + provided with a null pool.""" + + with pytest.raises(ValueError) as ex: + dragon_util.pool_to_descriptor(None) + + +def test_pool_to_happy_path(the_pool) -> None: + """Verify that `pool_to_descriptor` creates a descriptor + when supplied with a valid memory pool.""" + + descriptor = dragon_util.pool_to_descriptor(the_pool) + assert descriptor diff --git a/tests/dragon/test_dragon_ddict_utils.py b/tests/dragon_wlm/test_dragon_ddict_utils.py similarity index 100% rename from tests/dragon/test_dragon_ddict_utils.py rename to tests/dragon_wlm/test_dragon_ddict_utils.py diff --git a/tests/dragon/test_environment_loader.py b/tests/dragon_wlm/test_environment_loader.py similarity index 100% rename from tests/dragon/test_environment_loader.py rename to tests/dragon_wlm/test_environment_loader.py diff --git a/tests/dragon/test_error_handling.py b/tests/dragon_wlm/test_error_handling.py similarity index 100% rename from tests/dragon/test_error_handling.py rename to tests/dragon_wlm/test_error_handling.py diff --git a/tests/dragon/test_event_consumer.py b/tests/dragon_wlm/test_event_consumer.py similarity index 100% rename from tests/dragon/test_event_consumer.py rename to tests/dragon_wlm/test_event_consumer.py diff --git a/tests/dragon/test_featurestore.py b/tests/dragon_wlm/test_featurestore.py similarity index 100% rename from tests/dragon/test_featurestore.py rename to tests/dragon_wlm/test_featurestore.py diff --git a/tests/dragon/test_featurestore_base.py b/tests/dragon_wlm/test_featurestore_base.py similarity index 100% rename from tests/dragon/test_featurestore_base.py rename to tests/dragon_wlm/test_featurestore_base.py diff --git a/tests/dragon/test_featurestore_integration.py b/tests/dragon_wlm/test_featurestore_integration.py similarity index 100% rename from tests/dragon/test_featurestore_integration.py rename to tests/dragon_wlm/test_featurestore_integration.py diff --git a/tests/dragon/test_inference_reply.py b/tests/dragon_wlm/test_inference_reply.py similarity index 100% rename from tests/dragon/test_inference_reply.py rename to tests/dragon_wlm/test_inference_reply.py diff --git a/tests/dragon/test_inference_request.py b/tests/dragon_wlm/test_inference_request.py similarity index 100% rename from tests/dragon/test_inference_request.py rename to tests/dragon_wlm/test_inference_request.py diff --git a/tests/dragon/test_protoclient.py b/tests/dragon_wlm/test_protoclient.py similarity index 100% rename from tests/dragon/test_protoclient.py rename to tests/dragon_wlm/test_protoclient.py diff --git a/tests/dragon/test_reply_building.py b/tests/dragon_wlm/test_reply_building.py similarity index 100% rename from tests/dragon/test_reply_building.py rename to tests/dragon_wlm/test_reply_building.py diff --git a/tests/dragon/test_request_dispatcher.py b/tests/dragon_wlm/test_request_dispatcher.py similarity index 95% rename from tests/dragon/test_request_dispatcher.py rename to tests/dragon_wlm/test_request_dispatcher.py index 70d73e243f..8dc0f67a31 100644 --- a/tests/dragon/test_request_dispatcher.py +++ b/tests/dragon_wlm/test_request_dispatcher.py @@ -26,7 +26,6 @@ import gc import os -import subprocess as sp import time import typing as t from queue import Empty @@ -34,27 +33,29 @@ import numpy as np import pytest -from . import conftest -from .utils import msg_pump - pytest.importorskip("dragon") # isort: off import dragon + +from dragon.fli import FLInterface +from dragon.data.ddict.ddict import DDict +from dragon.managed_memory import MemoryAlloc + import multiprocessing as mp import torch # isort: on -from dragon import fli -from dragon.data.ddict.ddict import DDict -from dragon.managed_memory import MemoryAlloc from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.control.dragon_util import ( + function_as_dragon_proc, +) from smartsim._core.mli.infrastructure.control.request_dispatcher import ( RequestBatch, RequestDispatcher, @@ -71,6 +72,8 @@ from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker from smartsim.log import get_logger +from .utils.msg_pump import mock_messages + logger = get_logger(__name__) # The tests in this file belong to the dragon group @@ -83,6 +86,7 @@ pass +@pytest.mark.skip("TODO: Fix issue unpickling messages") @pytest.mark.parametrize("num_iterations", [4]) def test_request_dispatcher( num_iterations: int, @@ -96,7 +100,7 @@ def test_request_dispatcher( """ to_worker_channel = create_local() - to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + to_worker_fli = FLInterface(main_ch=to_worker_channel, manager_ch=None) to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli) backbone_fs = BackboneFeatureStore(the_storage, allow_reserved_writes=True) @@ -143,8 +147,8 @@ def test_request_dispatcher( callback_channel = DragonCommChannel.from_local() channels.append(callback_channel) - process = conftest.function_as_dragon_proc( - msg_pump.mock_messages, + process = function_as_dragon_proc( + mock_messages, [ worker_queue.descriptor, backbone_fs.descriptor, diff --git a/tests/dragon/test_torch_worker.py b/tests/dragon_wlm/test_torch_worker.py similarity index 100% rename from tests/dragon/test_torch_worker.py rename to tests/dragon_wlm/test_torch_worker.py diff --git a/tests/dragon/test_worker_manager.py b/tests/dragon_wlm/test_worker_manager.py similarity index 98% rename from tests/dragon/test_worker_manager.py rename to tests/dragon_wlm/test_worker_manager.py index 4047a731fc..20370bea7e 100644 --- a/tests/dragon/test_worker_manager.py +++ b/tests/dragon_wlm/test_worker_manager.py @@ -195,9 +195,8 @@ def mock_messages( request_bytes = MessageHandler.serialize_request(request) fli: DragonFLIChannel = worker_queue - with fli._fli.sendh(timeout=None, stream_channel=fli._channel) as sendh: - sendh.send_bytes(request_bytes) - sendh.send_bytes(batch_bytes) + multipart_message = [request_bytes, batch_bytes] + fli.send_multiple(multipart_message) logger.info("published message") diff --git a/tests/dragon/utils/__init__.py b/tests/dragon_wlm/utils/__init__.py similarity index 100% rename from tests/dragon/utils/__init__.py rename to tests/dragon_wlm/utils/__init__.py diff --git a/tests/dragon/utils/channel.py b/tests/dragon_wlm/utils/channel.py similarity index 100% rename from tests/dragon/utils/channel.py rename to tests/dragon_wlm/utils/channel.py diff --git a/tests/dragon/utils/msg_pump.py b/tests/dragon_wlm/utils/msg_pump.py similarity index 100% rename from tests/dragon/utils/msg_pump.py rename to tests/dragon_wlm/utils/msg_pump.py diff --git a/tests/dragon/utils/worker.py b/tests/dragon_wlm/utils/worker.py similarity index 100% rename from tests/dragon/utils/worker.py rename to tests/dragon_wlm/utils/worker.py diff --git a/tests/mli/test_integrated_torch_worker.py b/tests/mli/test_integrated_torch_worker.py index 60f1f0c6b9..4d93358bfb 100644 --- a/tests/mli/test_integrated_torch_worker.py +++ b/tests/mli/test_integrated_torch_worker.py @@ -25,22 +25,18 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pathlib -import typing as t import pytest import torch -# import smartsim.error as sse -# from smartsim._core.mli.infrastructure.control import workermanager as mli -# from smartsim._core.mli.message_handler import MessageHandler -from smartsim._core.utils import installed_redisai_backends - # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b # retrieved from pytest fixtures is_dragon = pytest.test_launcher == "dragon" -torch_available = "torch" in installed_redisai_backends() +torch_available = ( + "torch" in [] +) # todo: update test to replace installed_redisai_backends() @pytest.fixture diff --git a/tests/mli/test_service.py b/tests/mli/test_service.py index 3635f6ff78..41595ca80b 100644 --- a/tests/mli/test_service.py +++ b/tests/mli/test_service.py @@ -255,7 +255,7 @@ def test_service_health_check_freq(health_check_freq: float, run_for: float) -> expected_hc_count = run_for // health_check_freq # allow some wiggle room for frequency comparison - assert expected_hc_count - 1 <= service.num_health_checks <= expected_hc_count + 1 + assert expected_hc_count - 2 <= service.num_health_checks <= expected_hc_count + 2 assert service.num_cooldowns == 0 assert service.num_shutdowns == 1 diff --git a/tests/on_wlm/test_colocated_model.py b/tests/on_wlm/test_colocated_model.py deleted file mode 100644 index 97a47542d7..0000000000 --- a/tests/on_wlm/test_colocated_model.py +++ /dev/null @@ -1,194 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import sys - -import pytest - -from smartsim import Experiment -from smartsim.entity import Model -from smartsim.status import SmartSimStatus - -if sys.platform == "darwin": - supported_dbs = ["tcp", "deprecated"] -else: - supported_dbs = ["uds", "tcp", "deprecated"] - -# Set to true if DB logs should be generated for debugging -DEBUG_DB = False - -# retrieved from pytest fixtures -launcher = pytest.test_launcher -if launcher not in pytest.wlm_options: - pytestmark = pytest.mark.skip(reason="Not testing WLM integrations") - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_launch_colocated_model_defaults(fileutils, test_dir, coloutils, db_type): - """Test the launch of a model with a colocated database and local launcher""" - - db_args = {"debug": DEBUG_DB} - - exp = Experiment("colocated_model_defaults", launcher=launcher, exp_path=test_dir) - colo_model = coloutils.setup_test_colo( - fileutils, db_type, exp, "send_data_local_smartredis.py", db_args, on_wlm=True - ) - exp.generate(colo_model) - assert colo_model.run_settings.colocated_db_settings["custom_pinning"] == "0" - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" - - # test restarting the colocated model - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_colocated_model_disable_pinning(fileutils, test_dir, coloutils, db_type): - exp = Experiment( - "colocated_model_pinning_auto_1cpu", launcher=launcher, exp_path=test_dir - ) - db_args = { - "db_cpus": 1, - "custom_pinning": [], - "debug": DEBUG_DB, - } - - # Check to make sure that the CPU mask was correctly generated - colo_model = coloutils.setup_test_colo( - fileutils, db_type, exp, "send_data_local_smartredis.py", db_args, on_wlm=True - ) - assert colo_model.run_settings.colocated_db_settings["custom_pinning"] is None - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_colocated_model_pinning_auto_2cpu(fileutils, test_dir, coloutils, db_type): - exp = Experiment( - "colocated_model_pinning_auto_2cpu", - launcher=launcher, - exp_path=test_dir, - ) - - db_args = {"db_cpus": 2, "debug": DEBUG_DB} - - # Check to make sure that the CPU mask was correctly generated - colo_model = coloutils.setup_test_colo( - fileutils, db_type, exp, "send_data_local_smartredis.py", db_args, on_wlm=True - ) - assert colo_model.run_settings.colocated_db_settings["custom_pinning"] == "0,1" - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_colocated_model_pinning_range(fileutils, test_dir, coloutils, db_type): - # Check to make sure that the CPU mask was correctly generated - # Assume that there are at least 4 cpus on the node - - exp = Experiment( - "colocated_model_pinning_manual", - launcher=launcher, - exp_path=test_dir, - ) - - db_args = {"db_cpus": 4, "custom_pinning": range(4), "debug": DEBUG_DB} - - colo_model = coloutils.setup_test_colo( - fileutils, db_type, exp, "send_data_local_smartredis.py", db_args, on_wlm=True - ) - assert colo_model.run_settings.colocated_db_settings["custom_pinning"] == "0,1,2,3" - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_colocated_model_pinning_list(fileutils, test_dir, coloutils, db_type): - # Check to make sure that the CPU mask was correctly generated - # note we presume that this has more than 2 CPUs on the supercomputer node - - exp = Experiment( - "colocated_model_pinning_manual", - launcher=launcher, - exp_path=test_dir, - ) - - db_args = {"db_cpus": 2, "custom_pinning": [0, 2]} - - colo_model = coloutils.setup_test_colo( - fileutils, db_type, exp, "send_data_local_smartredis.py", db_args, on_wlm=True - ) - assert colo_model.run_settings.colocated_db_settings["custom_pinning"] == "0,2" - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_colocated_model_pinning_mixed(fileutils, test_dir, coloutils, db_type): - # Check to make sure that the CPU mask was correctly generated - # note we presume that this at least 4 CPUs on the supercomputer node - - exp = Experiment( - "colocated_model_pinning_manual", - launcher=launcher, - exp_path=test_dir, - ) - - db_args = {"db_cpus": 2, "custom_pinning": [range(2), 3]} - - colo_model = coloutils.setup_test_colo( - fileutils, db_type, exp, "send_data_local_smartredis.py", db_args, on_wlm=True - ) - assert colo_model.run_settings.colocated_db_settings["custom_pinning"] == "0,1,3" - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" diff --git a/tests/temp_tests/steps_tests.py b/tests/temp_tests/steps_tests.py new file mode 100644 index 0000000000..bd20607f32 --- /dev/null +++ b/tests/temp_tests/steps_tests.py @@ -0,0 +1,139 @@ +import pytest + +from smartsim._core.launcher.step import ( + AprunStep, + BsubBatchStep, + JsrunStep, + LocalStep, + MpiexecStep, + MpirunStep, + OrterunStep, + QsubBatchStep, + SbatchStep, + SrunStep, +) +from smartsim.entity import Model +from smartsim.settings import ( + AprunSettings, + BsubBatchSettings, + JsrunSettings, + MpirunSettings, + OrterunSettings, + QsubBatchSettings, + RunSettings, + SbatchSettings, + SrunSettings, +) + + +# Test creating a job step +@pytest.mark.parametrize( + "settings_type, step_type", + [ + pytest.param( + AprunSettings, + AprunStep, + id=f"aprun", + ), + pytest.param( + JsrunSettings, + JsrunStep, + id=f"jsrun", + ), + pytest.param( + SrunSettings, + SrunStep, + id="srun", + ), + pytest.param( + RunSettings, + LocalStep, + id="local", + ), + ], +) +def test_instantiate_run_settings(settings_type, step_type): + run_settings = settings_type() + run_settings.in_batch = True + model = Model( + exe="echo", exe_args="hello", name="model_name", run_settings=run_settings + ) + jobStep = step_type(entity=model, run_settings=model.run_settings) + assert jobStep.run_settings == run_settings + assert jobStep.entity == model + assert jobStep.entity_name == model.name + assert jobStep.cwd == model.path + assert jobStep.step_settings == model.run_settings + + +# Test creating a mpi job step +@pytest.mark.parametrize( + "settings_type, step_type", + [ + pytest.param( + OrterunSettings, + OrterunStep, + id="orterun", + ), + pytest.param( + MpirunSettings, + MpirunStep, + id="mpirun", + ), + ], +) +def test_instantiate_mpi_run_settings(settings_type, step_type): + run_settings = settings_type(fail_if_missing_exec=False) + run_settings.in_batch = True + model = Model( + exe="echo", exe_args="hello", name="model_name", run_settings=run_settings + ) + jobStep = step_type(entity=model, run_settings=model.run_settings) + assert jobStep.run_settings == run_settings + assert jobStep.entity == model + assert jobStep.entity_name == model.name + assert jobStep.cwd == model.path + assert jobStep.step_settings == model.run_settings + + +# Test creating a batch job step +@pytest.mark.parametrize( + "settings_type, batch_settings_type, step_type", + [ + pytest.param( + JsrunSettings, + BsubBatchSettings, + BsubBatchStep, + id=f"bsub", + ), + pytest.param( + SrunSettings, + SbatchSettings, + SbatchStep, + id="sbatch", + ), + pytest.param( + RunSettings, + QsubBatchSettings, + QsubBatchStep, + id="qsub", + ), + ], +) +def test_instantiate_batch_settings(settings_type, batch_settings_type, step_type): + run_settings = settings_type() + run_settings.in_batch = True + batch_settings = batch_settings_type() + model = Application( + exe="echo", + exe_args="hello", + name="model_name", + run_settings=run_settings, + batch_settings=batch_settings, + ) + jobStep = step_type(entity=model, batch_settings=model.batch_settings) + assert jobStep.batch_settings == batch_settings + assert jobStep.entity == model + assert jobStep.entity_name == model.name + assert jobStep.cwd == model.path + assert jobStep.step_settings == model.batch_settings diff --git a/tests/temp_tests/test_colocatedJobGroup.py b/tests/temp_tests/test_colocatedJobGroup.py new file mode 100644 index 0000000000..d6d17fc8ae --- /dev/null +++ b/tests/temp_tests/test_colocatedJobGroup.py @@ -0,0 +1,95 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim.entity.application import Application +from smartsim.launchable.base_job import BaseJob +from smartsim.launchable.colocated_job_group import ColocatedJobGroup +from smartsim.launchable.job import Job +from smartsim.settings import LaunchSettings + +pytestmark = pytest.mark.group_a + +app_1 = Application("app_1", "python") +app_2 = Application("app_2", "python") +app_3 = Application("app_3", "python") + + +class MockJob(BaseJob): + def get_launch_steps(self): + raise NotImplementedError + + +def test_create_ColocatedJobGroup(): + job_1 = MockJob() + job_group = ColocatedJobGroup([job_1]) + assert len(job_group) == 1 + + +def test_getitem_ColocatedJobGroup(): + job_1 = Job(app_1, LaunchSettings("slurm")) + job_2 = Job(app_2, LaunchSettings("slurm")) + job_group = ColocatedJobGroup([job_1, job_2]) + get_value = job_group[0].entity.name + assert get_value == job_1.entity.name + + +def test_setitem_JobGroup(): + job_1 = Job(app_1, LaunchSettings("slurm")) + job_2 = Job(app_2, LaunchSettings("slurm")) + job_group = ColocatedJobGroup([job_1, job_2]) + job_3 = Job(app_3, LaunchSettings("slurm")) + job_group[1] = job_3 + assert len(job_group) == 2 + get_value = job_group[1].entity.name + assert get_value == job_3.entity.name + + +def test_delitem_ColocatedJobGroup(): + job_1 = MockJob() + job_2 = MockJob() + job_group = ColocatedJobGroup([job_1, job_2]) + assert len(job_group) == 2 + del job_group[1] + assert len(job_group) == 1 + + +def test_len_ColocatedJobGroup(): + job_1 = MockJob() + job_2 = MockJob() + job_group = ColocatedJobGroup([job_1, job_2]) + assert len(job_group) == 2 + + +def test_insert_ColocatedJobGroup(): + job_1 = MockJob() + job_2 = MockJob() + job_group = ColocatedJobGroup([job_1, job_2]) + job_3 = MockJob() + job_group.insert(0, job_3) + get_value = job_group[0] + assert get_value == job_3 diff --git a/tests/temp_tests/test_core/test_commands/test_command.py b/tests/temp_tests/test_core/test_commands/test_command.py new file mode 100644 index 0000000000..f3d6f6a2a3 --- /dev/null +++ b/tests/temp_tests/test_core/test_commands/test_command.py @@ -0,0 +1,95 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.commands.command import Command + +pytestmark = pytest.mark.group_a + + +def test_command_init(): + cmd = Command(command=["salloc", "-N", "1"]) + assert cmd.command == ["salloc", "-N", "1"] + + +def test_command_invalid_init(): + cmd = Command(command=["salloc", "-N", "1"]) + assert cmd.command == ["salloc", "-N", "1"] + + +def test_command_getitem_int(): + with pytest.raises(TypeError): + _ = Command(command=[1]) + with pytest.raises(TypeError): + _ = Command(command=[]) + + +def test_command_getitem_slice(): + cmd = Command(command=["salloc", "-N", "1"]) + get_value = cmd[0:2] + assert get_value.command == ["salloc", "-N"] + + +def test_command_setitem_int(): + cmd = Command(command=["salloc", "-N", "1"]) + cmd[0] = "srun" + cmd[1] = "-n" + assert cmd.command == ["srun", "-n", "1"] + + +def test_command_setitem_slice(): + cmd = Command(command=["salloc", "-N", "1"]) + cmd[0:2] = ["srun", "-n"] + assert cmd.command == ["srun", "-n", "1"] + + +def test_command_setitem_fail(): + cmd = Command(command=["salloc", "-N", "1"]) + with pytest.raises(TypeError): + cmd[0] = 1 + with pytest.raises(TypeError): + cmd[0:2] = [1, "-n"] + + +def test_command_delitem(): + cmd = Command( + command=["salloc", "-N", "1", "--constraint", "P100"], + ) + del cmd.command[3] + del cmd.command[3] + assert cmd.command == ["salloc", "-N", "1"] + + +def test_command_len(): + cmd = Command(command=["salloc", "-N", "1"]) + assert len(cmd) is 3 + + +def test_command_insert(): + cmd = Command(command=["-N", "1"]) + cmd.insert(0, "salloc") + assert cmd.command == ["salloc", "-N", "1"] diff --git a/tests/temp_tests/test_core/test_commands/test_commandList.py b/tests/temp_tests/test_core/test_commands/test_commandList.py new file mode 100644 index 0000000000..37acefd8d3 --- /dev/null +++ b/tests/temp_tests/test_core/test_commands/test_commandList.py @@ -0,0 +1,99 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.commands.command import Command +from smartsim._core.commands.command_list import CommandList +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + +salloc_cmd = Command(command=["salloc", "-N", "1"]) +srun_cmd = Command(command=["srun", "-n", "1"]) +sacct_cmd = Command(command=["sacct", "--user"]) + + +def test_command_init(): + cmd_list = CommandList(commands=[salloc_cmd, srun_cmd]) + assert cmd_list.commands == [salloc_cmd, srun_cmd] + + +def test_command_getitem_int(): + cmd_list = CommandList(commands=[salloc_cmd, srun_cmd]) + get_value = cmd_list[0] + assert get_value == salloc_cmd + + +def test_command_getitem_slice(): + cmd_list = CommandList(commands=[salloc_cmd, srun_cmd]) + get_value = cmd_list[0:2] + assert get_value == [salloc_cmd, srun_cmd] + + +def test_command_setitem_idx(): + cmd_list = CommandList(commands=[salloc_cmd, srun_cmd]) + cmd_list[0] = sacct_cmd + for cmd in cmd_list.commands: + assert cmd.command in [sacct_cmd.command, srun_cmd.command] + + +def test_command_setitem_slice(): + cmd_list = CommandList(commands=[srun_cmd, srun_cmd]) + cmd_list[0:2] = [sacct_cmd, sacct_cmd] + for cmd in cmd_list.commands: + assert cmd.command == sacct_cmd.command + + +def test_command_setitem_fail(): + cmd_list = CommandList(commands=[srun_cmd, srun_cmd]) + with pytest.raises(TypeError): + cmd_list[0] = "fail" + with pytest.raises(TypeError): + cmd_list[0:1] = "fail" + with pytest.raises(TypeError): + cmd_list[0:1] = "fail" + with pytest.raises(TypeError): + _ = Command(command=["salloc", "-N", 1]) + with pytest.raises(TypeError): + cmd_list[0:1] = [Command(command=["salloc", "-N", "1"]), Command(command=1)] + + +def test_command_delitem(): + cmd_list = CommandList(commands=[salloc_cmd, srun_cmd]) + del cmd_list.commands[0] + assert cmd_list.commands == [srun_cmd] + + +def test_command_len(): + cmd_list = CommandList(commands=[salloc_cmd, srun_cmd]) + assert len(cmd_list) is 2 + + +def test_command_insert(): + cmd_list = CommandList(commands=[salloc_cmd, srun_cmd]) + cmd_list.insert(0, sacct_cmd) + assert cmd_list.commands == [sacct_cmd, salloc_cmd, srun_cmd] diff --git a/tests/temp_tests/test_core/test_commands/test_launchCommands.py b/tests/temp_tests/test_core/test_commands/test_launchCommands.py new file mode 100644 index 0000000000..60bfe4b279 --- /dev/null +++ b/tests/temp_tests/test_core/test_commands/test_launchCommands.py @@ -0,0 +1,52 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.commands.command import Command +from smartsim._core.commands.command_list import CommandList +from smartsim._core.commands.launch_commands import LaunchCommands +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + +pre_cmd = Command(command=["pre", "cmd"]) +launch_cmd = Command(command=["launch", "cmd"]) +post_cmd = Command(command=["post", "cmd"]) +pre_commands_list = CommandList(commands=[pre_cmd]) +launch_command_list = CommandList(commands=[launch_cmd]) +post_command_list = CommandList(commands=[post_cmd]) + + +def test_launchCommand_init(): + launch_cmd = LaunchCommands( + prelaunch_commands=pre_commands_list, + launch_commands=launch_command_list, + postlaunch_commands=post_command_list, + ) + assert launch_cmd.prelaunch_command == pre_commands_list + assert launch_cmd.launch_command == launch_command_list + assert launch_cmd.postlaunch_command == post_command_list diff --git a/tests/temp_tests/test_jobGroup.py b/tests/temp_tests/test_jobGroup.py new file mode 100644 index 0000000000..f735162609 --- /dev/null +++ b/tests/temp_tests/test_jobGroup.py @@ -0,0 +1,110 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim.entity.application import Application +from smartsim.launchable.base_job import BaseJob +from smartsim.launchable.job import Job +from smartsim.launchable.job_group import JobGroup +from smartsim.settings.launch_settings import LaunchSettings + +pytestmark = pytest.mark.group_a + +app_1 = Application("app_1", "python") +app_2 = Application("app_2", "python") +app_3 = Application("app_3", "python") + + +class MockJob(BaseJob): + def get_launch_steps(self): + raise NotImplementedError + + +def test_invalid_job_name(wlmutils): + job_1 = Job(app_1, LaunchSettings("slurm")) + job_2 = Job(app_2, LaunchSettings("slurm")) + with pytest.raises(ValueError): + _ = JobGroup([job_1, job_2], name="name/not/allowed") + + +def test_create_JobGroup(): + job_1 = MockJob() + job_group = JobGroup([job_1]) + assert len(job_group) == 1 + + +def test_name_setter(wlmutils): + job_1 = Job(app_1, LaunchSettings("slurm")) + job_2 = Job(app_2, LaunchSettings("slurm")) + job_group = JobGroup([job_1, job_2]) + job_group.name = "new_name" + assert job_group.name == "new_name" + + +def test_getitem_JobGroup(wlmutils): + job_1 = Job(app_1, LaunchSettings("slurm")) + job_2 = Job(app_2, LaunchSettings("slurm")) + job_group = JobGroup([job_1, job_2]) + get_value = job_group[0].entity.name + assert get_value == job_1.entity.name + + +def test_setitem_JobGroup(wlmutils): + job_1 = Job(app_1, LaunchSettings("slurm")) + job_2 = Job(app_2, LaunchSettings("slurm")) + job_group = JobGroup([job_1, job_2]) + job_3 = Job(app_3, LaunchSettings("slurm")) + job_group[1] = job_3 + assert len(job_group) == 2 + get_value = job_group[1] + assert get_value.entity.name == job_3.entity.name + + +def test_delitem_JobGroup(): + job_1 = MockJob() + job_2 = MockJob() + job_group = JobGroup([job_1, job_2]) + assert len(job_group) == 2 + del job_group[1] + assert len(job_group) == 1 + + +def test_len_JobGroup(): + job_1 = MockJob() + job_2 = MockJob() + job_group = JobGroup([job_1, job_2]) + assert len(job_group) == 2 + + +def test_insert_JobGroup(): + job_1 = MockJob() + job_2 = MockJob() + job_group = JobGroup([job_1, job_2]) + job_3 = MockJob() + job_group.insert(0, job_3) + get_value = job_group[0] + assert get_value == job_3 diff --git a/tests/temp_tests/test_launchable.py b/tests/temp_tests/test_launchable.py new file mode 100644 index 0000000000..de7d12e60e --- /dev/null +++ b/tests/temp_tests/test_launchable.py @@ -0,0 +1,306 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim import entity +from smartsim._core.utils import helpers +from smartsim.entity.application import Application +from smartsim.entity.entity import SmartSimEntity +from smartsim.error.errors import SSUnsupportedError +from smartsim.launchable import Job, Launchable +from smartsim.launchable.launchable import SmartSimObject +from smartsim.launchable.mpmd_job import MPMDJob +from smartsim.launchable.mpmd_pair import MPMDPair +from smartsim.settings import LaunchSettings + +pytestmark = pytest.mark.group_a + + +class EchoHelloWorldEntity(entity.SmartSimEntity): + """A simple smartsim entity""" + + def __init__(self): + super().__init__("test-entity") + + def __eq__(self, other): + if type(self) is not type(other): + return NotImplemented + return self.as_executable_sequence() == other.as_executable_sequence() + + def as_executable_sequence(self): + return (helpers.expand_exe_path("echo"), "Hello", "World!") + + +def test_smartsimobject_init(): + ss_object = SmartSimObject() + assert isinstance(ss_object, SmartSimObject) + + +def test_launchable_init(): + launchable = Launchable() + assert isinstance(launchable, Launchable) + + +def test_invalid_job_name(wlmutils): + entity = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + + settings = LaunchSettings(wlmutils.get_test_launcher()) + with pytest.raises(ValueError): + _ = Job(entity, settings, name="path/to/name") + + +def test_job_init(): + entity = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + job = Job(entity, LaunchSettings("slurm")) + assert isinstance(job, Job) + assert job.entity.name == "test_name" + assert "echo" in job.entity.exe + assert "spam" in job.entity.exe_args + assert "eggs" in job.entity.exe_args + + +def test_name_setter(): + entity = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + job = Job(entity, LaunchSettings("slurm")) + job.name = "new_name" + assert job.name == "new_name" + + +def test_job_init_deepcopy(): + entity = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + settings = LaunchSettings("slurm") + job = Job(entity, settings) + test = job.launch_settings.launcher + test = "test_change" + assert job.launch_settings.launcher is not test + + +def test_job_type_entity(): + entity = "invalid" + settings = LaunchSettings("slurm") + with pytest.raises( + TypeError, + match="entity argument was not of type SmartSimEntity", + ): + Job(entity, settings) + + +def test_job_type_launch_settings(): + entity = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + settings = "invalid" + + with pytest.raises( + TypeError, + match="launch_settings argument was not of type LaunchSettings", + ): + Job(entity, settings) + + +def test_add_mpmd_pair(): + entity = EchoHelloWorldEntity() + + mpmd_job = MPMDJob() + mpmd_job.add_mpmd_pair(entity, LaunchSettings("slurm")) + mpmd_pair = MPMDPair(entity, LaunchSettings("slurm")) + + assert len(mpmd_job.mpmd_pairs) == 1 + assert str(mpmd_pair.entity) == str(mpmd_job.mpmd_pairs[0].entity) + assert str(mpmd_pair.launch_settings) == str(mpmd_job.mpmd_pairs[0].launch_settings) + + +def test_mpmdpair_init(): + """Test the creation of an MPMDPair""" + entity = Application( + "test_name", + "echo", + exe_args=["spam", "eggs"], + ) + mpmd_pair = MPMDPair(entity, LaunchSettings("slurm")) + assert isinstance(mpmd_pair, MPMDPair) + assert mpmd_pair.entity.name == "test_name" + assert "echo" in mpmd_pair.entity.exe + assert "spam" in mpmd_pair.entity.exe_args + assert "eggs" in mpmd_pair.entity.exe_args + + +def test_mpmdpair_init_deepcopy(): + """Test the creation of an MPMDPair""" + entity = Application( + "test_name", + "echo", + exe_args=["spam", "eggs"], + ) + settings = LaunchSettings("slurm") + mpmd_pair = MPMDPair(entity, settings) + test = mpmd_pair.launch_settings.launcher + test = "change" + assert test not in mpmd_pair.launch_settings.launcher + + +def test_check_launcher(): + """Test that mpmd pairs that have the same launcher type can be added to an MPMD Job""" + + entity1 = Application( + "entity1", + "echo", + exe_args=["hello", "world"], + ) + launch_settings1 = LaunchSettings("slurm") + entity2 = Application( + "entity2", + "echo", + exe_args=["hello", "world"], + ) + launch_settings2 = LaunchSettings("slurm") + mpmd_pairs = [] + + pair1 = MPMDPair(entity1, launch_settings1) + mpmd_pairs.append(pair1) + mpmd_job = MPMDJob(mpmd_pairs) + # Add a second mpmd pair to the mpmd job + mpmd_job.add_mpmd_pair(entity2, launch_settings2) + + assert str(mpmd_job.mpmd_pairs[0].entity.name) == "entity1" + assert str(mpmd_job.mpmd_pairs[1].entity.name) == "entity2" + + +def test_add_mpmd_pair_check_launcher_error(): + """Test that an error is raised when a pairs is added to an mpmd + job using add_mpmd_pair that does not have the same launcher type""" + mpmd_pairs = [] + entity1 = EchoHelloWorldEntity() + launch_settings1 = LaunchSettings("slurm") + + entity2 = EchoHelloWorldEntity() + launch_settings2 = LaunchSettings("pals") + + pair1 = MPMDPair(entity1, launch_settings1) + mpmd_pairs.append(pair1) + mpmd_job = MPMDJob(mpmd_pairs) + + # Add a second mpmd pair to the mpmd job with a different launcher + with pytest.raises(SSUnsupportedError): + mpmd_job.add_mpmd_pair(entity2, launch_settings2) + + +def test_add_mpmd_pair_check_entity(): + """Test that mpmd pairs that have the same entity type can be added to an MPMD Job""" + mpmd_pairs = [] + entity1 = Application("entity1", "python") + launch_settings1 = LaunchSettings("slurm") + + entity2 = Application("entity2", "python") + launch_settings2 = LaunchSettings("slurm") + + pair1 = MPMDPair(entity1, launch_settings1) + mpmd_pairs.append(pair1) + mpmd_job = MPMDJob(mpmd_pairs) + + # Add a second mpmd pair to the mpmd job + mpmd_job.add_mpmd_pair(entity2, launch_settings2) + + assert isinstance(mpmd_job, MPMDJob) + + +def test_add_mpmd_pair_check_entity_error(): + """Test that an error is raised when a pairs is added to an mpmd job + using add_mpmd_pair that does not have the same entity type""" + mpmd_pairs = [] + entity1 = Application("entity1", "python") + launch_settings1 = LaunchSettings("slurm") + + entity2 = Application("entity2", "python") + launch_settings2 = LaunchSettings("pals") + + pair1 = MPMDPair(entity1, launch_settings1) + mpmd_pairs.append(pair1) + mpmd_job = MPMDJob(mpmd_pairs) + + with pytest.raises(SSUnsupportedError) as ex: + mpmd_job.add_mpmd_pair(entity2, launch_settings2) + assert "MPMD pairs must all share the same entity type." in ex.value.args[0] + + +def test_create_mpmdjob_invalid_mpmdpairs(): + """Test that an error is raised when a pairs is added to an mpmd job that + does not have the same launcher type""" + + mpmd_pairs = [] + entity1 = Application("entity1", "python") + launch_settings1 = LaunchSettings("slurm") + + entity1 = Application("entity1", "python") + launch_settings2 = LaunchSettings("pals") + + pair1 = MPMDPair(entity1, launch_settings1) + pair2 = MPMDPair(entity1, launch_settings2) + + mpmd_pairs.append(pair1) + mpmd_pairs.append(pair2) + + with pytest.raises(SSUnsupportedError) as ex: + MPMDJob(mpmd_pairs) + assert "MPMD pairs must all share the same launcher." in ex.value.args[0] + + +def test_create_mpmdjob_valid_mpmdpairs(): + """Test that all pairs have the same entity type is enforced when creating an MPMDJob""" + + mpmd_pairs = [] + entity1 = Application("entity1", "python") + launch_settings1 = LaunchSettings("slurm") + entity1 = Application("entity1", "python") + launch_settings2 = LaunchSettings("slurm") + + pair1 = MPMDPair(entity1, launch_settings1) + pair2 = MPMDPair(entity1, launch_settings2) + + mpmd_pairs.append(pair1) + mpmd_pairs.append(pair2) + mpmd_job = MPMDJob(mpmd_pairs) + + assert isinstance(mpmd_job, MPMDJob) diff --git a/tests/temp_tests/test_settings/conftest.py b/tests/temp_tests/test_settings/conftest.py new file mode 100644 index 0000000000..8697b15108 --- /dev/null +++ b/tests/temp_tests/test_settings/conftest.py @@ -0,0 +1,61 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.utils.launcher import LauncherProtocol, create_job_id +from smartsim.settings.arguments import launch_arguments as launch + + +@pytest.fixture +def mock_launch_args(): + class _MockLaunchArgs(launch.LaunchArguments): + def set(self, arg, val): ... + def launcher_str(self): + return "mock-laucnh-args" + + yield _MockLaunchArgs({}) + + +@pytest.fixture +def mock_launcher(): + class _MockLauncher(LauncherProtocol): + __hash__ = object.__hash__ + + def start(self, launchable): + return create_job_id() + + @classmethod + def create(cls, exp): + return cls() + + def get_status(self, *ids): + raise NotImplementedError + + def stop_jobs(self, *ids): + raise NotImplementedError + + yield _MockLauncher() diff --git a/tests/temp_tests/test_settings/test_alpsLauncher.py b/tests/temp_tests/test_settings/test_alpsLauncher.py new file mode 100644 index 0000000000..5abfbb9c76 --- /dev/null +++ b/tests/temp_tests/test_settings/test_alpsLauncher.py @@ -0,0 +1,232 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import io +import os +import pathlib + +import pytest + +from smartsim._core.shell.shell_launcher import ShellLauncherCommand +from smartsim.settings import LaunchSettings +from smartsim.settings.arguments.launch.alps import ( + AprunLaunchArguments, + _as_aprun_command, +) +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + + +def test_launcher_str(): + """Ensure launcher_str returns appropriate value""" + alpsLauncher = LaunchSettings(launcher=LauncherType.Alps) + assert alpsLauncher.launch_args.launcher_str() == LauncherType.Alps.value + + +@pytest.mark.parametrize( + "function,value,result,flag", + [ + pytest.param( + "set_cpus_per_task", (4,), "4", "cpus-per-pe", id="set_cpus_per_task" + ), + pytest.param("set_tasks", (4,), "4", "pes", id="set_tasks"), + pytest.param( + "set_tasks_per_node", (4,), "4", "pes-per-node", id="set_tasks_per_node" + ), + pytest.param( + "set_hostlist", ("host_A",), "host_A", "node-list", id="set_hostlist_str" + ), + pytest.param( + "set_hostlist", + (["host_A", "host_B"],), + "host_A,host_B", + "node-list", + id="set_hostlist_list[str]", + ), + pytest.param( + "set_hostlist_from_file", + ("./path/to/hostfile",), + "./path/to/hostfile", + "node-list-file", + id="set_hostlist_from_file", + ), + pytest.param( + "set_excluded_hosts", + ("host_A",), + "host_A", + "exclude-node-list", + id="set_excluded_hosts_str", + ), + pytest.param( + "set_excluded_hosts", + (["host_A", "host_B"],), + "host_A,host_B", + "exclude-node-list", + id="set_excluded_hosts_list[str]", + ), + pytest.param( + "set_cpu_bindings", (4,), "4", "cpu-binding", id="set_cpu_bindings" + ), + pytest.param( + "set_cpu_bindings", + ([4, 4],), + "4,4", + "cpu-binding", + id="set_cpu_bindings_list[str]", + ), + pytest.param( + "set_memory_per_node", + (8000,), + "8000", + "memory-per-pe", + id="set_memory_per_node", + ), + pytest.param( + "set_walltime", + ("10:00:00",), + "10:00:00", + "cpu-time-limit", + id="set_walltime", + ), + pytest.param( + "set_verbose_launch", (True,), "7", "debug", id="set_verbose_launch" + ), + pytest.param("set_quiet_launch", (True,), None, "quiet", id="set_quiet_launch"), + ], +) +def test_alps_class_methods(function, value, flag, result): + alpsLauncher = LaunchSettings(launcher=LauncherType.Alps) + assert isinstance(alpsLauncher._arguments, AprunLaunchArguments) + getattr(alpsLauncher.launch_args, function)(*value) + assert alpsLauncher.launch_args._launch_args[flag] == result + + +def test_set_verbose_launch(): + alpsLauncher = LaunchSettings(launcher=LauncherType.Alps) + assert isinstance(alpsLauncher._arguments, AprunLaunchArguments) + alpsLauncher.launch_args.set_verbose_launch(True) + assert alpsLauncher.launch_args._launch_args == {"debug": "7"} + alpsLauncher.launch_args.set_verbose_launch(False) + assert alpsLauncher.launch_args._launch_args == {} + + +def test_set_quiet_launch(): + aprunLauncher = LaunchSettings(launcher=LauncherType.Alps) + assert isinstance(aprunLauncher._arguments, AprunLaunchArguments) + aprunLauncher.launch_args.set_quiet_launch(True) + assert aprunLauncher.launch_args._launch_args == {"quiet": None} + aprunLauncher.launch_args.set_quiet_launch(False) + assert aprunLauncher.launch_args._launch_args == {} + + +def test_format_env_vars(): + env_vars = {"OMP_NUM_THREADS": "20", "LOGGING": "verbose"} + aprunLauncher = LaunchSettings(launcher=LauncherType.Alps, env_vars=env_vars) + assert isinstance(aprunLauncher._arguments, AprunLaunchArguments) + aprunLauncher.update_env({"OMP_NUM_THREADS": "10"}) + formatted = aprunLauncher._arguments.format_env_vars(aprunLauncher._env_vars) + result = ["-e", "OMP_NUM_THREADS=10", "-e", "LOGGING=verbose"] + assert formatted == result + + +def test_aprun_settings(): + aprunLauncher = LaunchSettings(launcher=LauncherType.Alps) + aprunLauncher.launch_args.set_cpus_per_task(2) + aprunLauncher.launch_args.set_tasks(100) + aprunLauncher.launch_args.set_tasks_per_node(20) + formatted = aprunLauncher._arguments.format_launch_args() + result = ["--cpus-per-pe=2", "--pes=100", "--pes-per-node=20"] + assert formatted == result + + +def test_invalid_hostlist_format(): + """Test invalid hostlist formats""" + alpsLauncher = LaunchSettings(launcher=LauncherType.Alps) + with pytest.raises(TypeError): + alpsLauncher.launch_args.set_hostlist(["test", 5]) + with pytest.raises(TypeError): + alpsLauncher.launch_args.set_hostlist([5]) + with pytest.raises(TypeError): + alpsLauncher.launch_args.set_hostlist(5) + + +def test_invalid_exclude_hostlist_format(): + """Test invalid hostlist formats""" + alpsLauncher = LaunchSettings(launcher=LauncherType.Alps) + with pytest.raises(TypeError): + alpsLauncher.launch_args.set_excluded_hosts(["test", 5]) + with pytest.raises(TypeError): + alpsLauncher.launch_args.set_excluded_hosts([5]) + with pytest.raises(TypeError): + alpsLauncher.launch_args.set_excluded_hosts(5) + + +@pytest.mark.parametrize( + "args, expected", + ( + pytest.param({}, ("aprun", "--", "echo", "hello", "world"), id="Empty Args"), + pytest.param( + {"N": "1"}, + ("aprun", "-N", "1", "--", "echo", "hello", "world"), + id="Short Arg", + ), + pytest.param( + {"cpus-per-pe": "1"}, + ("aprun", "--cpus-per-pe=1", "--", "echo", "hello", "world"), + id="Long Arg", + ), + pytest.param( + {"q": None}, + ("aprun", "-q", "--", "echo", "hello", "world"), + id="Short Arg (No Value)", + ), + pytest.param( + {"quiet": None}, + ("aprun", "--quiet", "--", "echo", "hello", "world"), + id="Long Arg (No Value)", + ), + pytest.param( + {"N": "1", "cpus-per-pe": "123"}, + ("aprun", "-N", "1", "--cpus-per-pe=123", "--", "echo", "hello", "world"), + id="Short and Long Args", + ), + ), +) +def test_formatting_launch_args(args, expected, test_dir): + out = os.path.join(test_dir, "out.txt") + err = os.path.join(test_dir, "err.txt") + open(out, "w"), open(err, "w") + shell_launch_cmd = _as_aprun_command( + AprunLaunchArguments(args), ("echo", "hello", "world"), test_dir, {}, out, err + ) + assert isinstance(shell_launch_cmd, ShellLauncherCommand) + assert shell_launch_cmd.command_tuple == expected + assert shell_launch_cmd.path == pathlib.Path(test_dir) + assert shell_launch_cmd.env == {} + assert isinstance(shell_launch_cmd.stdout, io.TextIOWrapper) + assert shell_launch_cmd.stdout.name == out + assert isinstance(shell_launch_cmd.stderr, io.TextIOWrapper) + assert shell_launch_cmd.stderr.name == err diff --git a/tests/temp_tests/test_settings/test_batchSettings.py b/tests/temp_tests/test_settings/test_batchSettings.py new file mode 100644 index 0000000000..37fd3a33f2 --- /dev/null +++ b/tests/temp_tests/test_settings/test_batchSettings.py @@ -0,0 +1,80 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +from smartsim.settings import BatchSettings +from smartsim.settings.batch_command import BatchSchedulerType + +pytestmark = pytest.mark.group_a + + +@pytest.mark.parametrize( + "scheduler_enum,formatted_batch_args", + [ + pytest.param( + BatchSchedulerType.Slurm, ["--launch=var", "--nodes=1"], id="slurm" + ), + pytest.param( + BatchSchedulerType.Pbs, ["-l", "nodes=1", "-launch", "var"], id="pbs" + ), + pytest.param( + BatchSchedulerType.Lsf, ["-launch", "var", "-nnodes", "1"], id="lsf" + ), + ], +) +def test_create_scheduler_settings(scheduler_enum, formatted_batch_args): + bs_str = BatchSettings( + batch_scheduler=scheduler_enum.value, + batch_args={"launch": "var"}, + env_vars={"ENV": "VAR"}, + ) + bs_str.batch_args.set_nodes(1) + assert bs_str._batch_scheduler == scheduler_enum + assert bs_str._env_vars == {"ENV": "VAR"} + print(bs_str.format_batch_args()) + assert bs_str.format_batch_args() == formatted_batch_args + + bs_enum = BatchSettings( + batch_scheduler=scheduler_enum, + batch_args={"launch": "var"}, + env_vars={"ENV": "VAR"}, + ) + bs_enum.batch_args.set_nodes(1) + assert bs_enum._batch_scheduler == scheduler_enum + assert bs_enum._env_vars == {"ENV": "VAR"} + assert bs_enum.format_batch_args() == formatted_batch_args + + +def test_launcher_property(): + bs = BatchSettings(batch_scheduler="slurm") + assert bs.batch_scheduler == "slurm" + + +def test_env_vars_property(): + bs = BatchSettings(batch_scheduler="slurm", env_vars={"ENV": "VAR"}) + assert bs.env_vars == {"ENV": "VAR"} + ref = bs.env_vars + assert ref is bs.env_vars diff --git a/tests/temp_tests/test_settings/test_common.py b/tests/temp_tests/test_settings/test_common.py new file mode 100644 index 0000000000..17ca66c040 --- /dev/null +++ b/tests/temp_tests/test_settings/test_common.py @@ -0,0 +1,39 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +from smartsim.settings.common import set_check_input + +pytestmark = pytest.mark.group_a + + +def test_check_set_raise_error(): + with pytest.raises(TypeError): + set_check_input(key="test", value=3) + with pytest.raises(TypeError): + set_check_input(key=3, value="str") + with pytest.raises(TypeError): + set_check_input(key=2, value=None) diff --git a/tests/temp_tests/test_settings/test_dispatch.py b/tests/temp_tests/test_settings/test_dispatch.py new file mode 100644 index 0000000000..89303b5a37 --- /dev/null +++ b/tests/temp_tests/test_settings/test_dispatch.py @@ -0,0 +1,419 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import abc +import contextlib +import dataclasses +import io +import sys + +import pytest + +from smartsim._core import dispatch +from smartsim._core.utils.launcher import LauncherProtocol, create_job_id +from smartsim.error import errors + +pytestmark = pytest.mark.group_a + +FORMATTED = object() + + +def format_fn(args, exe, env): + return FORMATTED + + +@pytest.fixture +def expected_dispatch_registry(mock_launcher, mock_launch_args): + yield { + type(mock_launch_args): dispatch._DispatchRegistration( + format_fn, type(mock_launcher) + ) + } + + +def test_declaritive_form_dispatch_declaration( + mock_launcher, mock_launch_args, expected_dispatch_registry +): + d = dispatch.Dispatcher() + assert type(mock_launch_args) == d.dispatch( + with_format=format_fn, to_launcher=type(mock_launcher) + )(type(mock_launch_args)) + assert d._dispatch_registry == expected_dispatch_registry + + +def test_imperative_form_dispatch_declaration( + mock_launcher, mock_launch_args, expected_dispatch_registry +): + d = dispatch.Dispatcher() + assert None == d.dispatch( + type(mock_launch_args), to_launcher=type(mock_launcher), with_format=format_fn + ) + assert d._dispatch_registry == expected_dispatch_registry + + +def test_dispatchers_from_same_registry_do_not_cross_polute( + mock_launcher, mock_launch_args, expected_dispatch_registry +): + some_starting_registry = {} + d1 = dispatch.Dispatcher(dispatch_registry=some_starting_registry) + d2 = dispatch.Dispatcher(dispatch_registry=some_starting_registry) + assert ( + d1._dispatch_registry == d2._dispatch_registry == some_starting_registry == {} + ) + assert ( + d1._dispatch_registry is not d2._dispatch_registry is not some_starting_registry + ) + + d2.dispatch( + type(mock_launch_args), with_format=format_fn, to_launcher=type(mock_launcher) + ) + assert d1._dispatch_registry == {} + assert d2._dispatch_registry == expected_dispatch_registry + + +def test_copied_dispatchers_do_not_cross_pollute( + mock_launcher, mock_launch_args, expected_dispatch_registry +): + some_starting_registry = {} + d1 = dispatch.Dispatcher(dispatch_registry=some_starting_registry) + d2 = d1.copy() + assert ( + d1._dispatch_registry == d2._dispatch_registry == some_starting_registry == {} + ) + assert ( + d1._dispatch_registry is not d2._dispatch_registry is not some_starting_registry + ) + + d2.dispatch( + type(mock_launch_args), to_launcher=type(mock_launcher), with_format=format_fn + ) + assert d1._dispatch_registry == {} + assert d2._dispatch_registry == expected_dispatch_registry + + +@pytest.mark.parametrize( + "add_dispatch, expected_ctx", + ( + pytest.param( + lambda d, s, l: d.dispatch(s, to_launcher=l, with_format=format_fn), + pytest.raises(TypeError, match="has already been registered"), + id="Imperative -- Disallowed implicitly", + ), + pytest.param( + lambda d, s, l: d.dispatch( + s, to_launcher=l, with_format=format_fn, allow_overwrite=True + ), + contextlib.nullcontext(), + id="Imperative -- Allowed with flag", + ), + pytest.param( + lambda d, s, l: d.dispatch(to_launcher=l, with_format=format_fn)(s), + pytest.raises(TypeError, match="has already been registered"), + id="Declarative -- Disallowed implicitly", + ), + pytest.param( + lambda d, s, l: d.dispatch( + to_launcher=l, with_format=format_fn, allow_overwrite=True + )(s), + contextlib.nullcontext(), + id="Declarative -- Allowed with flag", + ), + ), +) +def test_dispatch_overwriting( + add_dispatch, + expected_ctx, + mock_launcher, + mock_launch_args, + expected_dispatch_registry, +): + d = dispatch.Dispatcher(dispatch_registry=expected_dispatch_registry) + with expected_ctx: + add_dispatch(d, type(mock_launch_args), type(mock_launcher)) + + +@pytest.mark.parametrize( + "type_or_instance", + ( + pytest.param(type, id="type"), + pytest.param(lambda x: x, id="instance"), + ), +) +def test_dispatch_can_retrieve_dispatch_info_from_dispatch_registry( + expected_dispatch_registry, mock_launcher, mock_launch_args, type_or_instance +): + d = dispatch.Dispatcher(dispatch_registry=expected_dispatch_registry) + assert dispatch._DispatchRegistration( + format_fn, type(mock_launcher) + ) == d.get_dispatch(type_or_instance(mock_launch_args)) + + +@pytest.mark.parametrize( + "type_or_instance", + ( + pytest.param(type, id="type"), + pytest.param(lambda x: x, id="instance"), + ), +) +def test_dispatch_raises_if_settings_type_not_registered( + mock_launch_args, type_or_instance +): + d = dispatch.Dispatcher(dispatch_registry={}) + with pytest.raises( + TypeError, match="No dispatch for `.+?(?=`)` has been registered" + ): + d.get_dispatch(type_or_instance(mock_launch_args)) + + +class LauncherABC(abc.ABC): + @abc.abstractmethod + def start(self, launchable): ... + @classmethod + @abc.abstractmethod + def create(cls, exp): ... + + +class PartImplLauncherABC(LauncherABC): + def start(self, launchable): + return create_job_id() + + +class FullImplLauncherABC(PartImplLauncherABC): + @classmethod + def create(cls, exp): + return cls() + + +@pytest.mark.parametrize( + "cls, ctx", + ( + pytest.param( + LauncherProtocol, + pytest.raises(TypeError, match="Cannot dispatch to protocol"), + id="Cannot dispatch to protocol class", + ), + pytest.param( + "mock_launcher", + contextlib.nullcontext(None), + id="Can dispatch to protocol implementation", + ), + pytest.param( + LauncherABC, + pytest.raises(TypeError, match="Cannot dispatch to abstract class"), + id="Cannot dispatch to abstract class", + ), + pytest.param( + PartImplLauncherABC, + pytest.raises(TypeError, match="Cannot dispatch to abstract class"), + id="Cannot dispatch to partially implemented abstract class", + ), + pytest.param( + FullImplLauncherABC, + contextlib.nullcontext(None), + id="Can dispatch to fully implemented abstract class", + ), + ), +) +def test_register_dispatch_to_launcher_types(request, cls, ctx): + if isinstance(cls, str): + cls = request.getfixturevalue(cls) + d = dispatch.Dispatcher() + with ctx: + d.dispatch(to_launcher=cls, with_format=format_fn) + + +@dataclasses.dataclass(frozen=True) +class BufferWriterLauncher(LauncherProtocol[list[str]]): + buf: io.StringIO + + if sys.version_info < (3, 10): + __hash__ = object.__hash__ + + @classmethod + def create(cls, exp): + return cls(io.StringIO()) + + def start(self, strs): + self.buf.writelines(f"{s}\n" for s in strs) + return create_job_id() + + def get_status(self, *ids): + raise NotImplementedError + + def stop_jobs(self, *ids): + raise NotImplementedError + + +class BufferWriterLauncherSubclass(BufferWriterLauncher): ... + + +@pytest.fixture +def buffer_writer_dispatch(): + stub_format_fn = lambda *a, **kw: ["some", "strings"] + return dispatch._DispatchRegistration(stub_format_fn, BufferWriterLauncher) + + +@pytest.mark.parametrize( + "input_, map_, expected", + ( + pytest.param( + ["list", "of", "strings"], + lambda xs: xs, + ["list\n", "of\n", "strings\n"], + id="[str] -> [str]", + ), + pytest.param( + "words on new lines", + lambda x: x.split(), + ["words\n", "on\n", "new\n", "lines\n"], + id="str -> [str]", + ), + pytest.param( + range(1, 4), + lambda xs: [str(x) for x in xs], + ["1\n", "2\n", "3\n"], + id="[int] -> [str]", + ), + ), +) +def test_launcher_adapter_correctly_adapts_input_to_launcher(input_, map_, expected): + buf = io.StringIO() + adapter = dispatch._LauncherAdapter(BufferWriterLauncher(buf), map_) + adapter.start(input_) + buf.seek(0) + assert buf.readlines() == expected + + +@pytest.mark.parametrize( + "launcher_instance, ctx", + ( + pytest.param( + BufferWriterLauncher(io.StringIO()), + contextlib.nullcontext(None), + id="Correctly configures expected launcher", + ), + pytest.param( + BufferWriterLauncherSubclass(io.StringIO()), + pytest.raises( + TypeError, + match="^Cannot create launcher adapter.*expected launcher of type .+$", + ), + id="Errors if launcher types are disparate", + ), + pytest.param( + "mock_launcher", + pytest.raises( + TypeError, + match="^Cannot create launcher adapter.*expected launcher of type .+$", + ), + id="Errors if types are not an exact match", + ), + ), +) +def test_dispatch_registration_can_configure_adapter_for_existing_launcher_instance( + request, mock_launch_args, buffer_writer_dispatch, launcher_instance, ctx +): + if isinstance(launcher_instance, str): + launcher_instance = request.getfixturevalue(launcher_instance) + with ctx: + adapter = buffer_writer_dispatch.create_adapter_from_launcher( + launcher_instance, mock_launch_args + ) + assert adapter._adapted_launcher is launcher_instance + + +@pytest.mark.parametrize( + "launcher_instances, ctx", + ( + pytest.param( + (BufferWriterLauncher(io.StringIO()),), + contextlib.nullcontext(None), + id="Correctly configures expected launcher", + ), + pytest.param( + ( + "mock_launcher", + "mock_launcher", + BufferWriterLauncher(io.StringIO()), + "mock_launcher", + ), + contextlib.nullcontext(None), + id="Correctly ignores incompatible launchers instances", + ), + pytest.param( + (), + pytest.raises( + errors.LauncherNotFoundError, + match="^No launcher of exactly type.+could be found from provided launchers$", + ), + id="Errors if no launcher could be found", + ), + pytest.param( + ( + "mock_launcher", + BufferWriterLauncherSubclass(io.StringIO), + "mock_launcher", + ), + pytest.raises( + errors.LauncherNotFoundError, + match="^No launcher of exactly type.+could be found from provided launchers$", + ), + id="Errors if no launcher matches expected type exactly", + ), + ), +) +def test_dispatch_registration_configures_first_compatible_launcher_from_sequence_of_launchers( + request, mock_launch_args, buffer_writer_dispatch, launcher_instances, ctx +): + def resolve_instance(inst): + return request.getfixturevalue(inst) if isinstance(inst, str) else inst + + launcher_instances = tuple(map(resolve_instance, launcher_instances)) + + with ctx: + adapter = buffer_writer_dispatch.configure_first_compatible_launcher( + with_arguments=mock_launch_args, from_available_launchers=launcher_instances + ) + + +def test_dispatch_registration_can_create_a_laucher_for_an_experiment_and_can_reconfigure_it_later( + mock_launch_args, buffer_writer_dispatch +): + class MockExperiment: ... + + exp = MockExperiment() + adapter_1 = buffer_writer_dispatch.create_new_launcher_configuration( + for_experiment=exp, with_arguments=mock_launch_args + ) + assert type(adapter_1._adapted_launcher) == buffer_writer_dispatch.launcher_type + existing_launcher = adapter_1._adapted_launcher + + adapter_2 = buffer_writer_dispatch.create_adapter_from_launcher( + existing_launcher, mock_launch_args + ) + assert type(adapter_2._adapted_launcher) == buffer_writer_dispatch.launcher_type + assert adapter_1._adapted_launcher is adapter_2._adapted_launcher + assert adapter_1 is not adapter_2 diff --git a/tests/temp_tests/test_settings/test_dragonLauncher.py b/tests/temp_tests/test_settings/test_dragonLauncher.py new file mode 100644 index 0000000000..a7685e18e7 --- /dev/null +++ b/tests/temp_tests/test_settings/test_dragonLauncher.py @@ -0,0 +1,116 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +from smartsim._core.launcher.dragon.dragon_launcher import ( + _as_run_request_args_and_policy, +) +from smartsim._core.schemas.dragon_requests import DragonRunPolicy, DragonRunRequestView +from smartsim.settings import LaunchSettings +from smartsim.settings.arguments.launch.dragon import DragonLaunchArguments +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + + +def test_launcher_str(): + """Ensure launcher_str returns appropriate value""" + ls = LaunchSettings(launcher=LauncherType.Dragon) + assert ls.launch_args.launcher_str() == LauncherType.Dragon.value + + +@pytest.mark.parametrize( + "function,value,result,flag", + [ + pytest.param("set_nodes", (2,), "2", "nodes", id="set_nodes"), + pytest.param( + "set_tasks_per_node", (2,), "2", "tasks_per_node", id="set_tasks_per_node" + ), + ], +) +def test_dragon_class_methods(function, value, flag, result): + dragonLauncher = LaunchSettings(launcher=LauncherType.Dragon) + assert isinstance(dragonLauncher._arguments, DragonLaunchArguments) + getattr(dragonLauncher.launch_args, function)(*value) + assert dragonLauncher.launch_args._launch_args[flag] == result + + +NOT_SET = object() + + +@pytest.mark.parametrize("nodes", (NOT_SET, 20, 40)) +@pytest.mark.parametrize("tasks_per_node", (NOT_SET, 1, 20)) +@pytest.mark.parametrize("cpu_affinity", (NOT_SET, [1], [1, 2, 3])) +@pytest.mark.parametrize("gpu_affinity", (NOT_SET, [1], [1, 2, 3])) +def test_formatting_launch_args_into_request( + nodes, tasks_per_node, cpu_affinity, gpu_affinity, test_dir +): + launch_args = DragonLaunchArguments({}) + if nodes is not NOT_SET: + launch_args.set_nodes(nodes) + if tasks_per_node is not NOT_SET: + launch_args.set_tasks_per_node(tasks_per_node) + if cpu_affinity is not NOT_SET: + launch_args.set_cpu_affinity(cpu_affinity) + if gpu_affinity is not NOT_SET: + launch_args.set_gpu_affinity(gpu_affinity) + req, policy = _as_run_request_args_and_policy( + launch_args, ("echo", "hello", "world"), test_dir, {}, "output.txt", "error.txt" + ) + + expected_args = { + k: v + for k, v in { + "nodes": nodes, + "tasks_per_node": tasks_per_node, + }.items() + if v is not NOT_SET + } + expected_run_req = DragonRunRequestView( + exe="echo", + exe_args=["hello", "world"], + path=test_dir, + env={}, + output_file="output.txt", + error_file="error.txt", + **expected_args, + ) + assert req.exe == expected_run_req.exe + assert req.exe_args == expected_run_req.exe_args + assert req.nodes == expected_run_req.nodes + assert req.tasks_per_node == expected_run_req.tasks_per_node + assert req.hostlist == expected_run_req.hostlist + assert req.pmi_enabled == expected_run_req.pmi_enabled + assert req.path == expected_run_req.path + assert req.output_file == expected_run_req.output_file + assert req.error_file == expected_run_req.error_file + + expected_run_policy_args = { + k: v + for k, v in {"cpu_affinity": cpu_affinity, "gpu_affinity": gpu_affinity}.items() + if v is not NOT_SET + } + assert policy == DragonRunPolicy(**expected_run_policy_args) diff --git a/tests/temp_tests/test_settings/test_launchSettings.py b/tests/temp_tests/test_settings/test_launchSettings.py new file mode 100644 index 0000000000..3fc5e544a9 --- /dev/null +++ b/tests/temp_tests/test_settings/test_launchSettings.py @@ -0,0 +1,89 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import logging + +import pytest + +from smartsim.settings import LaunchSettings +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + + +@pytest.mark.parametrize( + "launch_enum", + [pytest.param(type_, id=type_.value) for type_ in LauncherType], +) +def test_create_launch_settings(launch_enum): + ls_str = LaunchSettings( + launcher=launch_enum.value, + launch_args={"launch": "var"}, + env_vars={"ENV": "VAR"}, + ) + assert ls_str._launcher == launch_enum + # TODO need to test launch_args + assert ls_str._env_vars == {"ENV": "VAR"} + + ls_enum = LaunchSettings( + launcher=launch_enum, launch_args={"launch": "var"}, env_vars={"ENV": "VAR"} + ) + assert ls_enum._launcher == launch_enum + # TODO need to test launch_args + assert ls_enum._env_vars == {"ENV": "VAR"} + + +def test_launcher_property(): + ls = LaunchSettings(launcher="local") + assert ls.launcher == "local" + + +def test_env_vars_property(): + ls = LaunchSettings(launcher="local", env_vars={"ENV": "VAR"}) + assert ls.env_vars == {"ENV": "VAR"} + ref = ls.env_vars + assert ref is ls.env_vars + + +def test_update_env_vars(): + ls = LaunchSettings(launcher="local", env_vars={"ENV": "VAR"}) + ls.update_env({"test": "no_update"}) + assert ls.env_vars == {"ENV": "VAR", "test": "no_update"} + + +def test_update_env_vars_errors(): + ls = LaunchSettings(launcher="local", env_vars={"ENV": "VAR"}) + with pytest.raises(TypeError): + ls.update_env({"test": 1}) + with pytest.raises(TypeError): + ls.update_env({1: "test"}) + with pytest.raises(TypeError): + ls.update_env({1: 1}) + with pytest.raises(TypeError): + # Make sure the first key and value do not assign + # and that the function is atomic + ls.update_env({"test": "test", "test": 1}) + assert ls.env_vars == {"ENV": "VAR"} diff --git a/tests/temp_tests/test_settings/test_localLauncher.py b/tests/temp_tests/test_settings/test_localLauncher.py new file mode 100644 index 0000000000..6576b2249c --- /dev/null +++ b/tests/temp_tests/test_settings/test_localLauncher.py @@ -0,0 +1,169 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import io +import os +import pathlib + +import pytest + +from smartsim._core.shell.shell_launcher import ShellLauncherCommand +from smartsim.settings import LaunchSettings +from smartsim.settings.arguments.launch.local import ( + LocalLaunchArguments, + _as_local_command, +) +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + + +def test_launcher_str(): + """Ensure launcher_str returns appropriate value""" + ls = LaunchSettings(launcher=LauncherType.Local) + assert ls.launch_args.launcher_str() == LauncherType.Local.value + + +# TODO complete after launch args retrieval +def test_launch_args_input_mutation(): + # Tests that the run args passed in are not modified after initialization + key0, key1, key2 = "arg0", "arg1", "arg2" + val0, val1, val2 = "val0", "val1", "val2" + + default_launcher_args = { + key0: val0, + key1: val1, + key2: val2, + } + localLauncher = LaunchSettings( + launcher=LauncherType.Local, launch_args=default_launcher_args + ) + + # Confirm initial values are set + assert localLauncher.launch_args._launch_args[key0] == val0 + assert localLauncher.launch_args._launch_args[key1] == val1 + assert localLauncher.launch_args._launch_args[key2] == val2 + + # Update our common run arguments + val2_upd = f"not-{val2}" + default_launcher_args[key2] = val2_upd + + # Confirm previously created run settings are not changed + assert localLauncher.launch_args._launch_args[key2] == val2 + + +@pytest.mark.parametrize( + "env_vars", + [ + pytest.param({}, id="no env vars"), + pytest.param({"env1": "abc"}, id="normal var"), + pytest.param({"env1": "abc,def"}, id="compound var"), + pytest.param({"env1": "xyz", "env2": "pqr"}, id="multiple env vars"), + ], +) +def test_update_env(env_vars): + """Ensure non-initialized env vars update correctly""" + localLauncher = LaunchSettings(launcher=LauncherType.Local) + localLauncher.update_env(env_vars) + + assert len(localLauncher.env_vars) == len(env_vars.keys()) + + +def test_format_launch_args(): + localLauncher = LaunchSettings(launcher=LauncherType.Local, launch_args={"-np": 2}) + launch_args = localLauncher._arguments.format_launch_args() + assert launch_args == ["-np", "2"] + + +@pytest.mark.parametrize( + "env_vars", + [ + pytest.param({"env1": {"abc"}}, id="set value not allowed"), + pytest.param({"env1": {"abc": "def"}}, id="dict value not allowed"), + ], +) +def test_update_env_null_valued(env_vars): + """Ensure validation of env var in update""" + orig_env = {} + + with pytest.raises(TypeError) as ex: + localLauncher = LaunchSettings(launcher=LauncherType.Local, env_vars=orig_env) + localLauncher.update_env(env_vars) + + +@pytest.mark.parametrize( + "env_vars", + [ + pytest.param({}, id="no env vars"), + pytest.param({"env1": "abc"}, id="normal var"), + pytest.param({"env1": "abc,def"}, id="compound var"), + pytest.param({"env1": "xyz", "env2": "pqr"}, id="multiple env vars"), + ], +) +def test_update_env_initialized(env_vars): + """Ensure update of initialized env vars does not overwrite""" + orig_env = {"key": "value"} + localLauncher = LaunchSettings(launcher=LauncherType.Local, env_vars=orig_env) + localLauncher.update_env(env_vars) + + combined_keys = {k for k in env_vars.keys()} + combined_keys.update(k for k in orig_env.keys()) + + assert len(localLauncher.env_vars) == len(combined_keys) + assert {k for k in localLauncher.env_vars.keys()} == combined_keys + + +def test_format_env_vars(): + env_vars = { + "A": "a", + "B": None, + "C": "", + "D": "12", + } + localLauncher = LaunchSettings(launcher=LauncherType.Local, env_vars=env_vars) + assert isinstance(localLauncher._arguments, LocalLaunchArguments) + assert localLauncher._arguments.format_env_vars(env_vars) == [ + "A=a", + "B=", + "C=", + "D=12", + ] + + +def test_formatting_returns_original_exe(test_dir): + out = os.path.join(test_dir, "out.txt") + err = os.path.join(test_dir, "err.txt") + open(out, "w"), open(err, "w") + shell_launch_cmd = _as_local_command( + LocalLaunchArguments({}), ("echo", "hello", "world"), test_dir, {}, out, err + ) + assert isinstance(shell_launch_cmd, ShellLauncherCommand) + assert shell_launch_cmd.command_tuple == ("echo", "hello", "world") + assert shell_launch_cmd.path == pathlib.Path(test_dir) + assert shell_launch_cmd.env == {} + assert isinstance(shell_launch_cmd.stdout, io.TextIOWrapper) + assert shell_launch_cmd.stdout.name == out + assert isinstance(shell_launch_cmd.stderr, io.TextIOWrapper) + assert shell_launch_cmd.stderr.name == err diff --git a/tests/temp_tests/test_settings/test_lsfLauncher.py b/tests/temp_tests/test_settings/test_lsfLauncher.py new file mode 100644 index 0000000000..549c2483b4 --- /dev/null +++ b/tests/temp_tests/test_settings/test_lsfLauncher.py @@ -0,0 +1,199 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import subprocess + +import pytest + +from smartsim.settings import LaunchSettings +from smartsim.settings.arguments.launch.lsf import ( + JsrunLaunchArguments, + _as_jsrun_command, +) +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + + +def test_launcher_str(): + """Ensure launcher_str returns appropriate value""" + ls = LaunchSettings(launcher=LauncherType.Lsf) + assert ls.launch_args.launcher_str() == LauncherType.Lsf.value + + +@pytest.mark.parametrize( + "function,value,result,flag", + [ + pytest.param("set_tasks", (2,), "2", "np", id="set_tasks"), + pytest.param( + "set_binding", ("packed:21",), "packed:21", "bind", id="set_binding" + ), + ], +) +def test_lsf_class_methods(function, value, flag, result): + lsfLauncher = LaunchSettings(launcher=LauncherType.Lsf) + assert isinstance(lsfLauncher._arguments, JsrunLaunchArguments) + getattr(lsfLauncher.launch_args, function)(*value) + assert lsfLauncher.launch_args._launch_args[flag] == result + + +def test_format_env_vars(): + env_vars = {"OMP_NUM_THREADS": None, "LOGGING": "verbose"} + lsfLauncher = LaunchSettings(launcher=LauncherType.Lsf, env_vars=env_vars) + assert isinstance(lsfLauncher._arguments, JsrunLaunchArguments) + formatted = lsfLauncher._arguments.format_env_vars(env_vars) + assert formatted == ["-E", "OMP_NUM_THREADS", "-E", "LOGGING=verbose"] + + +def test_launch_args(): + """Test the possible user overrides through run_args""" + launch_args = { + "latency_priority": "gpu-gpu", + "immediate": None, + "d": "packed", # test single letter variables + "nrs": 10, + "np": 100, + } + lsfLauncher = LaunchSettings(launcher=LauncherType.Lsf, launch_args=launch_args) + assert isinstance(lsfLauncher._arguments, JsrunLaunchArguments) + formatted = lsfLauncher._arguments.format_launch_args() + result = [ + "--latency_priority=gpu-gpu", + "--immediate", + "-d", + "packed", + "--nrs=10", + "--np=100", + ] + assert formatted == result + + +@pytest.mark.parametrize( + "args, expected", + ( + pytest.param( + {}, + ( + "jsrun", + "--stdio_stdout=output.txt", + "--stdio_stderr=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Empty Args", + ), + pytest.param( + {"n": "1"}, + ( + "jsrun", + "-n", + "1", + "--stdio_stdout=output.txt", + "--stdio_stderr=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Short Arg", + ), + pytest.param( + {"nrs": "1"}, + ( + "jsrun", + "--nrs=1", + "--stdio_stdout=output.txt", + "--stdio_stderr=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Long Arg", + ), + pytest.param( + {"v": None}, + ( + "jsrun", + "-v", + "--stdio_stdout=output.txt", + "--stdio_stderr=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Short Arg (No Value)", + ), + pytest.param( + {"verbose": None}, + ( + "jsrun", + "--verbose", + "--stdio_stdout=output.txt", + "--stdio_stderr=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Long Arg (No Value)", + ), + pytest.param( + {"tasks_per_rs": "1", "n": "123"}, + ( + "jsrun", + "--tasks_per_rs=1", + "-n", + "123", + "--stdio_stdout=output.txt", + "--stdio_stderr=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Short and Long Args", + ), + ), +) +def test_formatting_launch_args(args, expected, test_dir): + outfile = "output.txt" + errfile = "error.txt" + env, path, stdin, stdout, args = _as_jsrun_command( + JsrunLaunchArguments(args), + ("echo", "hello", "world"), + test_dir, + {}, + outfile, + errfile, + ) + assert tuple(args) == expected + assert path == test_dir + assert env == {} + assert stdin == subprocess.DEVNULL + assert stdout == subprocess.DEVNULL diff --git a/tests/temp_tests/test_settings/test_lsfScheduler.py b/tests/temp_tests/test_settings/test_lsfScheduler.py new file mode 100644 index 0000000000..5e6b7fd0c4 --- /dev/null +++ b/tests/temp_tests/test_settings/test_lsfScheduler.py @@ -0,0 +1,77 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +from smartsim.settings import BatchSettings +from smartsim.settings.batch_command import BatchSchedulerType + +pytestmark = pytest.mark.group_a + + +def test_scheduler_str(): + """Ensure scheduler_str returns appropriate value""" + bs = BatchSettings(batch_scheduler=BatchSchedulerType.Lsf) + assert bs.batch_args.scheduler_str() == BatchSchedulerType.Lsf.value + + +@pytest.mark.parametrize( + "function,value,result,flag", + [ + pytest.param("set_nodes", (2,), "2", "nnodes", id="set_nodes"), + pytest.param("set_walltime", ("10:00:00",), "10:00", "W", id="set_walltime"), + pytest.param( + "set_hostlist", ("host_A",), "" '"host_A"' "", "m", id="set_hostlist_str" + ), + pytest.param( + "set_hostlist", + (["host_A", "host_B"],), + "" '"host_A host_B"' "", + "m", + id="set_hostlist_list[str]", + ), + pytest.param("set_smts", (1,), "1", "alloc_flags", id="set_smts"), + pytest.param("set_project", ("project",), "project", "P", id="set_project"), + pytest.param("set_account", ("project",), "project", "P", id="set_account"), + pytest.param("set_tasks", (2,), "2", "n", id="set_tasks"), + pytest.param("set_queue", ("queue",), "queue", "q", id="set_queue"), + ], +) +def test_update_env_initialized(function, value, flag, result): + lsfScheduler = BatchSettings(batch_scheduler=BatchSchedulerType.Lsf) + getattr(lsfScheduler.batch_args, function)(*value) + assert lsfScheduler.batch_args._batch_args[flag] == result + + +def test_create_bsub(): + batch_args = {"core_isolation": None} + lsfScheduler = BatchSettings( + batch_scheduler=BatchSchedulerType.Lsf, batch_args=batch_args + ) + lsfScheduler.batch_args.set_nodes(1) + lsfScheduler.batch_args.set_walltime("10:10:10") + lsfScheduler.batch_args.set_queue("default") + args = lsfScheduler.format_batch_args() + assert args == ["-core_isolation", "-nnodes", "1", "-W", "10:10", "-q", "default"] diff --git a/tests/temp_tests/test_settings/test_mpiLauncher.py b/tests/temp_tests/test_settings/test_mpiLauncher.py new file mode 100644 index 0000000000..57be23ee2b --- /dev/null +++ b/tests/temp_tests/test_settings/test_mpiLauncher.py @@ -0,0 +1,304 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import itertools +import os +import pathlib + +import pytest + +from smartsim._core.shell.shell_launcher import ShellLauncherCommand +from smartsim.settings import LaunchSettings +from smartsim.settings.arguments.launch.mpi import ( + MpiexecLaunchArguments, + MpirunLaunchArguments, + OrterunLaunchArguments, + _as_mpiexec_command, + _as_mpirun_command, + _as_orterun_command, +) +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + + +@pytest.mark.parametrize( + "launcher", + [ + pytest.param(LauncherType.Mpirun, id="launcher_str_mpirun"), + pytest.param(LauncherType.Mpiexec, id="launcher_str_mpiexec"), + pytest.param(LauncherType.Orterun, id="launcher_str_orterun"), + ], +) +def test_launcher_str(launcher): + """Ensure launcher_str returns appropriate value""" + ls = LaunchSettings(launcher=launcher) + assert ls.launch_args.launcher_str() == launcher.value + + +@pytest.mark.parametrize( + "l,function,value,result,flag", + [ + # Use OpenMPI style settigs for all launchers + *itertools.chain.from_iterable( + ( + ( + pytest.param( + l, "set_walltime", ("100",), "100", "timeout", id="set_walltime" + ), + pytest.param( + l, + "set_task_map", + ("taskmap",), + "taskmap", + "map-by", + id="set_task_map", + ), + pytest.param( + l, + "set_cpus_per_task", + (2,), + "2", + "cpus-per-proc", + id="set_cpus_per_task", + ), + pytest.param( + l, + "set_cpu_binding_type", + ("4",), + "4", + "bind-to", + id="set_cpu_binding_type", + ), + pytest.param( + l, + "set_tasks_per_node", + (4,), + "4", + "npernode", + id="set_tasks_per_node", + ), + pytest.param(l, "set_tasks", (4,), "4", "n", id="set_tasks"), + pytest.param( + l, + "set_executable_broadcast", + ("broadcast",), + "broadcast", + "preload-binary", + id="set_executable_broadcast", + ), + pytest.param( + l, + "set_hostlist", + ("host_A",), + "host_A", + "host", + id="set_hostlist_str", + ), + pytest.param( + l, + "set_hostlist", + (["host_A", "host_B"],), + "host_A,host_B", + "host", + id="set_hostlist_list[str]", + ), + pytest.param( + l, + "set_hostlist_from_file", + ("./path/to/hostfile",), + "./path/to/hostfile", + "hostfile", + id="set_hostlist_from_file", + ), + ) + for l in ( + [LauncherType.Mpirun, MpirunLaunchArguments], + [LauncherType.Mpiexec, MpiexecLaunchArguments], + [LauncherType.Orterun, OrterunLaunchArguments], + ) + ) + ) + ], +) +def test_mpi_class_methods(l, function, value, flag, result): + mpiSettings = LaunchSettings(launcher=l[0]) + assert isinstance(mpiSettings._arguments, l[1]) + getattr(mpiSettings.launch_args, function)(*value) + assert mpiSettings.launch_args._launch_args[flag] == result + + +@pytest.mark.parametrize( + "launcher", + [ + pytest.param(LauncherType.Mpirun, id="format_env_mpirun"), + pytest.param(LauncherType.Mpiexec, id="format_env_mpiexec"), + pytest.param(LauncherType.Orterun, id="format_env_orterun"), + ], +) +def test_format_env_vars(launcher): + env_vars = {"OMP_NUM_THREADS": "20", "LOGGING": "verbose"} + mpiSettings = LaunchSettings(launcher=launcher, env_vars=env_vars) + formatted = mpiSettings._arguments.format_env_vars(env_vars) + result = [ + "-x", + "OMP_NUM_THREADS=20", + "-x", + "LOGGING=verbose", + ] + assert formatted == result + + +@pytest.mark.parametrize( + "launcher", + [ + pytest.param(LauncherType.Mpirun, id="format_launcher_args_mpirun"), + pytest.param(LauncherType.Mpiexec, id="format_launcher_args_mpiexec"), + pytest.param(LauncherType.Orterun, id="format_launcher_args_orterun"), + ], +) +def test_format_launcher_args(launcher): + mpiSettings = LaunchSettings(launcher=launcher) + mpiSettings.launch_args.set_cpus_per_task(1) + mpiSettings.launch_args.set_tasks(2) + mpiSettings.launch_args.set_hostlist(["node005", "node006"]) + formatted = mpiSettings._arguments.format_launch_args() + result = ["--cpus-per-proc", "1", "--n", "2", "--host", "node005,node006"] + assert formatted == result + + +@pytest.mark.parametrize( + "launcher", + [ + pytest.param(LauncherType.Mpirun, id="set_verbose_launch_mpirun"), + pytest.param(LauncherType.Mpiexec, id="set_verbose_launch_mpiexec"), + pytest.param(LauncherType.Orterun, id="set_verbose_launch_orterun"), + ], +) +def test_set_verbose_launch(launcher): + mpiSettings = LaunchSettings(launcher=launcher) + mpiSettings.launch_args.set_verbose_launch(True) + assert mpiSettings.launch_args._launch_args == {"verbose": None} + mpiSettings.launch_args.set_verbose_launch(False) + assert mpiSettings.launch_args._launch_args == {} + + +@pytest.mark.parametrize( + "launcher", + [ + pytest.param(LauncherType.Mpirun, id="set_quiet_launch_mpirun"), + pytest.param(LauncherType.Mpiexec, id="set_quiet_launch_mpiexec"), + pytest.param(LauncherType.Orterun, id="set_quiet_launch_orterun"), + ], +) +def test_set_quiet_launch(launcher): + mpiSettings = LaunchSettings(launcher=launcher) + mpiSettings.launch_args.set_quiet_launch(True) + assert mpiSettings.launch_args._launch_args == {"quiet": None} + mpiSettings.launch_args.set_quiet_launch(False) + assert mpiSettings.launch_args._launch_args == {} + + +@pytest.mark.parametrize( + "launcher", + [ + pytest.param(LauncherType.Mpirun, id="invalid_hostlist_mpirun"), + pytest.param(LauncherType.Mpiexec, id="invalid_hostlist_mpiexec"), + pytest.param(LauncherType.Orterun, id="invalid_hostlist_orterun"), + ], +) +def test_invalid_hostlist_format(launcher): + """Test invalid hostlist formats""" + mpiSettings = LaunchSettings(launcher=launcher) + with pytest.raises(TypeError): + mpiSettings.launch_args.set_hostlist(["test", 5]) + with pytest.raises(TypeError): + mpiSettings.launch_args.set_hostlist([5]) + with pytest.raises(TypeError): + mpiSettings.launch_args.set_hostlist(5) + + +@pytest.mark.parametrize( + "cls, fmt, cmd", + ( + pytest.param( + MpirunLaunchArguments, _as_mpirun_command, "mpirun", id="w/ mpirun" + ), + pytest.param( + MpiexecLaunchArguments, _as_mpiexec_command, "mpiexec", id="w/ mpiexec" + ), + pytest.param( + OrterunLaunchArguments, _as_orterun_command, "orterun", id="w/ orterun" + ), + ), +) +@pytest.mark.parametrize( + "args, expected", + ( + pytest.param({}, ("--", "echo", "hello", "world"), id="Empty Args"), + pytest.param( + {"n": "1"}, + ("--n", "1", "--", "echo", "hello", "world"), + id="Short Arg", + ), + pytest.param( + {"host": "myhost"}, + ("--host", "myhost", "--", "echo", "hello", "world"), + id="Long Arg", + ), + pytest.param( + {"v": None}, + ("--v", "--", "echo", "hello", "world"), + id="Short Arg (No Value)", + ), + pytest.param( + {"verbose": None}, + ("--verbose", "--", "echo", "hello", "world"), + id="Long Arg (No Value)", + ), + pytest.param( + {"n": "1", "host": "myhost"}, + ("--n", "1", "--host", "myhost", "--", "echo", "hello", "world"), + id="Short and Long Args", + ), + ), +) +def test_formatting_launch_args(cls, fmt, cmd, args, expected, test_dir): + out = os.path.join(test_dir, "out.txt") + err = os.path.join(test_dir, "err.txt") + open(out, "w"), open(err, "w") + shell_launch_cmd = fmt( + cls(args), ("echo", "hello", "world"), test_dir, {}, out, err + ) + assert isinstance(shell_launch_cmd, ShellLauncherCommand) + assert shell_launch_cmd.command_tuple == (cmd,) + expected + assert shell_launch_cmd.path == pathlib.Path(test_dir) + assert shell_launch_cmd.env == {} + assert isinstance(shell_launch_cmd.stdout, io.TextIOWrapper) + assert shell_launch_cmd.stdout.name == out + assert isinstance(shell_launch_cmd.stderr, io.TextIOWrapper) + assert shell_launch_cmd.stderr.name == err diff --git a/tests/temp_tests/test_settings/test_palsLauncher.py b/tests/temp_tests/test_settings/test_palsLauncher.py new file mode 100644 index 0000000000..d38d1842c6 --- /dev/null +++ b/tests/temp_tests/test_settings/test_palsLauncher.py @@ -0,0 +1,158 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import os +import pathlib + +import pytest + +from smartsim._core.shell.shell_launcher import ShellLauncherCommand +from smartsim.settings import LaunchSettings +from smartsim.settings.arguments.launch.pals import ( + PalsMpiexecLaunchArguments, + _as_pals_command, +) +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + + +def test_launcher_str(): + """Ensure launcher_str returns appropriate value""" + ls = LaunchSettings(launcher=LauncherType.Pals) + assert ls.launch_args.launcher_str() == LauncherType.Pals.value + + +@pytest.mark.parametrize( + "function,value,result,flag", + [ + pytest.param( + "set_cpu_binding_type", + ("bind",), + "bind", + "bind-to", + id="set_cpu_binding_type", + ), + pytest.param("set_tasks", (2,), "2", "np", id="set_tasks"), + pytest.param("set_tasks_per_node", (2,), "2", "ppn", id="set_tasks_per_node"), + pytest.param( + "set_hostlist", ("host_A",), "host_A", "hosts", id="set_hostlist_str" + ), + pytest.param( + "set_hostlist", + (["host_A", "host_B"],), + "host_A,host_B", + "hosts", + id="set_hostlist_list[str]", + ), + pytest.param( + "set_executable_broadcast", + ("broadcast",), + "broadcast", + "transfer", + id="set_executable_broadcast", + ), + ], +) +def test_pals_class_methods(function, value, flag, result): + palsLauncher = LaunchSettings(launcher=LauncherType.Pals) + assert isinstance(palsLauncher.launch_args, PalsMpiexecLaunchArguments) + getattr(palsLauncher.launch_args, function)(*value) + assert palsLauncher.launch_args._launch_args[flag] == result + assert palsLauncher._arguments.format_launch_args() == ["--" + flag, str(result)] + + +def test_format_env_vars(): + env_vars = {"FOO_VERSION": "3.14", "PATH": None, "LD_LIBRARY_PATH": None} + palsLauncher = LaunchSettings(launcher=LauncherType.Pals, env_vars=env_vars) + formatted = " ".join(palsLauncher._arguments.format_env_vars(env_vars)) + expected = "--env FOO_VERSION=3.14 --envlist PATH,LD_LIBRARY_PATH" + assert formatted == expected + + +def test_invalid_hostlist_format(): + """Test invalid hostlist formats""" + palsLauncher = LaunchSettings(launcher=LauncherType.Pals) + with pytest.raises(TypeError): + palsLauncher.launch_args.set_hostlist(["test", 5]) + with pytest.raises(TypeError): + palsLauncher.launch_args.set_hostlist([5]) + with pytest.raises(TypeError): + palsLauncher.launch_args.set_hostlist(5) + + +@pytest.mark.parametrize( + "args, expected", + ( + pytest.param({}, ("mpiexec", "--", "echo", "hello", "world"), id="Empty Args"), + pytest.param( + {"n": "1"}, + ("mpiexec", "--n", "1", "--", "echo", "hello", "world"), + id="Short Arg", + ), + pytest.param( + {"host": "myhost"}, + ("mpiexec", "--host", "myhost", "--", "echo", "hello", "world"), + id="Long Arg", + ), + pytest.param( + {"v": None}, + ("mpiexec", "--v", "--", "echo", "hello", "world"), + id="Short Arg (No Value)", + ), + pytest.param( + {"verbose": None}, + ("mpiexec", "--verbose", "--", "echo", "hello", "world"), + id="Long Arg (No Value)", + ), + pytest.param( + {"n": "1", "host": "myhost"}, + ("mpiexec", "--n", "1", "--host", "myhost", "--", "echo", "hello", "world"), + id="Short and Long Args", + ), + ), +) +def test_formatting_launch_args(args, expected, test_dir): + out = os.path.join(test_dir, "out.txt") + err = os.path.join(test_dir, "err.txt") + open(out, "w"), open(err, "w") + shell_launch_cmd = _as_pals_command( + PalsMpiexecLaunchArguments(args), + ("echo", "hello", "world"), + test_dir, + {}, + out, + err, + ) + assert isinstance(shell_launch_cmd, ShellLauncherCommand) + assert shell_launch_cmd.command_tuple == expected + assert shell_launch_cmd.path == pathlib.Path(test_dir) + assert shell_launch_cmd.env == {} + assert isinstance(shell_launch_cmd.stdout, io.TextIOWrapper) + assert shell_launch_cmd.stdout.name == out + assert isinstance(shell_launch_cmd.stderr, io.TextIOWrapper) + assert shell_launch_cmd.stderr.name == err diff --git a/tests/temp_tests/test_settings/test_pbsScheduler.py b/tests/temp_tests/test_settings/test_pbsScheduler.py new file mode 100644 index 0000000000..36fde6776d --- /dev/null +++ b/tests/temp_tests/test_settings/test_pbsScheduler.py @@ -0,0 +1,88 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +from smartsim.settings import BatchSettings +from smartsim.settings.arguments.batch.pbs import QsubBatchArguments +from smartsim.settings.batch_command import BatchSchedulerType + +pytestmark = pytest.mark.group_a + + +def test_scheduler_str(): + """Ensure scheduler_str returns appropriate value""" + bs = BatchSettings(batch_scheduler=BatchSchedulerType.Pbs) + assert bs.batch_args.scheduler_str() == BatchSchedulerType.Pbs.value + + +@pytest.mark.parametrize( + "function,value,result,flag", + [ + pytest.param("set_nodes", (2,), "2", "nodes", id="set_nodes"), + pytest.param( + "set_walltime", ("10:00:00",), "10:00:00", "walltime", id="set_walltime" + ), + pytest.param("set_account", ("account",), "account", "A", id="set_account"), + pytest.param("set_queue", ("queue",), "queue", "q", id="set_queue"), + pytest.param("set_ncpus", (2,), "2", "ppn", id="set_ncpus"), + pytest.param( + "set_hostlist", ("host_A",), "host_A", "hostname", id="set_hostlist_str" + ), + pytest.param( + "set_hostlist", + (["host_A", "host_B"],), + "host_A,host_B", + "hostname", + id="set_hostlist_list[str]", + ), + ], +) +def test_create_pbs_batch(function, value, flag, result): + pbsScheduler = BatchSettings(batch_scheduler=BatchSchedulerType.Pbs) + assert isinstance(pbsScheduler.batch_args, QsubBatchArguments) + getattr(pbsScheduler.batch_args, function)(*value) + assert pbsScheduler.batch_args._batch_args[flag] == result + + +def test_format_pbs_batch_args(): + pbsScheduler = BatchSettings(batch_scheduler=BatchSchedulerType.Pbs) + pbsScheduler.batch_args.set_nodes(1) + pbsScheduler.batch_args.set_walltime("10:00:00") + pbsScheduler.batch_args.set_queue("default") + pbsScheduler.batch_args.set_account("myproject") + pbsScheduler.batch_args.set_ncpus(10) + pbsScheduler.batch_args.set_hostlist(["host_a", "host_b", "host_c"]) + args = pbsScheduler.format_batch_args() + assert args == [ + "-l", + "nodes=1:ncpus=10:host=host_a+host=host_b+host=host_c", + "-l", + "walltime=10:00:00", + "-q", + "default", + "-A", + "myproject", + ] diff --git a/tests/temp_tests/test_settings/test_slurmLauncher.py b/tests/temp_tests/test_settings/test_slurmLauncher.py new file mode 100644 index 0000000000..6be9b5542a --- /dev/null +++ b/tests/temp_tests/test_settings/test_slurmLauncher.py @@ -0,0 +1,398 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import subprocess + +import pytest + +from smartsim._core.shell.shell_launcher import ShellLauncherCommand +from smartsim.settings import LaunchSettings +from smartsim.settings.arguments.launch.slurm import ( + SlurmLaunchArguments, + _as_srun_command, +) +from smartsim.settings.launch_command import LauncherType + +pytestmark = pytest.mark.group_a + + +def test_launcher_str(): + """Ensure launcher_str returns appropriate value""" + ls = LaunchSettings(launcher=LauncherType.Slurm) + assert ls.launch_args.launcher_str() == LauncherType.Slurm.value + + +@pytest.mark.parametrize( + "function,value,result,flag", + [ + pytest.param("set_nodes", (2,), "2", "nodes", id="set_nodes"), + pytest.param( + "set_hostlist", ("host_A",), "host_A", "nodelist", id="set_hostlist_str" + ), + pytest.param( + "set_hostlist", + (["host_A", "host_B"],), + "host_A,host_B", + "nodelist", + id="set_hostlist_list[str]", + ), + pytest.param( + "set_hostlist_from_file", + ("./path/to/hostfile",), + "./path/to/hostfile", + "nodefile", + id="set_hostlist_from_file", + ), + pytest.param( + "set_excluded_hosts", + ("host_A",), + "host_A", + "exclude", + id="set_excluded_hosts_str", + ), + pytest.param( + "set_excluded_hosts", + (["host_A", "host_B"],), + "host_A,host_B", + "exclude", + id="set_excluded_hosts_list[str]", + ), + pytest.param( + "set_cpus_per_task", (4,), "4", "cpus-per-task", id="set_cpus_per_task" + ), + pytest.param("set_tasks", (4,), "4", "ntasks", id="set_tasks"), + pytest.param( + "set_tasks_per_node", (4,), "4", "ntasks-per-node", id="set_tasks_per_node" + ), + pytest.param( + "set_cpu_bindings", (4,), "map_cpu:4", "cpu_bind", id="set_cpu_bindings" + ), + pytest.param( + "set_cpu_bindings", + ([4, 4],), + "map_cpu:4,4", + "cpu_bind", + id="set_cpu_bindings_list[str]", + ), + pytest.param( + "set_memory_per_node", (8000,), "8000M", "mem", id="set_memory_per_node" + ), + pytest.param( + "set_executable_broadcast", + ("/tmp/some/path",), + "/tmp/some/path", + "bcast", + id="set_broadcast", + ), + pytest.param("set_node_feature", ("P100",), "P100", "C", id="set_node_feature"), + pytest.param( + "set_walltime", ("10:00:00",), "10:00:00", "time", id="set_walltime" + ), + ], +) +def test_slurm_class_methods(function, value, flag, result): + slurmLauncher = LaunchSettings(launcher=LauncherType.Slurm) + assert isinstance(slurmLauncher.launch_args, SlurmLaunchArguments) + getattr(slurmLauncher.launch_args, function)(*value) + assert slurmLauncher.launch_args._launch_args[flag] == result + + +def test_set_verbose_launch(): + ls = LaunchSettings(launcher=LauncherType.Slurm) + ls.launch_args.set_verbose_launch(True) + assert ls.launch_args._launch_args == {"verbose": None} + ls.launch_args.set_verbose_launch(False) + assert ls.launch_args._launch_args == {} + + +def test_set_quiet_launch(): + ls = LaunchSettings(launcher=LauncherType.Slurm) + ls.launch_args.set_quiet_launch(True) + assert ls.launch_args._launch_args == {"quiet": None} + ls.launch_args.set_quiet_launch(False) + assert ls.launch_args._launch_args == {} + + +def test_format_env_vars(): + """Test format_env_vars runs correctly""" + env_vars = { + "OMP_NUM_THREADS": "20", + "LOGGING": "verbose", + "SSKEYIN": "name_0,name_1", + } + ls = LaunchSettings(launcher=LauncherType.Slurm, env_vars=env_vars) + ls_format = ls._arguments.format_env_vars(env_vars) + assert "OMP_NUM_THREADS=20" in ls_format + assert "LOGGING=verbose" in ls_format + assert all("SSKEYIN" not in x for x in ls_format) + + +def test_catch_existing_env_var(caplog, monkeypatch): + slurmSettings = LaunchSettings( + launcher=LauncherType.Slurm, + env_vars={ + "SMARTSIM_TEST_VAR": "B", + }, + ) + monkeypatch.setenv("SMARTSIM_TEST_VAR", "A") + monkeypatch.setenv("SMARTSIM_TEST_CSVAR", "A,B") + caplog.clear() + slurmSettings._arguments.format_env_vars(slurmSettings._env_vars) + + msg = f"Variable SMARTSIM_TEST_VAR is set to A in current environment. " + msg += f"If the job is running in an interactive allocation, the value B will not be set. " + msg += "Please consider removing the variable from the environment and re-running the experiment." + + for record in caplog.records: + assert record.levelname == "WARNING" + assert record.message == msg + + caplog.clear() + + env_vars = {"SMARTSIM_TEST_VAR": "B", "SMARTSIM_TEST_CSVAR": "C,D"} + settings = LaunchSettings(launcher=LauncherType.Slurm, env_vars=env_vars) + settings._arguments.format_comma_sep_env_vars(env_vars) + + for record in caplog.records: + assert record.levelname == "WARNING" + assert record.message == msg + + +def test_format_comma_sep_env_vars(): + """Test format_comma_sep_env_vars runs correctly""" + env_vars = { + "OMP_NUM_THREADS": "20", + "LOGGING": "verbose", + "SSKEYIN": "name_0,name_1", + } + slurmLauncher = LaunchSettings(launcher=LauncherType.Slurm, env_vars=env_vars) + formatted, comma_separated_formatted = ( + slurmLauncher._arguments.format_comma_sep_env_vars(env_vars) + ) + assert "OMP_NUM_THREADS" in formatted + assert "LOGGING" in formatted + assert "SSKEYIN" in formatted + assert "name_0,name_1" not in formatted + assert "SSKEYIN=name_0,name_1" in comma_separated_formatted + + +def test_slurmSettings_settings(): + """Test format_launch_args runs correctly""" + slurmLauncher = LaunchSettings(launcher=LauncherType.Slurm) + slurmLauncher.launch_args.set_nodes(5) + slurmLauncher.launch_args.set_cpus_per_task(2) + slurmLauncher.launch_args.set_tasks(100) + slurmLauncher.launch_args.set_tasks_per_node(20) + formatted = slurmLauncher._arguments.format_launch_args() + result = ["--nodes=5", "--cpus-per-task=2", "--ntasks=100", "--ntasks-per-node=20"] + assert formatted == result + + +def test_slurmSettings_launch_args(): + """Test the possible user overrides through run_args""" + launch_args = { + "account": "A3123", + "exclusive": None, + "C": "P100", # test single letter variables + "nodes": 10, + "ntasks": 100, + } + slurmLauncher = LaunchSettings(launcher=LauncherType.Slurm, launch_args=launch_args) + formatted = slurmLauncher._arguments.format_launch_args() + result = [ + "--account=A3123", + "--exclusive", + "-C", + "P100", + "--nodes=10", + "--ntasks=100", + ] + assert formatted == result + + +def test_invalid_hostlist_format(): + """Test invalid hostlist formats""" + slurmLauncher = LaunchSettings(launcher=LauncherType.Slurm) + with pytest.raises(TypeError): + slurmLauncher.launch_args.set_hostlist(["test", 5]) + with pytest.raises(TypeError): + slurmLauncher.launch_args.set_hostlist([5]) + with pytest.raises(TypeError): + slurmLauncher.launch_args.set_hostlist(5) + + +def test_invalid_exclude_hostlist_format(): + """Test invalid hostlist formats""" + slurmLauncher = LaunchSettings(launcher=LauncherType.Slurm) + with pytest.raises(TypeError): + slurmLauncher.launch_args.set_excluded_hosts(["test", 5]) + with pytest.raises(TypeError): + slurmLauncher.launch_args.set_excluded_hosts([5]) + with pytest.raises(TypeError): + slurmLauncher.launch_args.set_excluded_hosts(5) + + +def test_invalid_node_feature_format(): + """Test invalid node feature formats""" + slurmLauncher = LaunchSettings(launcher=LauncherType.Slurm) + with pytest.raises(TypeError): + slurmLauncher.launch_args.set_node_feature(["test", 5]) + with pytest.raises(TypeError): + slurmLauncher.launch_args.set_node_feature([5]) + with pytest.raises(TypeError): + slurmLauncher.launch_args.set_node_feature(5) + + +def test_invalid_walltime_format(): + """Test invalid walltime formats""" + slurmLauncher = LaunchSettings(launcher=LauncherType.Slurm) + with pytest.raises(ValueError): + slurmLauncher.launch_args.set_walltime("11:11") + with pytest.raises(ValueError): + slurmLauncher.launch_args.set_walltime("ss:ss:ss") + with pytest.raises(ValueError): + slurmLauncher.launch_args.set_walltime("11:ss:ss") + with pytest.raises(ValueError): + slurmLauncher.launch_args.set_walltime("0s:ss:ss") + + +def test_set_het_groups(monkeypatch): + """Test ability to set one or more het groups to run setting""" + monkeypatch.setenv("SLURM_HET_SIZE", "4") + slurmLauncher = LaunchSettings(launcher=LauncherType.Slurm) + slurmLauncher.launch_args.set_het_group([1]) + assert slurmLauncher._arguments._launch_args["het-group"] == "1" + slurmLauncher.launch_args.set_het_group([3, 2]) + assert slurmLauncher._arguments._launch_args["het-group"] == "3,2" + with pytest.raises(ValueError): + slurmLauncher.launch_args.set_het_group([4]) + + +@pytest.mark.parametrize( + "args, expected", + ( + pytest.param( + {}, + ( + "srun", + "--output=output.txt", + "--error=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Empty Args", + ), + pytest.param( + {"N": "1"}, + ( + "srun", + "-N", + "1", + "--output=output.txt", + "--error=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Short Arg", + ), + pytest.param( + {"nodes": "1"}, + ( + "srun", + "--nodes=1", + "--output=output.txt", + "--error=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Long Arg", + ), + pytest.param( + {"v": None}, + ( + "srun", + "-v", + "--output=output.txt", + "--error=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Short Arg (No Value)", + ), + pytest.param( + {"verbose": None}, + ( + "srun", + "--verbose", + "--output=output.txt", + "--error=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Long Arg (No Value)", + ), + pytest.param( + {"nodes": "1", "n": "123"}, + ( + "srun", + "--nodes=1", + "-n", + "123", + "--output=output.txt", + "--error=error.txt", + "--", + "echo", + "hello", + "world", + ), + id="Short and Long Args", + ), + ), +) +def test_formatting_launch_args(args, expected, test_dir): + shell_launch_cmd = _as_srun_command( + args=SlurmLaunchArguments(args), + exe=("echo", "hello", "world"), + path=test_dir, + env={}, + stdout_path="output.txt", + stderr_path="error.txt", + ) + assert isinstance(shell_launch_cmd, ShellLauncherCommand) + assert shell_launch_cmd.command_tuple == expected + assert shell_launch_cmd.path == test_dir + assert shell_launch_cmd.env == {} + assert shell_launch_cmd.stdout == subprocess.DEVNULL + assert shell_launch_cmd.stderr == subprocess.DEVNULL diff --git a/tests/temp_tests/test_settings/test_slurmScheduler.py b/tests/temp_tests/test_settings/test_slurmScheduler.py new file mode 100644 index 0000000000..8ab489cc8b --- /dev/null +++ b/tests/temp_tests/test_settings/test_slurmScheduler.py @@ -0,0 +1,136 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +from smartsim.settings import BatchSettings +from smartsim.settings.arguments.batch.slurm import SlurmBatchArguments +from smartsim.settings.batch_command import BatchSchedulerType + +pytestmark = pytest.mark.group_a + + +def test_batch_scheduler_str(): + """Ensure scheduler_str returns appropriate value""" + bs = BatchSettings(batch_scheduler=BatchSchedulerType.Slurm) + assert bs.batch_args.scheduler_str() == BatchSchedulerType.Slurm.value + + +@pytest.mark.parametrize( + "function,value,result,flag", + [ + pytest.param("set_nodes", (2,), "2", "nodes", id="set_nodes"), + pytest.param( + "set_walltime", ("10:00:00",), "10:00:00", "time", id="set_walltime" + ), + pytest.param( + "set_account", ("account",), "account", "account", id="set_account" + ), + pytest.param( + "set_partition", + ("partition",), + "partition", + "partition", + id="set_partition", + ), + pytest.param( + "set_queue", ("partition",), "partition", "partition", id="set_queue" + ), + pytest.param( + "set_cpus_per_task", (2,), "2", "cpus-per-task", id="set_cpus_per_task" + ), + pytest.param( + "set_hostlist", ("host_A",), "host_A", "nodelist", id="set_hostlist_str" + ), + pytest.param( + "set_hostlist", + (["host_A", "host_B"],), + "host_A,host_B", + "nodelist", + id="set_hostlist_list[str]", + ), + ], +) +def test_sbatch_class_methods(function, value, flag, result): + slurmScheduler = BatchSettings(batch_scheduler=BatchSchedulerType.Slurm) + getattr(slurmScheduler.batch_args, function)(*value) + assert slurmScheduler.batch_args._batch_args[flag] == result + + +def test_create_sbatch(): + batch_args = {"exclusive": None, "oversubscribe": None} + slurmScheduler = BatchSettings( + batch_scheduler=BatchSchedulerType.Slurm, batch_args=batch_args + ) + assert isinstance(slurmScheduler._arguments, SlurmBatchArguments) + args = slurmScheduler.format_batch_args() + assert args == ["--exclusive", "--oversubscribe"] + + +def test_launch_args_input_mutation(): + # Tests that the run args passed in are not modified after initialization + key0, key1, key2 = "arg0", "arg1", "arg2" + val0, val1, val2 = "val0", "val1", "val2" + + default_batch_args = { + key0: val0, + key1: val1, + key2: val2, + } + slurmScheduler = BatchSettings( + batch_scheduler=BatchSchedulerType.Slurm, batch_args=default_batch_args + ) + + # Confirm initial values are set + assert slurmScheduler.batch_args._batch_args[key0] == val0 + assert slurmScheduler.batch_args._batch_args[key1] == val1 + assert slurmScheduler.batch_args._batch_args[key2] == val2 + + # Update our common run arguments + val2_upd = f"not-{val2}" + default_batch_args[key2] = val2_upd + + # Confirm previously created run settings are not changed + assert slurmScheduler.batch_args._batch_args[key2] == val2 + + +def test_sbatch_settings(): + batch_args = {"nodes": 1, "time": "10:00:00", "account": "A3123"} + slurmScheduler = BatchSettings( + batch_scheduler=BatchSchedulerType.Slurm, batch_args=batch_args + ) + formatted = slurmScheduler.format_batch_args() + result = ["--nodes=1", "--time=10:00:00", "--account=A3123"] + assert formatted == result + + +def test_sbatch_manual(): + slurmScheduler = BatchSettings(batch_scheduler=BatchSchedulerType.Slurm) + slurmScheduler.batch_args.set_nodes(5) + slurmScheduler.batch_args.set_account("A3531") + slurmScheduler.batch_args.set_walltime("10:00:00") + formatted = slurmScheduler.format_batch_args() + result = ["--nodes=5", "--account=A3531", "--time=10:00:00"] + assert formatted == result diff --git a/tests/test_application.py b/tests/test_application.py new file mode 100644 index 0000000000..54a02c5b4d --- /dev/null +++ b/tests/test_application.py @@ -0,0 +1,207 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from glob import glob +from os import path as osp + +import pytest + +from smartsim.entity.application import Application +from smartsim.settings.launch_settings import LaunchSettings + +pytestmark = pytest.mark.group_a + + +@pytest.fixture +def get_gen_configure_dir(fileutils): + yield fileutils.get_test_conf_path(osp.join("generator_files", "tag_dir_template")) + + +@pytest.fixture +def mock_launcher_settings(wlmutils): + return LaunchSettings(wlmutils.get_test_launcher(), {}, {}) + + +def test_application_exe_property(): + a = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + exe = a.exe + assert exe is a.exe + + +def test_application_exe_args_property(): + a = Application("test_name", exe="echo", exe_args=["spam", "eggs"]) + exe_args = a.exe_args + assert exe_args is a.exe_args + + +def test_application_file_parameters_property(): + file_parameters = {"h": [5, 6, 7, 8]} + a = Application( + "test_name", + exe="echo", + file_parameters=file_parameters, + ) + file_parameters = a.file_parameters + + assert file_parameters is a.file_parameters + + +def test_application_key_prefixing_property(): + key_prefixing_enabled = True + a = Application("test_name", exe="echo", exe_args=["spam", "eggs"]) + key_prefixing_enabled = a.key_prefixing_enabled + assert key_prefixing_enabled == a.key_prefixing_enabled + + +def test_empty_executable(): + """Test that an error is raised when the exe property is empty""" + with pytest.raises(ValueError): + Application(name="application", exe=None, exe_args=None) + + +def test_executable_is_not_empty_str(): + """Test that an error is raised when the exe property is and empty str""" + app = Application(name="application", exe="echo", exe_args=None) + with pytest.raises(ValueError): + app.exe = "" + + +def test_type_exe(): + with pytest.raises(TypeError): + Application( + "test_name", + exe=2, + exe_args=["spam", "eggs"], + ) + + +def test_type_exe_args(): + application = Application( + "test_name", + exe="echo", + ) + with pytest.raises(TypeError): + application.exe_args = [1, 2, 3] + + +def test_type_incoming_entities(): + application = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises(TypeError): + application.incoming_entities = [1, 2, 3] + + +# application type checks +def test_application_type_exe(): + application = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises(TypeError, match="exe argument was not of type str"): + application.exe = 2 + + +def test_application_type_exe_args(): + application = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises( + TypeError, match="Executable arguments were not a list of str or a str." + ): + application.exe_args = [1, 2, 3] + + +@pytest.mark.parametrize( + "file_params", + ( + pytest.param(["invalid"], id="Not a mapping"), + pytest.param({"1": 2}, id="Value is not mapping of str and str"), + pytest.param({1: "2"}, id="Key is not mapping of str and str"), + pytest.param({1: 2}, id="Values not mapping of str and str"), + ), +) +def test_application_type_file_parameters(file_params): + application = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises( + TypeError, + match="file_parameters argument was not of type mapping of str and str", + ): + application.file_parameters = file_params + + +def test_application_type_incoming_entities(): + + application = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises( + TypeError, + match="incoming_entities argument was not of type list of SmartSimEntity", + ): + application.incoming_entities = [1, 2, 3] + + +def test_application_type_key_prefixing_enabled(): + + application = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises( + TypeError, + match="key_prefixing_enabled argument was not of type bool", + ): + application.key_prefixing_enabled = "invalid" + + +def test_application_type_build_exe_args(): + application = Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises( + TypeError, match="Executable arguments were not a list of str or a str." + ): + + application.exe_args = [1, 2, 3] diff --git a/tests/test_colo_model_local.py b/tests/test_colo_model_local.py deleted file mode 100644 index fe347ee309..0000000000 --- a/tests/test_colo_model_local.py +++ /dev/null @@ -1,314 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import sys - -import pytest - -from smartsim import Experiment -from smartsim.entity import Model -from smartsim.error import SSUnsupportedError -from smartsim.status import SmartSimStatus - -# The tests in this file belong to the slow_tests group -pytestmark = pytest.mark.slow_tests - - -if sys.platform == "darwin": - supported_dbs = ["tcp", "deprecated"] -else: - supported_dbs = ["uds", "tcp", "deprecated"] - -is_mac = sys.platform == "darwin" - - -@pytest.mark.skipif(not is_mac, reason="MacOS-only test") -def test_macosx_warning(fileutils, test_dir, coloutils): - db_args = {"custom_pinning": [1]} - db_type = "uds" # Test is insensitive to choice of db - - exp = Experiment("colocated_model_defaults", launcher="local", exp_path=test_dir) - with pytest.warns( - RuntimeWarning, - match="CPU pinning is not supported on MacOSX. Ignoring pinning specification.", - ): - _ = coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - ) - - -def test_unsupported_limit_app(fileutils, test_dir, coloutils): - db_args = {"limit_app_cpus": True} - db_type = "uds" # Test is insensitive to choice of db - - exp = Experiment("colocated_model_defaults", launcher="local", exp_path=test_dir) - with pytest.raises(SSUnsupportedError): - coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - ) - - -@pytest.mark.skipif(is_mac, reason="Unsupported on MacOSX") -@pytest.mark.parametrize("custom_pinning", [1, "10", "#", 1.0, ["a"], [1.0]]) -def test_unsupported_custom_pinning(fileutils, test_dir, coloutils, custom_pinning): - db_type = "uds" # Test is insensitive to choice of db - db_args = {"custom_pinning": custom_pinning} - - exp = Experiment("colocated_model_defaults", launcher="local", exp_path=test_dir) - with pytest.raises(TypeError): - coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - ) - - -@pytest.mark.skipif(is_mac, reason="Unsupported on MacOSX") -@pytest.mark.parametrize( - "pin_list, num_cpus, expected", - [ - pytest.param(None, 2, "0,1", id="Automatic creation of pinned cpu list"), - pytest.param([1, 2], 2, "1,2", id="Individual ids only"), - pytest.param([range(2), 3], 3, "0,1,3", id="Mixed ranges and individual ids"), - pytest.param(range(3), 3, "0,1,2", id="Range only"), - pytest.param( - [range(8, 10), range(6, 1, -2)], 4, "2,4,6,8,9", id="Multiple ranges" - ), - ], -) -def test_create_pinning_string(pin_list, num_cpus, expected): - assert Model._create_pinning_string(pin_list, num_cpus) == expected - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_launch_colocated_model_defaults( - fileutils, test_dir, coloutils, db_type, launcher="local" -): - """Test the launch of a model with a colocated database and local launcher""" - - db_args = {} - - exp = Experiment("colocated_model_defaults", launcher=launcher, exp_path=test_dir) - colo_model = coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - ) - - if is_mac: - true_pinning = None - else: - true_pinning = "0" - assert ( - colo_model.run_settings.colocated_db_settings["custom_pinning"] == true_pinning - ) - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all(stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses) - - # test restarting the colocated model - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses {statuses}" - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_launch_multiple_colocated_models( - fileutils, test_dir, coloutils, wlmutils, db_type, launcher="local" -): - """Test the concurrent launch of two models with a colocated database and local launcher""" - - db_args = {} - - exp = Experiment("multi_colo_models", launcher=launcher, exp_path=test_dir) - colo_models = [ - coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - colo_model_name="colo0", - port=wlmutils.get_test_port(), - ), - coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - colo_model_name="colo1", - port=wlmutils.get_test_port() + 1, - ), - ] - exp.generate(*colo_models) - exp.start(*colo_models, block=True) - statuses = exp.get_status(*colo_models) - assert all(stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses) - - # test restarting the colocated model - exp.start(*colo_models, block=True) - statuses = exp.get_status(*colo_models) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_colocated_model_disable_pinning( - fileutils, test_dir, coloutils, db_type, launcher="local" -): - exp = Experiment( - "colocated_model_pinning_auto_1cpu", launcher=launcher, exp_path=test_dir - ) - db_args = { - "db_cpus": 1, - "custom_pinning": [], - } - # Check to make sure that the CPU mask was correctly generated - colo_model = coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - ) - assert colo_model.run_settings.colocated_db_settings["custom_pinning"] is None - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) - - -@pytest.mark.parametrize("db_type", supported_dbs) -def test_colocated_model_pinning_auto_2cpu( - fileutils, test_dir, coloutils, db_type, launcher="local" -): - exp = Experiment( - "colocated_model_pinning_auto_2cpu", launcher=launcher, exp_path=test_dir - ) - - db_args = { - "db_cpus": 2, - } - - # Check to make sure that the CPU mask was correctly generated - colo_model = coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - ) - if is_mac: - true_pinning = None - else: - true_pinning = "0,1" - assert ( - colo_model.run_settings.colocated_db_settings["custom_pinning"] == true_pinning - ) - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) - - -@pytest.mark.skipif(is_mac, reason="unsupported on MacOSX") -@pytest.mark.parametrize("db_type", supported_dbs) -def test_colocated_model_pinning_range( - fileutils, test_dir, coloutils, db_type, launcher="local" -): - # Check to make sure that the CPU mask was correctly generated - - exp = Experiment( - "colocated_model_pinning_manual", launcher=launcher, exp_path=test_dir - ) - - db_args = {"db_cpus": 2, "custom_pinning": range(2)} - - colo_model = coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - ) - assert colo_model.run_settings.colocated_db_settings["custom_pinning"] == "0,1" - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) - - -@pytest.mark.skipif(is_mac, reason="unsupported on MacOSX") -@pytest.mark.parametrize("db_type", supported_dbs) -def test_colocated_model_pinning_list( - fileutils, test_dir, coloutils, db_type, launcher="local" -): - # Check to make sure that the CPU mask was correctly generated - - exp = Experiment( - "colocated_model_pinning_manual", launcher=launcher, exp_path=test_dir - ) - - db_args = {"db_cpus": 1, "custom_pinning": [1]} - - colo_model = coloutils.setup_test_colo( - fileutils, - db_type, - exp, - "send_data_local_smartredis.py", - db_args, - ) - assert colo_model.run_settings.colocated_db_settings["custom_pinning"] == "1" - exp.generate(colo_model) - exp.start(colo_model, block=True) - statuses = exp.get_status(colo_model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) - - -def test_colo_uds_verifies_socket_file_name(test_dir, launcher="local"): - exp = Experiment(f"colo_uds_wrong_name", launcher=launcher, exp_path=test_dir) - - colo_settings = exp.create_run_settings(exe=sys.executable, exe_args=["--version"]) - - colo_model = exp.create_model("wrong_uds_socket_name", colo_settings) - - with pytest.raises(ValueError): - colo_model.colocate_db_uds(unix_socket="this is not a valid name!") diff --git a/tests/test_configs/generator_files/easy/correct/invalidtag.txt b/tests/test_configs/generator_files/easy/correct/invalidtag.txt new file mode 100644 index 0000000000..2165ae8d1a --- /dev/null +++ b/tests/test_configs/generator_files/easy/correct/invalidtag.txt @@ -0,0 +1,3 @@ +some text before +some params are valid and others are ;INVALID; but we mostly encounter valid params +some text after diff --git a/tests/test_configs/generator_files/easy/marked/invalidtag.txt b/tests/test_configs/generator_files/easy/marked/invalidtag.txt new file mode 100644 index 0000000000..90a6253199 --- /dev/null +++ b/tests/test_configs/generator_files/easy/marked/invalidtag.txt @@ -0,0 +1,3 @@ +some text before +some params are ;VALID; and others are ;INVALID; but we mostly encounter ;VALID; params +some text after diff --git a/tests/test_configs/generator_files/log_params/dir_test/dir_test_0/smartsim_params.txt b/tests/test_configs/generator_files/log_params/dir_test/dir_test_0/smartsim_params.txt index 373cec87e0..d29f0741f4 100644 --- a/tests/test_configs/generator_files/log_params/dir_test/dir_test_0/smartsim_params.txt +++ b/tests/test_configs/generator_files/log_params/dir_test/dir_test_0/smartsim_params.txt @@ -1,4 +1,4 @@ -Model name: dir_test_0 +Application name: dir_test_0 File name Parameters -------------------------- --------------- dir_test/dir_test_0/in.atm Name Value diff --git a/tests/test_configs/generator_files/log_params/dir_test/dir_test_1/smartsim_params.txt b/tests/test_configs/generator_files/log_params/dir_test/dir_test_1/smartsim_params.txt index e45ebb6bf7..86cc2151b8 100644 --- a/tests/test_configs/generator_files/log_params/dir_test/dir_test_1/smartsim_params.txt +++ b/tests/test_configs/generator_files/log_params/dir_test/dir_test_1/smartsim_params.txt @@ -1,4 +1,4 @@ -Model name: dir_test_1 +Application name: dir_test_1 File name Parameters -------------------------- --------------- dir_test/dir_test_1/in.atm Name Value diff --git a/tests/test_configs/generator_files/log_params/dir_test/dir_test_2/smartsim_params.txt b/tests/test_configs/generator_files/log_params/dir_test/dir_test_2/smartsim_params.txt index 081dc56c67..ef4ea24736 100644 --- a/tests/test_configs/generator_files/log_params/dir_test/dir_test_2/smartsim_params.txt +++ b/tests/test_configs/generator_files/log_params/dir_test/dir_test_2/smartsim_params.txt @@ -1,4 +1,4 @@ -Model name: dir_test_2 +Application name: dir_test_2 File name Parameters -------------------------- --------------- dir_test/dir_test_2/in.atm Name Value diff --git a/tests/test_configs/generator_files/log_params/dir_test/dir_test_3/smartsim_params.txt b/tests/test_configs/generator_files/log_params/dir_test/dir_test_3/smartsim_params.txt index 3403f7c714..496e12e3bd 100644 --- a/tests/test_configs/generator_files/log_params/dir_test/dir_test_3/smartsim_params.txt +++ b/tests/test_configs/generator_files/log_params/dir_test/dir_test_3/smartsim_params.txt @@ -1,4 +1,4 @@ -Model name: dir_test_3 +Application name: dir_test_3 File name Parameters -------------------------- --------------- dir_test/dir_test_3/in.atm Name Value diff --git a/tests/test_configs/generator_files/log_params/smartsim_params.txt b/tests/test_configs/generator_files/log_params/smartsim_params.txt index 6ac92049fe..d3dcc5aac6 100644 --- a/tests/test_configs/generator_files/log_params/smartsim_params.txt +++ b/tests/test_configs/generator_files/log_params/smartsim_params.txt @@ -1,5 +1,5 @@ Generation start date and time: 08/09/2023 18:22:44 -Model name: dir_test_0 +Application name: dir_test_0 File name Parameters -------------------------- --------------- dir_test/dir_test_0/in.atm Name Value @@ -7,7 +7,7 @@ dir_test/dir_test_0/in.atm Name Value THERMO 10 STEPS 10 -Model name: dir_test_1 +Application name: dir_test_1 File name Parameters -------------------------- --------------- dir_test/dir_test_1/in.atm Name Value @@ -15,7 +15,7 @@ dir_test/dir_test_1/in.atm Name Value THERMO 10 STEPS 20 -Model name: dir_test_2 +Application name: dir_test_2 File name Parameters -------------------------- --------------- dir_test/dir_test_2/in.atm Name Value @@ -23,7 +23,7 @@ dir_test/dir_test_2/in.atm Name Value THERMO 20 STEPS 10 -Model name: dir_test_3 +Application name: dir_test_3 File name Parameters -------------------------- --------------- dir_test/dir_test_3/in.atm Name Value diff --git a/tests/test_configs/generator_files/to_copy_dir/mock.txt b/tests/test_configs/generator_files/to_copy_dir/mock_1.txt similarity index 100% rename from tests/test_configs/generator_files/to_copy_dir/mock.txt rename to tests/test_configs/generator_files/to_copy_dir/mock_1.txt diff --git a/tests/test_configs/generator_files/to_copy_dir/mock_2.txt b/tests/test_configs/generator_files/to_copy_dir/mock_2.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_configs/generator_files/to_copy_dir/mock_3.txt b/tests/test_configs/generator_files/to_copy_dir/mock_3.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_configs/generator_files/to_symlink_dir/mock_1.txt b/tests/test_configs/generator_files/to_symlink_dir/mock_1.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_configs/generator_files/to_symlink_dir/mock2.txt b/tests/test_configs/generator_files/to_symlink_dir/mock_2.txt similarity index 100% rename from tests/test_configs/generator_files/to_symlink_dir/mock2.txt rename to tests/test_configs/generator_files/to_symlink_dir/mock_2.txt diff --git a/tests/test_configs/generator_files/to_symlink_dir/mock_3.txt b/tests/test_configs/generator_files/to_symlink_dir/mock_3.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_configs/send_data.py b/tests/test_configs/send_data.py index f9b9440c47..7c8cc7c25b 100644 --- a/tests/test_configs/send_data.py +++ b/tests/test_configs/send_data.py @@ -42,7 +42,7 @@ def send_data(key): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--iters", type=int, default=10) - parser.add_argument("--name", type=str, default="model") + parser.add_argument("--name", type=str, default="application") args = parser.parse_args() # send data in iterations diff --git a/tests/test_configs/telemetry/colocatedmodel.json b/tests/test_configs/telemetry/colocatedmodel.json index f3e93ac762..77cf910fa7 100644 --- a/tests/test_configs/telemetry/colocatedmodel.json +++ b/tests/test_configs/telemetry/colocatedmodel.json @@ -12,10 +12,10 @@ { "run_id": "002816b", "timestamp": 1699037041106269774, - "model": [ + "application": [ { - "name": "colocated_model", - "path": "/tmp/my-exp/colocated_model", + "name": "colocated_application", + "path": "/tmp/my-exp/colocated_application", "exe_args": [ "/path/to/my/script.py" ], @@ -33,7 +33,7 @@ "Configure": [], "Copy": [] }, - "colocated_db": { + "colocated_fs": { "settings": { "unix_socket": "/tmp/redis.socket", "socket_permissions": 755, @@ -41,19 +41,19 @@ "cpus": 1, "custom_pinning": "0", "debug": false, - "db_identifier": "", + "fs_identifier": "", "rai_args": { "threads_per_queue": null, "inter_op_parallelism": null, "intra_op_parallelism": null }, - "extra_db_args": {} + "extra_fs_args": {} }, "scripts": [], "models": [] }, "telemetry_metadata": { - "status_dir": "/tmp/my-exp/.smartsim/telemetry/telemetry_ensemble/002816b/model/colocated_model", + "status_dir": "/tmp/my-exp/.smartsim/telemetry/telemetry_ensemble/002816b/application/colocated_application", "step_id": "4139111.21", "task_id": "21529", "managed": true @@ -62,8 +62,8 @@ "err_file": "/tmp/my-exp/colocated_model/colocated_model.err" } ], - "orchestrator": [], + "featurestore": [], "ensemble": [] } ] -} +} \ No newline at end of file diff --git a/tests/test_configs/telemetry/db_and_model.json b/tests/test_configs/telemetry/db_and_model.json index 36edc74868..3eebd6fbfe 100644 --- a/tests/test_configs/telemetry/db_and_model.json +++ b/tests/test_configs/telemetry/db_and_model.json @@ -12,17 +12,17 @@ { "run_id": "2ca19ad", "timestamp": 1699038647234488933, - "model": [], - "orchestrator": [ + "application": [], + "featurestore": [ { - "name": "orchestrator", + "name": "featurestore", "type": "redis", "interface": [ "ipogif0" ], "shards": [ { - "name": "orchestrator_0", + "name": "featurestore_0", "hostname": "10.128.0.4", "port": 6780, "cluster": false, @@ -33,7 +33,7 @@ "client_count_file": null, "memory_file": "/path/to/some/mem.log", "telemetry_metadata": { - "status_dir": "/tmp/my-exp/.smartsim/telemetry/telemetry_db_and_model/2ca19ad/database/orchestrator/orchestrator_0", + "status_dir": "/tmp/my-exp/.smartsim/telemetry/telemetry_db_and_model/2ca19ad/database/featurestore/featurestore_0", "step_id": "4139111.27", "task_id": "1452", "managed": true @@ -47,7 +47,7 @@ { "run_id": "4b5507a", "timestamp": 1699038661491043211, - "model": [ + "application": [ { "name": "perroquet", "path": "/tmp/my-exp/perroquet", @@ -71,7 +71,7 @@ "Configure": [], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/tmp/my-exp/.smartsim/telemetry/telemetry_db_and_model/4b5507a/model/perroquet", "step_id": "4139111.28", @@ -82,8 +82,8 @@ "err_file": "/tmp/my-exp/perroquet/perroquet.err" } ], - "orchestrator": [], + "featurestore": [], "ensemble": [] } ] -} +} \ No newline at end of file diff --git a/tests/test_configs/telemetry/db_and_model_1run.json b/tests/test_configs/telemetry/db_and_model_1run.json index 44e32bfe40..ec6be51f58 100644 --- a/tests/test_configs/telemetry/db_and_model_1run.json +++ b/tests/test_configs/telemetry/db_and_model_1run.json @@ -12,7 +12,7 @@ { "run_id": "4b5507a", "timestamp": 1699038661491043211, - "model": [ + "application": [ { "name": "perroquet", "path": "/tmp/my-exp/perroquet", @@ -36,7 +36,7 @@ "Configure": [], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/tmp/my-exp/.smartsim/telemetry/telemetry_db_and_model/4b5507a/model/perroquet", "step_id": "4139111.28", @@ -47,16 +47,16 @@ "err_file": "/tmp/my-exp/perroquet/perroquet.err" } ], - "orchestrator": [ + "featurestore": [ { - "name": "orchestrator", + "name": "featurestore", "type": "redis", "interface": [ "ipogif0" ], "shards": [ { - "name": "orchestrator_0", + "name": "featurestore_0", "hostname": "10.128.0.4", "port": 6780, "cluster": false, @@ -64,7 +64,7 @@ "out_file": "/path/to/some/file.out", "err_file": "/path/to/some/file.err", "telemetry_metadata": { - "status_dir": "/tmp/my-exp/.smartsim/telemetry/telemetry_db_and_model/2ca19ad/database/orchestrator/orchestrator_0", + "status_dir": "/tmp/my-exp/.smartsim/telemetry/telemetry_db_and_model/2ca19ad/database/featurestore/featurestore_0", "step_id": "4139111.27", "task_id": "1452", "managed": true @@ -76,4 +76,4 @@ "ensemble": [] } ] -} +} \ No newline at end of file diff --git a/tests/test_configs/telemetry/ensembles.json b/tests/test_configs/telemetry/ensembles.json index 67e53ca096..e8c4cfc32e 100644 --- a/tests/test_configs/telemetry/ensembles.json +++ b/tests/test_configs/telemetry/ensembles.json @@ -12,8 +12,8 @@ { "run_id": "d041b90", "timestamp": 1698679830384608928, - "model": [], - "orchestrator": [], + "application": [], + "featurestore": [], "ensemble": [ { "name": "my-ens", @@ -32,7 +32,7 @@ ] }, "batch_settings": {}, - "models": [ + "applications": [ { "name": "my-ens_0", "path": "/home/someuser/code/ss", @@ -326,4 +326,4 @@ ] } ] -} +} \ No newline at end of file diff --git a/tests/test_configs/telemetry/serialmodels.json b/tests/test_configs/telemetry/serialmodels.json index 40337ecebe..53c0d9cb8f 100644 --- a/tests/test_configs/telemetry/serialmodels.json +++ b/tests/test_configs/telemetry/serialmodels.json @@ -12,7 +12,7 @@ { "run_id": "8c0fbb1", "timestamp": 1699037881502730708, - "model": [ + "application": [ { "name": "perroquet_0", "path": "/tmp/my-exp/perroquet_0", @@ -179,8 +179,8 @@ "err_file": "/tmp/my-exp/perroquet_4/perroquet_4.err" } ], - "orchestrator": [], + "featurestore": [], "ensemble": [] } ] -} +} \ No newline at end of file diff --git a/tests/test_configs/telemetry/telemetry.json b/tests/test_configs/telemetry/telemetry.json index 916f5922b4..084cc18663 100644 --- a/tests/test_configs/telemetry/telemetry.json +++ b/tests/test_configs/telemetry/telemetry.json @@ -6,12 +6,12 @@ }, "runs": [ { - "run_id": "d999ad89-020f-4e6a-b834-dbd88658ce84", + "run_id": "d999ad89-020f-4e6a-b834-fsd88658ce84", "timestamp": 1697824072792854287, - "model": [ + "application": [ { - "name": "my-model", - "path": "/path/to/my-exp/my-model", + "name": "my-application", + "path": "/path/to/my-exp/my-application", "exe_args": [ "hello", "world" @@ -33,20 +33,20 @@ "Configure": [], "Copy": [] }, - "colocated_db": { + "colocated_fs": { "settings": { "port": 5757, "ifname": "lo", "cpus": 1, "custom_pinning": "0", "debug": false, - "db_identifier": "COLO", + "fs_identifier": "COLO", "rai_args": { "threads_per_queue": null, "inter_op_parallelism": null, "intra_op_parallelism": null }, - "extra_db_args": {} + "extra_fs_args": {} }, "scripts": [], "models": [ @@ -59,7 +59,7 @@ ] }, "telemetry_metadata": { - "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d999ad89-020f-4e6a-b834-dbd88658ce84/model/my-model", + "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d999ad89-020f-4e6a-b834-fsd88658ce84/model/my-model", "step_id": "4121050.30", "task_id": "25230", "managed": true @@ -68,61 +68,61 @@ "err_file": "/path/to/my-exp/my-model/my-model.err" } ], - "orchestrator": [], + "featurestore": [], "ensemble": [] }, { "run_id": "fd3cd1a8-cb8f-4f61-b847-73a8eb0881fa", "timestamp": 1697824102122439975, - "model": [], - "orchestrator": [ + "application": [], + "featurestore": [ { - "name": "orchestrator", + "name": "featurestore", "type": "redis", "interface": [ "ipogif0" ], "shards": [ { - "name": "orchestrator_1", + "name": "featurestore_1", "hostname": "10.128.0.70", "port": 2424, "cluster": true, - "conf_file": "nodes-orchestrator_1-2424.conf", - "out_file": "/path/to/my-exp/orchestrator/orchestrator.out", - "err_file": "/path/to/my-exp/orchestrator/orchestrator.err", + "conf_file": "nodes-featurestore_1-2424.conf", + "out_file": "/path/to/my-exp/featurestore/featurestore.out", + "err_file": "/path/to/my-exp/featurestore/featurestore.err", "telemetry_metadata": { - "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/fd3cd1a8-cb8f-4f61-b847-73a8eb0881fa/database/orchestrator/orchestrator", + "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/fd3cd1a8-cb8f-4f61-b847-73a8eb0881fa/database/featurestore/featurestore", "step_id": "4121050.31+2", "task_id": "25241", "managed": true } }, { - "name": "orchestrator_2", + "name": "featurestore_2", "hostname": "10.128.0.71", "port": 2424, "cluster": true, - "conf_file": "nodes-orchestrator_2-2424.conf", - "out_file": "/path/to/my-exp/orchestrator/orchestrator.out", - "err_file": "/path/to/my-exp/orchestrator/orchestrator.err", + "conf_file": "nodes-featurestore_2-2424.conf", + "out_file": "/path/to/my-exp/featurestore/featurestore.out", + "err_file": "/path/to/my-exp/featurestore/featurestore.err", "telemetry_metadata": { - "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/fd3cd1a8-cb8f-4f61-b847-73a8eb0881fa/database/orchestrator/orchestrator", + "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/fd3cd1a8-cb8f-4f61-b847-73a8eb0881fa/database/featurestore/featurestore", "step_id": "4121050.31+2", "task_id": "25241", "managed": true } }, { - "name": "orchestrator_0", + "name": "featurestore_0", "hostname": "10.128.0.69", "port": 2424, "cluster": true, - "conf_file": "nodes-orchestrator_0-2424.conf", - "out_file": "/path/to/my-exp/orchestrator/orchestrator.out", - "err_file": "/path/to/my-exp/orchestrator/orchestrator.err", + "conf_file": "nodes-featurestore_0-2424.conf", + "out_file": "/path/to/my-exp/featurestore/featurestore.out", + "err_file": "/path/to/my-exp/featurestore/featurestore.err", "telemetry_metadata": { - "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/fd3cd1a8-cb8f-4f61-b847-73a8eb0881fa/database/orchestrator/orchestrator", + "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/fd3cd1a8-cb8f-4f61-b847-73a8eb0881fa/database/featurestore/featurestore", "step_id": "4121050.31+2", "task_id": "25241", "managed": true @@ -136,8 +136,8 @@ { "run_id": "d65ae1df-cb5e-45d9-ab09-6fa641755997", "timestamp": 1697824127962219505, - "model": [], - "orchestrator": [], + "application": [], + "featurestore": [], "ensemble": [ { "name": "my-ens", @@ -156,7 +156,7 @@ ] }, "batch_settings": {}, - "models": [ + "applications": [ { "name": "my-ens_0", "path": "/path/to/my-exp/my-ens/my-ens_0", @@ -186,7 +186,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d65ae1df-cb5e-45d9-ab09-6fa641755997/ensemble/my-ens/my-ens_0", "step_id": "4121050.32", @@ -225,7 +225,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d65ae1df-cb5e-45d9-ab09-6fa641755997/ensemble/my-ens/my-ens_1", "step_id": "4121050.33", @@ -264,7 +264,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d65ae1df-cb5e-45d9-ab09-6fa641755997/ensemble/my-ens/my-ens_2", "step_id": "4121050.34", @@ -303,7 +303,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d65ae1df-cb5e-45d9-ab09-6fa641755997/ensemble/my-ens/my-ens_3", "step_id": "4121050.35", @@ -342,7 +342,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d65ae1df-cb5e-45d9-ab09-6fa641755997/ensemble/my-ens/my-ens_4", "step_id": "4121050.36", @@ -381,7 +381,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d65ae1df-cb5e-45d9-ab09-6fa641755997/ensemble/my-ens/my-ens_5", "step_id": "4121050.37", @@ -420,7 +420,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d65ae1df-cb5e-45d9-ab09-6fa641755997/ensemble/my-ens/my-ens_6", "step_id": "4121050.38", @@ -459,7 +459,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/d65ae1df-cb5e-45d9-ab09-6fa641755997/ensemble/my-ens/my-ens_7", "step_id": "4121050.39", @@ -476,10 +476,10 @@ { "run_id": "e41f8e17-c4b2-441d-adf9-707443ee2c72", "timestamp": 1697835227560376025, - "model": [ + "application": [ { - "name": "my-model", - "path": "/path/to/my-exp/my-model", + "name": "my-application", + "path": "/path/to/my-exp/my-application", "exe_args": [ "hello", "world" @@ -501,20 +501,20 @@ "Configure": [], "Copy": [] }, - "colocated_db": { + "colocated_fs": { "settings": { "port": 5757, "ifname": "lo", "cpus": 1, "custom_pinning": "0", "debug": false, - "db_identifier": "COLO", + "fs_identifier": "COLO", "rai_args": { "threads_per_queue": null, "inter_op_parallelism": null, "intra_op_parallelism": null }, - "extra_db_args": {} + "extra_fs_args": {} }, "scripts": [], "models": [ @@ -536,61 +536,61 @@ "err_file": "/path/to/my-exp/my-model/my-model.err" } ], - "orchestrator": [], + "featurestore": [], "ensemble": [] }, { "run_id": "b33a5d27-6822-4795-8e0e-cfea18551fa4", "timestamp": 1697835261956135240, - "model": [], - "orchestrator": [ + "application": [], + "featurestore": [ { - "name": "orchestrator", + "name": "featurestore", "type": "redis", "interface": [ "ipogif0" ], "shards": [ { - "name": "orchestrator_0", + "name": "featurestore_0", "hostname": "10.128.0.2", "port": 2424, "cluster": true, - "conf_file": "nodes-orchestrator_0-2424.conf", - "out_file": "/path/to/my-exp/orchestrator/orchestrator.out", - "err_file": "/path/to/my-exp/orchestrator/orchestrator.err", + "conf_file": "nodes-featurestore_0-2424.conf", + "out_file": "/path/to/my-exp/featurestore/featurestore.out", + "err_file": "/path/to/my-exp/featurestore/featurestore.err", "telemetry_metadata": { - "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/b33a5d27-6822-4795-8e0e-cfea18551fa4/database/orchestrator/orchestrator", + "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/b33a5d27-6822-4795-8e0e-cfea18551fa4/database/featurestore/featurestore", "step_id": "4121904.1+2", "task_id": "28289", "managed": true } }, { - "name": "orchestrator_2", + "name": "featurestore_2", "hostname": "10.128.0.4", "port": 2424, "cluster": true, - "conf_file": "nodes-orchestrator_2-2424.conf", - "out_file": "/path/to/my-exp/orchestrator/orchestrator.out", - "err_file": "/path/to/my-exp/orchestrator/orchestrator.err", + "conf_file": "nodes-featurestore_2-2424.conf", + "out_file": "/path/to/my-exp/featurestore/featurestore.out", + "err_file": "/path/to/my-exp/featurestore/featurestore.err", "telemetry_metadata": { - "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/b33a5d27-6822-4795-8e0e-cfea18551fa4/database/orchestrator/orchestrator", + "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/b33a5d27-6822-4795-8e0e-cfea18551fa4/database/featurestore/featurestore", "step_id": "4121904.1+2", "task_id": "28289", "managed": true } }, { - "name": "orchestrator_1", + "name": "featurestore_1", "hostname": "10.128.0.3", "port": 2424, "cluster": true, - "conf_file": "nodes-orchestrator_1-2424.conf", - "out_file": "/path/to/my-exp/orchestrator/orchestrator.out", - "err_file": "/path/to/my-exp/orchestrator/orchestrator.err", + "conf_file": "nodes-featurestore_1-2424.conf", + "out_file": "/path/to/my-exp/featurestore/featurestore.out", + "err_file": "/path/to/my-exp/featurestore/featurestore.err", "telemetry_metadata": { - "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/b33a5d27-6822-4795-8e0e-cfea18551fa4/database/orchestrator/orchestrator", + "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/b33a5d27-6822-4795-8e0e-cfea18551fa4/database/featurestore/featurestore", "step_id": "4121904.1+2", "task_id": "28289", "managed": true @@ -604,8 +604,8 @@ { "run_id": "45772df2-fd80-43fd-adf0-d5e319870182", "timestamp": 1697835287798613875, - "model": [], - "orchestrator": [], + "application": [], + "featurestore": [], "ensemble": [ { "name": "my-ens", @@ -624,7 +624,7 @@ ] }, "batch_settings": {}, - "models": [ + "applications": [ { "name": "my-ens_0", "path": "/path/to/my-exp/my-ens/my-ens_0", @@ -654,7 +654,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/45772df2-fd80-43fd-adf0-d5e319870182/ensemble/my-ens/my-ens_0", "step_id": "4121904.2", @@ -693,7 +693,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/45772df2-fd80-43fd-adf0-d5e319870182/ensemble/my-ens/my-ens_1", "step_id": "4121904.3", @@ -732,7 +732,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/45772df2-fd80-43fd-adf0-d5e319870182/ensemble/my-ens/my-ens_2", "step_id": "4121904.4", @@ -771,7 +771,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/45772df2-fd80-43fd-adf0-d5e319870182/ensemble/my-ens/my-ens_3", "step_id": "4121904.5", @@ -810,7 +810,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/45772df2-fd80-43fd-adf0-d5e319870182/ensemble/my-ens/my-ens_4", "step_id": "4121904.6", @@ -849,7 +849,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/45772df2-fd80-43fd-adf0-d5e319870182/ensemble/my-ens/my-ens_5", "step_id": "4121904.7", @@ -888,7 +888,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/45772df2-fd80-43fd-adf0-d5e319870182/ensemble/my-ens/my-ens_6", "step_id": "4121904.8", @@ -927,7 +927,7 @@ ], "Copy": [] }, - "colocated_db": {}, + "colocated_fs": {}, "telemetry_metadata": { "status_dir": "/path/to/my-exp/.smartsim/telemetry/my-exp/45772df2-fd80-43fd-adf0-d5e319870182/ensemble/my-ens/my-ens_7", "step_id": "4121904.9", @@ -942,4 +942,4 @@ ] } ] -} +} \ No newline at end of file diff --git a/tests/test_ensemble.py b/tests/test_ensemble.py index 0632eee16f..1bfbd0b67a 100644 --- a/tests/test_ensemble.py +++ b/tests/test_ensemble.py @@ -24,282 +24,450 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from copy import deepcopy +import typing as t +from glob import glob +from os import path as osp import pytest -from smartsim import Experiment -from smartsim.entity import Ensemble, Model -from smartsim.error import EntityExistsError, SSUnsupportedError, UserStrategyError -from smartsim.settings import RunSettings +from smartsim.builders.ensemble import Ensemble +from smartsim.builders.utils.strategies import ParamSet +from smartsim.entity.files import EntityFiles +from smartsim.settings.launch_settings import LaunchSettings -# The tests in this file belong to the slow_tests group -pytestmark = pytest.mark.slow_tests +pytestmark = pytest.mark.group_a +_2x2_PARAMS = {"SPAM": ["a", "b"], "EGGS": ["c", "d"]} +_2x2_EXE_ARG = {"EXE": [["a"], ["b", "c"]], "ARGS": [["d"], ["e", "f"]]} -""" -Test ensemble creation -TODO: test to add -- test batch settings/run_setting combinations and errors -- test replica creation -""" +@pytest.fixture +def get_gen_configure_dir(fileutils): + yield fileutils.get_test_conf_path(osp.join("generator_files", "tag_dir_template")) -# ---- helpers ------------------------------------------------------ +def user_created_function( + file_params: t.Mapping[str, t.Sequence[str]], + exe_arg_params: t.Mapping[str, t.Sequence[t.Sequence[str]]], + n_permutations: int = 0, +) -> list[ParamSet]: + return [ParamSet({}, {})] -def step_values(param_names, param_values, n_models=0): - permutations = [] - for p in zip(*param_values): - permutations.append(dict(zip(param_names, p))) - return permutations +@pytest.fixture +def mock_launcher_settings(wlmutils): + return LaunchSettings(wlmutils.get_test_launcher(), {}, {}) -# bad permutation strategy that doesn't return -# a list of dictionaries -def bad_strategy(names, values, n_models=0): - return -1 +def test_exe_property(): + e = Ensemble(name="test", exe="path/to/example_simulation_program") + exe = e.exe + assert exe == e.exe -# test bad perm strategy that returns a list but of lists -# not dictionaries -def bad_strategy_2(names, values, n_models=0): - return [values] +def test_exe_args_property(): + e = Ensemble("test", exe="path/to/example_simulation_program", exe_args="sleepy.py") + exe_args = e.exe_args + assert exe_args == e.exe_args -rs = RunSettings("python", exe_args="sleep.py") -# ----- Test param generation ---------------------------------------- +def test_exe_arg_parameters_property(): + exe_arg_parameters = {"-N": 2} + e = Ensemble( + "test", + exe="path/to/example_simulation_program", + exe_arg_parameters=exe_arg_parameters, + ) + exe_arg_parameters = e.exe_arg_parameters + assert exe_arg_parameters == e.exe_arg_parameters -def test_all_perm(): - """Test all permutation strategy""" - params = {"h": [5, 6]} - ensemble = Ensemble("all_perm", params, run_settings=rs, perm_strat="all_perm") - assert len(ensemble) == 2 - assert ensemble.entities[0].params["h"] == "5" - assert ensemble.entities[1].params["h"] == "6" +def test_files_property(get_gen_configure_dir): + tagged_files = sorted(glob(get_gen_configure_dir + "/*")) + files = EntityFiles(tagged=tagged_files) + e = Ensemble("test", exe="path/to/example_simulation_program", files=files) + files = e.files + assert files == e.files -def test_step(): - """Test step strategy""" - params = {"h": [5, 6], "g": [7, 8]} - ensemble = Ensemble("step", params, run_settings=rs, perm_strat="step") - assert len(ensemble) == 2 +def test_file_parameters_property(): + file_parameters = {"h": [5, 6, 7, 8]} + e = Ensemble( + "test", + exe="path/to/example_simulation_program", + file_parameters=file_parameters, + ) + file_parameters = e.file_parameters + assert file_parameters == e.file_parameters - model_1_params = {"h": "5", "g": "7"} - assert ensemble.entities[0].params == model_1_params - model_2_params = {"h": "6", "g": "8"} - assert ensemble.entities[1].params == model_2_params +def test_ensemble_init_empty_params(test_dir: str) -> None: + """Ensemble created without required args""" + with pytest.raises(TypeError): + Ensemble() + + +@pytest.mark.parametrize( + "bad_settings", + [pytest.param(None, id="Nullish"), pytest.param("invalid", id="String")], +) +def test_ensemble_incorrect_launch_settings_type(bad_settings): + """test starting an ensemble with invalid launch settings""" + ensemble = Ensemble("ensemble-name", "echo", replicas=2) + with pytest.raises(TypeError): + ensemble.build_jobs(bad_settings) -def test_random(): - """Test random strategy""" - random_ints = [4, 5, 6, 7, 8] - params = {"h": random_ints} +def test_ensemble_type_exe(): ensemble = Ensemble( - "random_test", - params, - run_settings=rs, - perm_strat="random", - n_models=len(random_ints), + "ensemble-name", + exe="valid", + exe_args=["spam", "eggs"], ) - assert len(ensemble) == len(random_ints) - assigned_params = [m.params["h"] for m in ensemble.entities] - assert all([int(x) in random_ints for x in assigned_params]) - + with pytest.raises( + TypeError, match="exe argument was not of type str or PathLike str" + ): + ensemble.exe = 2 + + +@pytest.mark.parametrize( + "bad_settings", + [ + pytest.param([1, 2, 3], id="sequence of ints"), + pytest.param(0, id="null"), + pytest.param({"foo": "bar"}, id="dict"), + ], +) +def test_ensemble_type_exe_args(bad_settings): ensemble = Ensemble( - "random_test", - params, - run_settings=rs, - perm_strat="random", - n_models=len(random_ints) - 1, + "ensemble-name", + exe="echo", ) - assert len(ensemble) == len(random_ints) - 1 - assigned_params = [m.params["h"] for m in ensemble.entities] - assert all([int(x) in random_ints for x in assigned_params]) - - -def test_user_strategy(): - """Test a user provided strategy""" - params = {"h": [5, 6], "g": [7, 8]} - ensemble = Ensemble("step", params, run_settings=rs, perm_strat=step_values) - assert len(ensemble) == 2 - - model_1_params = {"h": "5", "g": "7"} - assert ensemble.entities[0].params == model_1_params - - model_2_params = {"h": "6", "g": "8"} - assert ensemble.entities[1].params == model_2_params - - -# ----- Model arguments ------------------------------------- - - -def test_arg_params(): - """Test parameterized exe arguments""" - params = {"H": [5, 6], "g_param": ["a", "b"]} - - # Copy rs to avoid modifying referenced object - rs_copy = deepcopy(rs) - rs_orig_args = rs_copy.exe_args + with pytest.raises( + TypeError, match="exe_args argument was not of type sequence of str" + ): + ensemble.exe_args = bad_settings + + +@pytest.mark.parametrize( + "exe_arg_params", + ( + pytest.param(["invalid"], id="Not a mapping"), + pytest.param({"key": [1, 2, 3]}, id="Value is not sequence of sequences"), + pytest.param( + {"key": [[1, 2, 3], [4, 5, 6]]}, + id="Value is not sequence of sequence of str", + ), + pytest.param( + {1: 2}, + id="key and value wrong type", + ), + pytest.param({"1": 2}, id="Value is not mapping of str and str"), + pytest.param({1: "2"}, id="Key is not str"), + pytest.param({1: 2}, id="Values not mapping of str and str"), + ), +) +def test_ensemble_type_exe_arg_parameters(exe_arg_params): ensemble = Ensemble( - "step", - params=params, - params_as_args=list(params.keys()), - run_settings=rs_copy, - perm_strat="step", + "ensemble-name", + exe="echo", + exe_args=["spam", "eggs"], ) - assert len(ensemble) == 2 - - exe_args_0 = rs_orig_args + ["-H", "5", "--g_param=a"] - assert ensemble.entities[0].run_settings.exe_args == exe_args_0 - - exe_args_1 = rs_orig_args + ["-H", "6", "--g_param=b"] - assert ensemble.entities[1].run_settings.exe_args == exe_args_1 - + with pytest.raises( + TypeError, + match="exe_arg_parameters argument was not of type mapping " + "of str and sequences of sequences of strings", + ): + ensemble.exe_arg_parameters = exe_arg_params -def test_arg_and_model_params_step(): - """Test parameterized exe arguments combined with - model parameters and step strategy - """ - params = {"H": [5, 6], "g_param": ["a", "b"], "h": [5, 6], "g": [7, 8]} - # Copy rs to avoid modifying referenced object - rs_copy = deepcopy(rs) - rs_orig_args = rs_copy.exe_args +def test_ensemble_type_files(): ensemble = Ensemble( - "step", - params, - params_as_args=["H", "g_param"], - run_settings=rs_copy, - perm_strat="step", + "ensemble-name", + exe="echo", + exe_args=["spam", "eggs"], ) - assert len(ensemble) == 2 - - exe_args_0 = rs_orig_args + ["-H", "5", "--g_param=a"] - assert ensemble.entities[0].run_settings.exe_args == exe_args_0 - - exe_args_1 = rs_orig_args + ["-H", "6", "--g_param=b"] - assert ensemble.entities[1].run_settings.exe_args == exe_args_1 - - model_1_params = {"H": "5", "g_param": "a", "h": "5", "g": "7"} - assert ensemble.entities[0].params == model_1_params - - model_2_params = {"H": "6", "g_param": "b", "h": "6", "g": "8"} - assert ensemble.entities[1].params == model_2_params - - -def test_arg_and_model_params_all_perms(): - """Test parameterized exe arguments combined with - model parameters and all_perm strategy - """ - params = {"h": [5, 6], "g_param": ["a", "b"]} - - # Copy rs to avoid modifying referenced object - rs_copy = deepcopy(rs) - rs_orig_args = rs_copy.exe_args + with pytest.raises(TypeError, match="files argument was not of type EntityFiles"): + ensemble.files = 2 + + +@pytest.mark.parametrize( + "file_params", + ( + pytest.param(["invalid"], id="Not a mapping"), + pytest.param({"key": [1, 2, 3]}, id="Key is not sequence of sequences"), + ), +) +def test_ensemble_type_file_parameters(file_params): ensemble = Ensemble( - "step", - params, - params_as_args=["g_param"], - run_settings=rs_copy, - perm_strat="all_perm", + "ensemble-name", + exe="echo", + exe_args=["spam", "eggs"], ) - assert len(ensemble) == 4 - - exe_args_0 = rs_orig_args + ["--g_param=a"] - assert ensemble.entities[0].run_settings.exe_args == exe_args_0 - assert ensemble.entities[2].run_settings.exe_args == exe_args_0 - - exe_args_1 = rs_orig_args + ["--g_param=b"] - assert ensemble.entities[1].run_settings.exe_args == exe_args_1 - assert ensemble.entities[3].run_settings.exe_args == exe_args_1 - - model_0_params = {"g_param": "a", "h": "5"} - assert ensemble.entities[0].params == model_0_params - model_1_params = {"g_param": "b", "h": "5"} - assert ensemble.entities[1].params == model_1_params - model_2_params = {"g_param": "a", "h": "6"} - assert ensemble.entities[2].params == model_2_params - model_3_params = {"g_param": "b", "h": "6"} - assert ensemble.entities[3].params == model_3_params + with pytest.raises( + TypeError, + match="file_parameters argument was not of type " + "mapping of str and sequence of str", + ): + ensemble.file_parameters = file_params -# ----- Error Handling -------------------------------------- - - -# unknown permuation strategy -def test_unknown_perm_strat(): - bad_strat = "not-a-strategy" - with pytest.raises(SSUnsupportedError): - e = Ensemble("ensemble", {}, run_settings=rs, perm_strat=bad_strat) +def test_ensemble_type_permutation_strategy(): + ensemble = Ensemble( + "ensemble-name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises( + TypeError, + match="permutation_strategy argument was not of " + "type str or PermutationStrategyType", + ): + ensemble.permutation_strategy = 2 -def test_bad_perm_strat(): - params = {"h": [2, 3]} - with pytest.raises(UserStrategyError): - e = Ensemble("ensemble", params, run_settings=rs, perm_strat=bad_strategy) +def test_ensemble_type_max_permutations(): + ensemble = Ensemble( + "ensemble-name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises( + TypeError, + match="max_permutations argument was not of type int", + ): + ensemble.max_permutations = "invalid" -def test_bad_perm_strat_2(): - params = {"h": [2, 3]} - with pytest.raises(UserStrategyError): - e = Ensemble("ensemble", params, run_settings=rs, perm_strat=bad_strategy_2) +def test_ensemble_type_replicas(): + ensemble = Ensemble( + "ensemble-name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises( + TypeError, + match="replicas argument was not of type int", + ): + ensemble.replicas = "invalid" -# bad argument type in params -def test_incorrect_param_type(): - # can either be a list, str, or int - params = {"h": {"h": [5]}} - with pytest.raises(TypeError): - e = Ensemble("ensemble", params, run_settings=rs) +def test_ensemble_type_replicas_negative(): + ensemble = Ensemble( + "ensemble-name", + exe="echo", + exe_args=["spam", "eggs"], + ) + with pytest.raises( + ValueError, + match="Number of replicas must be a positive integer", + ): + ensemble.replicas = -20 -def test_add_model_type(): - params = {"h": 5} - e = Ensemble("ensemble", params, run_settings=rs) +def test_ensemble_type_build_jobs(): + ensemble = Ensemble("ensemble-name", "echo", replicas=2) with pytest.raises(TypeError): - # should be a Model not string - e.add_model("model") - - -def test_add_existing_model(): - params_1 = {"h": 5} - params_2 = {"z": 6} - model_1 = Model("identical_name", params_1, "", rs) - model_2 = Model("identical_name", params_2, "", rs) - e = Ensemble("ensemble", params_1, run_settings=rs) - e.add_model(model_1) - with pytest.raises(EntityExistsError): - e.add_model(model_2) - - -# ----- Other -------------------------------------- - - -def test_models_property(): - params = {"h": [5, 6, 7, 8]} - e = Ensemble("test", params, run_settings=rs) - models = e.models - assert models == [model for model in e] - - -def test_key_prefixing(): - params_1 = {"h": [5, 6, 7, 8]} - params_2 = {"z": 6} - e = Ensemble("test", params_1, run_settings=rs) - model = Model("model", params_2, "", rs) - e.add_model(model) - assert e.query_key_prefixing() == False - e.enable_key_prefixing() - assert e.query_key_prefixing() == True + ensemble.build_jobs("invalid") + + +def test_ensemble_user_created_strategy(mock_launcher_settings, test_dir): + jobs = Ensemble( + "test_ensemble", + "echo", + ("hello", "world"), + permutation_strategy=user_created_function, + ).build_jobs(mock_launcher_settings) + assert len(jobs) == 1 + + +def test_ensemble_without_any_members_raises_when_cast_to_jobs( + mock_launcher_settings, test_dir +): + with pytest.raises(ValueError): + Ensemble( + "test_ensemble", + "echo", + ("hello", "world"), + file_parameters=_2x2_PARAMS, + permutation_strategy="random", + max_permutations=30, + replicas=0, + ).build_jobs(mock_launcher_settings) + + +def test_strategy_error_raised_if_a_strategy_that_dne_is_requested(test_dir): + with pytest.raises(ValueError): + Ensemble( + "test_ensemble", + "echo", + ("hello",), + permutation_strategy="THIS-STRATEGY-DNE", + )._create_applications() + + +@pytest.mark.parametrize( + "file_parameters", + ( + pytest.param({"SPAM": ["eggs"]}, id="Non-Empty Params"), + pytest.param({}, id="Empty Params"), + pytest.param(None, id="Nullish Params"), + ), +) +def test_replicated_applications_have_eq_deep_copies_of_parameters( + file_parameters, test_dir +): + apps = list( + Ensemble( + "test_ensemble", + "echo", + ("hello",), + replicas=4, + file_parameters=file_parameters, + )._create_applications() + ) + assert len(apps) >= 2 # Sanitiy check to make sure the test is valid + assert all( + app_1.file_parameters == app_2.file_parameters + for app_1 in apps + for app_2 in apps + ) + assert all( + app_1.file_parameters is not app_2.file_parameters + for app_1 in apps + for app_2 in apps + if app_1 is not app_2 + ) -def test_ensemble_type(): - exp = Experiment("name") - ens_settings = RunSettings("python") - ensemble = exp.create_ensemble("name", replicas=4, run_settings=ens_settings) - assert ensemble.type == "Ensemble" +# fmt: off +@pytest.mark.parametrize( + " params, exe_arg_params, max_perms, replicas, expected_num_jobs", + (pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 30, 1, 16 , id="Set max permutation high"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, -1, 1, 16 , id="Set max permutation negative"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 0, 1, 1 , id="Set max permutation zero"), + pytest.param(_2x2_PARAMS, None, 4, 1, 4 , id="No exe arg params or Replicas"), + pytest.param( None, _2x2_EXE_ARG, 4, 1, 4 , id="No Parameters or Replicas"), + pytest.param( None, None, 4, 1, 1 , id="No Parameters, Exe_Arg_Param or Replicas"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 1, 1, 1 , id="Set max permutation to lowest"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 6, 2, 12 , id="Set max permutation, set replicas"), + pytest.param( {}, _2x2_EXE_ARG, 6, 2, 8 , id="Set params as dict, set max permutations and replicas"), + pytest.param(_2x2_PARAMS, {}, 6, 2, 8 , id="Set params as dict, set max permutations and replicas"), + pytest.param( {}, {}, 6, 2, 2 , id="Set params as dict, set max permutations and replicas") +)) +# fmt: on +def test_all_perm_strategy( + # Parameterized + params, + exe_arg_params, + max_perms, + replicas, + expected_num_jobs, + # Other fixtures + mock_launcher_settings, + test_dir, +): + jobs = Ensemble( + "test_ensemble", + "echo", + ("hello", "world"), + file_parameters=params, + exe_arg_parameters=exe_arg_params, + permutation_strategy="all_perm", + max_permutations=max_perms, + replicas=replicas, + ).build_jobs(mock_launcher_settings) + assert len(jobs) == expected_num_jobs + + +def test_all_perm_strategy_contents(mock_launcher_settings): + jobs = Ensemble( + "test_ensemble", + "echo", + ("hello", "world"), + file_parameters=_2x2_PARAMS, + exe_arg_parameters=_2x2_EXE_ARG, + permutation_strategy="all_perm", + max_permutations=16, + replicas=1, + ).build_jobs(mock_launcher_settings) + assert len(jobs) == 16 + + +# fmt: off +@pytest.mark.parametrize( + " params, exe_arg_params, max_perms, replicas, expected_num_jobs", + (pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 30, 1, 2 , id="Set max permutation high"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, -1, 1, 2 , id="Set max permutation negtive"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 0, 1, 1 , id="Set max permutation zero"), + pytest.param(_2x2_PARAMS, None, 4, 1, 1 , id="No exe arg params or Replicas"), + pytest.param( None, _2x2_EXE_ARG, 4, 1, 1 , id="No Parameters or Replicas"), + pytest.param( None, None, 4, 1, 1 , id="No Parameters, Exe_Arg_Param or Replicas"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 1, 1, 1 , id="Set max permutation to lowest"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 6, 2, 4 , id="Set max permutation, set replicas"), + pytest.param( {}, _2x2_EXE_ARG, 6, 2, 2 , id="Set params as dict, set max permutations and replicas"), + pytest.param(_2x2_PARAMS, {}, 6, 2, 2 , id="Set params as dict, set max permutations and replicas"), + pytest.param( {}, {}, 6, 2, 2 , id="Set params as dict, set max permutations and replicas") +)) +# fmt: on +def test_step_strategy( + # Parameterized + params, + exe_arg_params, + max_perms, + replicas, + expected_num_jobs, + # Other fixtures + mock_launcher_settings, + test_dir, +): + jobs = Ensemble( + "test_ensemble", + "echo", + ("hello", "world"), + file_parameters=params, + exe_arg_parameters=exe_arg_params, + permutation_strategy="step", + max_permutations=max_perms, + replicas=replicas, + ).build_jobs(mock_launcher_settings) + assert len(jobs) == expected_num_jobs + + +# fmt: off +@pytest.mark.parametrize( + " params, exe_arg_params, max_perms, replicas, expected_num_jobs", + (pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 30, 1, 16 , id="Set max permutation high"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, -1, 1, 16 , id="Set max permutation negative"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 0, 1, 1 , id="Set max permutation zero"), + pytest.param(_2x2_PARAMS, None, 4, 1, 4 , id="No exe arg params or Replicas"), + pytest.param( None, _2x2_EXE_ARG, 4, 1, 4 , id="No Parameters or Replicas"), + pytest.param( None, None, 4, 1, 1 , id="No Parameters, Exe_Arg_Param or Replicas"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 1, 1, 1 , id="Set max permutation to lowest"), + pytest.param(_2x2_PARAMS, _2x2_EXE_ARG, 6, 2, 12 , id="Set max permutation, set replicas"), + pytest.param( {}, _2x2_EXE_ARG, 6, 2, 8 , id="Set params as dict, set max permutations and replicas"), + pytest.param(_2x2_PARAMS, {}, 6, 2, 8 , id="Set params as dict, set max permutations and replicas"), + pytest.param( {}, {}, 6, 2, 2 , id="Set params as dict, set max permutations and replicas") +)) +# fmt: on +def test_random_strategy( + # Parameterized + params, + exe_arg_params, + max_perms, + replicas, + expected_num_jobs, + # Other fixtures + mock_launcher_settings, +): + jobs = Ensemble( + "test_ensemble", + "echo", + ("hello", "world"), + file_parameters=params, + exe_arg_parameters=exe_arg_params, + permutation_strategy="random", + max_permutations=max_perms, + replicas=replicas, + ).build_jobs(mock_launcher_settings) + assert len(jobs) == expected_num_jobs diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 4bae09e68a..45f3ecf8e5 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -23,348 +23,767 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import os -import os.path as osp -import pathlib -import shutil -import typing as t - -import pytest - -from smartsim import Experiment -from smartsim._core.config import CONFIG -from smartsim._core.config.config import Config -from smartsim._core.utils import serialize -from smartsim.database import Orchestrator -from smartsim.entity import Model -from smartsim.error import SmartSimError -from smartsim.error.errors import SSUnsupportedError -from smartsim.settings import RunSettings -from smartsim.status import SmartSimStatus - -if t.TYPE_CHECKING: - import conftest +from __future__ import annotations -# The tests in this file belong to the slow_tests group -pytestmark = pytest.mark.slow_tests - +import dataclasses +import io +import itertools +import random +import re +import time +import typing as t +import uuid +from os import path as osp -def test_model_prefix(test_dir: str) -> None: - exp_name = "test_prefix" - exp = Experiment(exp_name) +import pytest - model = exp.create_model( - "model", - path=test_dir, - run_settings=RunSettings("python"), - enable_key_prefixing=True, +from smartsim._core import dispatch +from smartsim._core.control.launch_history import LaunchHistory +from smartsim._core.generation.generator import Job_Path +from smartsim._core.utils.launcher import LauncherProtocol, create_job_id +from smartsim.builders.ensemble import Ensemble +from smartsim.entity import entity +from smartsim.entity.application import Application +from smartsim.error import errors +from smartsim.experiment import Experiment +from smartsim.launchable import job +from smartsim.settings import launch_settings +from smartsim.settings.arguments import launch_arguments +from smartsim.status import InvalidJobStatus, JobStatus +from smartsim.types import LaunchedJobID + +pytestmark = pytest.mark.group_a + +_ID_GENERATOR = (str(i) for i in itertools.count()) + + +def random_id(): + return next(_ID_GENERATOR) + + +@pytest.fixture +def experiment(monkeypatch, test_dir, dispatcher): + """A simple experiment instance with a unique name anda unique name and its + own directory to be used by tests + """ + exp = Experiment(f"test-exp-{uuid.uuid4()}", test_dir) + monkeypatch.setattr(dispatch, "DEFAULT_DISPATCHER", dispatcher) + monkeypatch.setattr( + exp, + "_generate", + lambda generator, job, idx: Job_Path( + "/tmp/job", "/tmp/job/out.txt", "/tmp/job/err.txt" + ), ) - assert model._key_prefixing_enabled == True - - -def test_model_no_name(): - exp = Experiment("test_model_no_name") - with pytest.raises(AttributeError): - _ = exp.create_model(name=None, run_settings=RunSettings("python")) - - -def test_ensemble_no_name(): - exp = Experiment("test_ensemble_no_name") - with pytest.raises(AttributeError): - _ = exp.create_ensemble( - name=None, run_settings=RunSettings("python"), replicas=2 + yield exp + + +@pytest.fixture +def dispatcher(): + """A pre-configured dispatcher to be used by experiments that simply + dispatches any jobs with `MockLaunchArgs` to a `NoOpRecordLauncher` + """ + d = dispatch.Dispatcher() + to_record: dispatch.FormatterType[MockLaunchArgs, LaunchRecord] = ( + lambda settings, exe, path, env, out, err: LaunchRecord( + settings, exe, env, path, out, err ) - - -def test_bad_exp_path() -> None: - with pytest.raises(NotADirectoryError): - exp = Experiment("test", "not-a-directory") - - -def test_type_exp_path() -> None: - with pytest.raises(TypeError): - exp = Experiment("test", ["this-is-a-list-dummy"]) - - -def test_stop_type() -> None: - """Wrong argument type given to stop""" - exp = Experiment("name") - with pytest.raises(TypeError): - exp.stop("model") - - -def test_finished_new_model() -> None: - # finished should fail as this model hasn't been - # launched yet. - - model = Model("name", {}, "./", RunSettings("python")) - exp = Experiment("test") - with pytest.raises(ValueError): - exp.finished(model) - - -def test_status_typeerror() -> None: - exp = Experiment("test") - with pytest.raises(TypeError): - exp.get_status([]) - - -def test_status_pre_launch() -> None: - model = Model("name", {}, "./", RunSettings("python")) - exp = Experiment("test") - assert exp.get_status(model)[0] == SmartSimStatus.STATUS_NEVER_STARTED - - -def test_bad_ensemble_init_no_rs(test_dir: str) -> None: - """params supplied without run settings""" - exp = Experiment("test", exp_path=test_dir) - with pytest.raises(SmartSimError): - exp.create_ensemble("name", {"param1": 1}) - - -def test_bad_ensemble_init_no_params(test_dir: str) -> None: - """params supplied without run settings""" - exp = Experiment("test", exp_path=test_dir) - with pytest.raises(SmartSimError): - exp.create_ensemble("name", run_settings=RunSettings("python")) - - -def test_bad_ensemble_init_no_rs_bs(test_dir: str) -> None: - """ensemble init without run settings or batch settings""" - exp = Experiment("test", exp_path=test_dir) - with pytest.raises(SmartSimError): - exp.create_ensemble("name") - - -def test_stop_entity(test_dir: str) -> None: - exp_name = "test_stop_entity" - exp = Experiment(exp_name, exp_path=test_dir) - m = exp.create_model("model", path=test_dir, run_settings=RunSettings("sleep", "5")) - exp.start(m, block=False) - assert exp.finished(m) == False - exp.stop(m) - assert exp.finished(m) == True - - -def test_poll(test_dir: str) -> None: - # Ensure that a SmartSimError is not raised - exp_name = "test_exp_poll" - exp = Experiment(exp_name, exp_path=test_dir) - model = exp.create_model( - "model", path=test_dir, run_settings=RunSettings("sleep", "5") ) - exp.start(model, block=False) - exp.poll(interval=1) - exp.stop(model) + d.dispatch(MockLaunchArgs, with_format=to_record, to_launcher=NoOpRecordLauncher) + yield d -def test_summary(test_dir: str) -> None: - exp_name = "test_exp_summary" - exp = Experiment(exp_name, exp_path=test_dir) - m = exp.create_model( - "model", path=test_dir, run_settings=RunSettings("echo", "Hello") - ) - exp.start(m) - summary_str = exp.summary(style="plain") - print(summary_str) +@pytest.fixture +def job_maker(monkeypatch): + """A fixture to generate a never ending stream of `Job` instances each + configured with a unique `MockLaunchArgs` instance, but identical + executable. + """ - summary_lines = summary_str.split("\n") - assert 2 == len(summary_lines) + def iter_jobs(): + for i in itertools.count(): + settings = launch_settings.LaunchSettings("local") + monkeypatch.setattr(settings, "_arguments", MockLaunchArgs(i)) + yield job.Job(EchoHelloWorldEntity(), settings) - headers, values = [s.split() for s in summary_lines] - headers = ["Index"] + headers + jobs = iter_jobs() + yield lambda: next(jobs) - row = dict(zip(headers, values)) - assert m.name == row["Name"] - assert m.type == row["Entity-Type"] - assert 0 == int(row["RunID"]) - assert 0 == int(row["Returncode"]) +JobMakerType: t.TypeAlias = t.Callable[[], job.Job] -def test_launcher_detection( - wlmutils: "conftest.WLMUtils", monkeypatch: pytest.MonkeyPatch -) -> None: - if wlmutils.get_test_launcher() == "pals": - pytest.skip(reason="Launcher detection cannot currently detect pbs vs pals") - if wlmutils.get_test_launcher() == "local": - monkeypatch.setenv("PATH", "") # Remove all WLMs from PATH - if wlmutils.get_test_launcher() == "dragon": - pytest.skip(reason="Launcher detection cannot currently detect dragon") - exp = Experiment("test-launcher-detection", launcher="auto") +@dataclasses.dataclass(frozen=True, eq=False) +class NoOpRecordLauncher(LauncherProtocol): + """Simple launcher to track the order of and mapping of ids to `start` + method calls. It has exactly three attrs: - assert exp._launcher == wlmutils.get_test_launcher() + - `created_by_experiment`: + A back ref to the experiment used when calling + `NoOpRecordLauncher.create`. + - `launched_order`: + An append-only list of `LaunchRecord`s that it has "started". Notice + that this launcher will not actually open any subprocesses/run any + threads/otherwise execute the contents of the record on the system -def test_enable_disable_telemetry( - monkeypatch: pytest.MonkeyPatch, test_dir: str, config: Config -) -> None: - # Global telemetry defaults to `on` and can be modified by - # setting the value of env var SMARTSIM_FLAG_TELEMETRY to 0/1 - monkeypatch.setattr(os, "environ", {}) - exp = Experiment("my-exp", exp_path=test_dir) - exp.telemetry.enable() - assert exp.telemetry.is_enabled + - `ids_to_launched`: + A mapping where keys are the generated launched id returned from + a `NoOpRecordLauncher.start` call and the values are the + `LaunchRecord` that was passed into `NoOpRecordLauncher.start` to + cause the id to be generated. - exp.telemetry.disable() - assert not exp.telemetry.is_enabled + This is helpful for testing that launchers are handling the expected input + """ - exp.telemetry.enable() - assert exp.telemetry.is_enabled - - exp.telemetry.disable() - assert not exp.telemetry.is_enabled - - exp.start() - mani_path = ( - pathlib.Path(test_dir) / config.telemetry_subdir / serialize.MANIFEST_FILENAME + created_by_experiment: Experiment + launched_order: list[LaunchRecord] = dataclasses.field(default_factory=list) + ids_to_launched: dict[dispatch.LaunchedJobID, LaunchRecord] = dataclasses.field( + default_factory=dict ) - assert mani_path.exists() - -def test_telemetry_default( - monkeypatch: pytest.MonkeyPatch, test_dir: str, config: Config + __hash__ = object.__hash__ + + @classmethod + def create(cls, exp): + return cls(exp) + + def start(self, record: LaunchRecord): + id_ = create_job_id() + self.launched_order.append(record) + self.ids_to_launched[id_] = record + return id_ + + def get_status(self, *ids): + raise NotImplementedError + + def stop_jobs(self, *ids): + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class LaunchRecord: + launch_args: launch_arguments.LaunchArguments + entity: entity.SmartSimEntity + env: t.Mapping[str, str | None] + path: str + out: str + err: str + + @classmethod + def from_job(cls, job: job.Job): + """Create a launch record for what we would expect a launch record to + look like having gone through the launching process + + :param job: A job that has or will be launched through an experiment + and dispatched to a `NoOpRecordLauncher` + :returns: A `LaunchRecord` that should evaluate to being equivilient to + that of the one stored in the `NoOpRecordLauncher` + """ + args = job._launch_settings.launch_args + entity = job._entity.as_executable_sequence() + env = job._launch_settings.env_vars + path = "/tmp/job" + out = "/tmp/job/out.txt" + err = "/tmp/job/err.txt" + return cls(args, entity, env, path, out, err) + + +class MockLaunchArgs(launch_arguments.LaunchArguments): + """A `LaunchArguments` subclass that will evaluate as true with another if + and only if they were initialized with the same id. In practice this class + has no arguments to set. + """ + + def __init__(self, id_: int): + super().__init__({}) + self.id = id_ + + def __eq__(self, other): + if type(self) is not type(other): + return NotImplemented + return other.id == self.id + + def launcher_str(self): + return "test-launch-args" + + def set(self, arg, val): ... + + +class EchoHelloWorldEntity(entity.SmartSimEntity): + """A simple smartsim entity""" + + def __init__(self): + super().__init__("test-entity") + + def __eq__(self, other): + if type(self) is not type(other): + return NotImplemented + return self.as_executable_sequence() == other.as_executable_sequence() + + def as_executable_sequence(self): + return ("echo", "Hello", "World!") + + +# fmt: off +@pytest.mark.parametrize( + "num_jobs", [pytest.param(i, id=f"{i} job(s)") for i in (1, 2, 3, 5, 10, 100, 1_000)] +) +@pytest.mark.parametrize( + "make_jobs", ( + pytest.param(lambda maker, n: tuple(maker() for _ in range(n)), id="many job instances"), + pytest.param(lambda maker, n: (maker(),) * n , id="same job instance many times"), + ), +) +# fmt: on +def test_start_can_launch_jobs( + experiment: Experiment, + job_maker: JobMakerType, + make_jobs: t.Callable[[JobMakerType, int], tuple[job.Job, ...]], + num_jobs: int, ) -> None: - """Ensure the default values for telemetry configuration match expectation - that experiment telemetry is on""" - - # If env var related to telemetry doesn't exist, experiment should default to True - monkeypatch.setattr(os, "environ", {}) - exp = Experiment("my-exp", exp_path=test_dir) - assert exp.telemetry.is_enabled - - # If telemetry disabled in env, should get False - monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", "0") - exp = Experiment("my-exp", exp_path=test_dir) - assert not exp.telemetry.is_enabled - - # If telemetry enabled in env, should get True - monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", "1") - exp = Experiment("my-exp", exp_path=test_dir) - assert exp.telemetry.is_enabled - - -def test_error_on_cobalt() -> None: - with pytest.raises(SSUnsupportedError): - exp = Experiment("cobalt_exp", launcher="cobalt") - - -def test_default_orch_path( - monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" + jobs = make_jobs(job_maker, num_jobs) + assert ( + len(list(experiment._launch_history.iter_past_launchers())) == 0 + ), "Initialized w/ launchers" + launched_ids = experiment.start(*jobs) + assert ( + len(list(experiment._launch_history.iter_past_launchers())) == 1 + ), "Unexpected number of launchers" + ((launcher, exp_cached_ids),) = ( + experiment._launch_history.group_by_launcher().items() + ) + assert isinstance(launcher, NoOpRecordLauncher), "Unexpected launcher type" + assert launcher.created_by_experiment is experiment, "Not created by experiment" + assert ( + len(jobs) == len(launcher.launched_order) == len(launched_ids) == num_jobs + ), "Inconsistent number of jobs/launched jobs/launched ids/expected number of jobs" + expected_launched = [LaunchRecord.from_job(job) for job in jobs] + + # Check that `job_a, job_b, job_c, ...` are started in that order when + # calling `experiemnt.start(job_a, job_b, job_c, ...)` + assert expected_launched == list(launcher.launched_order), "Unexpected launch order" + assert sorted(launched_ids) == sorted(exp_cached_ids), "Exp did not cache ids" + + # Similarly, check that `id_a, id_b, id_c, ...` corresponds to + # `job_a, job_b, job_c, ...` when calling + # `id_a, id_b, id_c, ... = experiemnt.start(job_a, job_b, job_c, ...)` + expected_id_map = dict(zip(launched_ids, expected_launched)) + assert expected_id_map == launcher.ids_to_launched, "IDs returned in wrong order" + + +@pytest.mark.parametrize( + "num_starts", + [pytest.param(i, id=f"{i} start(s)") for i in (1, 2, 3, 5, 10, 100, 1_000)], +) +def test_start_can_start_a_job_multiple_times_accross_multiple_calls( + experiment: Experiment, job_maker: JobMakerType, num_starts: int ) -> None: - """Ensure the default file structure is created for Orchestrator""" - - exp_name = "default-orch-path" - exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) - monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) - db = exp.create_database( - port=wlmutils.get_test_port(), interface=wlmutils.get_test_interface() + assert ( + len(list(experiment._launch_history.iter_past_launchers())) == 0 + ), "Initialized w/ launchers" + job = job_maker() + ids_to_launches = { + experiment.start(job)[0]: LaunchRecord.from_job(job) for _ in range(num_starts) + } + assert ( + len(list(experiment._launch_history.iter_past_launchers())) == 1 + ), "Did not reuse the launcher" + ((launcher, exp_cached_ids),) = ( + experiment._launch_history.group_by_launcher().items() ) - exp.start(db) - orch_path = pathlib.Path(test_dir) / db.name - assert orch_path.exists() - assert db.path == str(orch_path) + assert isinstance(launcher, NoOpRecordLauncher), "Unexpected launcher type" + assert len(launcher.launched_order) == num_starts, "Unexpected number launches" + + # Check that a single `job` instance can be launched and re-launched and + # that `id_a, id_b, id_c, ...` corresponds to + # `"start_a", "start_b", "start_c", ...` when calling + # ```py + # id_a = experiment.start(job) # "start_a" + # id_b = experiment.start(job) # "start_b" + # id_c = experiment.start(job) # "start_c" + # ... + # ``` + assert ids_to_launches == launcher.ids_to_launched, "Job was not re-launched" + assert sorted(ids_to_launches) == sorted(exp_cached_ids), "Exp did not cache ids" + + +class GetStatusLauncher(LauncherProtocol): + def __init__(self): + self.id_to_status = {create_job_id(): stat for stat in JobStatus} + + __hash__ = object.__hash__ + + @property + def known_ids(self): + return tuple(self.id_to_status) + + @classmethod + def create(cls, _): + raise NotImplementedError("{type(self).__name__} should not be created") + + def start(self, _): + raise NotImplementedError("{type(self).__name__} should not start anything") + + def _assert_ids(self, ids: LaunchedJobID): + if any(id_ not in self.id_to_status for id_ in ids): + raise errors.LauncherJobNotFound + + def get_status(self, *ids: LaunchedJobID): + self._assert_ids(ids) + return {id_: self.id_to_status[id_] for id_ in ids} + + def stop_jobs(self, *ids: LaunchedJobID): + self._assert_ids(ids) + stopped = {id_: JobStatus.CANCELLED for id_ in ids} + self.id_to_status |= stopped + return stopped + + +@pytest.fixture +def make_populated_experiment(monkeypatch, experiment): + def impl(num_active_launchers): + new_launchers = (GetStatusLauncher() for _ in range(num_active_launchers)) + id_to_launcher = { + id_: launcher for launcher in new_launchers for id_ in launcher.known_ids + } + monkeypatch.setattr( + experiment, "_launch_history", LaunchHistory(id_to_launcher) + ) + return experiment + + yield impl + + +def test_experiment_can_get_statuses(make_populated_experiment): + exp = make_populated_experiment(num_active_launchers=1) + (launcher,) = exp._launch_history.iter_past_launchers() + ids = tuple(launcher.known_ids) + recieved_stats = exp.get_status(*ids) + assert len(recieved_stats) == len(ids), "Unexpected number of statuses" + assert ( + dict(zip(ids, recieved_stats)) == launcher.id_to_status + ), "Statuses in wrong order" + + +@pytest.mark.parametrize( + "num_launchers", + [pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)], +) +def test_experiment_can_get_statuses_from_many_launchers( + make_populated_experiment, num_launchers +): + exp = make_populated_experiment(num_active_launchers=num_launchers) + launcher_and_rand_ids = ( + (launcher, random.choice(tuple(launcher.id_to_status))) + for launcher in exp._launch_history.iter_past_launchers() + ) + expected_id_to_stat = { + id_: launcher.id_to_status[id_] for launcher, id_ in launcher_and_rand_ids + } + query_ids = tuple(expected_id_to_stat) + stats = exp.get_status(*query_ids) + assert len(stats) == len(expected_id_to_stat), "Unexpected number of statuses" + assert dict(zip(query_ids, stats)) == expected_id_to_stat, "Statuses in wrong order" + + +def test_get_status_returns_not_started_for_unrecognized_ids( + monkeypatch, make_populated_experiment +): + exp = make_populated_experiment(num_active_launchers=1) + brand_new_id = create_job_id() + ((launcher, (id_not_known_by_exp, *rest)),) = ( + exp._launch_history.group_by_launcher().items() + ) + new_history = LaunchHistory({id_: launcher for id_ in rest}) + monkeypatch.setattr(exp, "_launch_history", new_history) + expected_stats = (InvalidJobStatus.NEVER_STARTED,) * 2 + actual_stats = exp.get_status(brand_new_id, id_not_known_by_exp) + assert expected_stats == actual_stats + + +def test_get_status_de_dups_ids_passed_to_launchers( + monkeypatch, make_populated_experiment +): + def track_calls(fn): + calls = [] + + def impl(*a, **kw): + calls.append((a, kw)) + return fn(*a, **kw) + + return calls, impl + + exp = make_populated_experiment(num_active_launchers=1) + ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + calls, tracked_get_status = track_calls(launcher.get_status) + monkeypatch.setattr(launcher, "get_status", tracked_get_status) + stats = exp.get_status(id_, id_, id_) + assert len(stats) == 3, "Unexpected number of statuses" + assert all(stat == stats[0] for stat in stats), "Statuses are not eq" + assert len(calls) == 1, "Launcher's `get_status` was called more than once" + (call,) = calls + assert call == ((id_,), {}), "IDs were not de-duplicated" + + +def test_wait_handles_empty_call_args(experiment): + """An exception is raised when there are no jobs to complete""" + with pytest.raises(ValueError, match="No job ids"): + experiment.wait() + + +def test_wait_does_not_block_unknown_id(experiment): + """If an experiment does not recognize a job id, it should not wait for its + completion + """ + now = time.perf_counter() + experiment.wait(create_job_id()) + assert time.perf_counter() - now < 1 + + +def test_wait_calls_prefered_impl(make_populated_experiment, monkeypatch): + """Make wait is calling the expected method for checking job statuses. + Right now we only have the "polling" impl, but in future this might change + to an event based system. + """ + exp = make_populated_experiment(1) + ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + was_called = False + + def mocked_impl(*args, **kwargs): + nonlocal was_called + was_called = True + + monkeypatch.setattr(exp, "_poll_for_statuses", mocked_impl) + exp.wait(id_) + assert was_called + + +@pytest.mark.parametrize( + "num_polls", + [ + pytest.param(i, id=f"Poll for status {i} times") + for i in (1, 5, 10, 20, 100, 1_000) + ], +) +@pytest.mark.parametrize("verbose", [True, False]) +def test_poll_status_blocks_until_job_is_completed( + monkeypatch, make_populated_experiment, num_polls, verbose +): + """Make sure that the polling based implementation blocks the calling + thread. Use varying number of polls to simulate varying lengths of job time + for a job to complete. + + Additionally check to make sure that the expected log messages are present + """ + exp = make_populated_experiment(1) + ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + (current_status,) = launcher.get_status(id_).values() + different_statuses = set(JobStatus) - {current_status} + (new_status, *_) = different_statuses + mock_log = io.StringIO() + + @dataclasses.dataclass + class ChangeStatusAfterNPolls: + n: int + from_: JobStatus + to: JobStatus + num_calls: int = dataclasses.field(default=0, init=False) + + def __call__(self, *args, **kwargs): + self.num_calls += 1 + ret_status = self.to if self.num_calls >= self.n else self.from_ + return (ret_status,) + + mock_get_status = ChangeStatusAfterNPolls(num_polls, current_status, new_status) + monkeypatch.setattr(exp, "get_status", mock_get_status) + monkeypatch.setattr( + "smartsim.experiment.logger.info", lambda s: mock_log.write(f"{s}\n") + ) + final_statuses = exp._poll_for_statuses( + [id_], different_statuses, timeout=10, interval=0, verbose=verbose + ) + assert final_statuses == {id_: new_status} + expected_log = io.StringIO() + expected_log.writelines( + f"Job({id_}): Running with status '{current_status.value}'\n" + for _ in range(num_polls - 1) + ) + expected_log.write(f"Job({id_}): Finished with status '{new_status.value}'\n") + assert mock_get_status.num_calls == num_polls + assert mock_log.getvalue() == (expected_log.getvalue() if verbose else "") + + +def test_poll_status_raises_when_called_with_infinite_iter_wait( + make_populated_experiment, +): + """Cannot wait forever between polls. That will just block the thread after + the first poll + """ + exp = make_populated_experiment(1) + ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + with pytest.raises(ValueError, match="Polling interval cannot be infinite"): + exp._poll_for_statuses( + [id_], + [], + timeout=10, + interval=float("inf"), + ) -def test_default_model_path( - monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" -) -> None: - """Ensure the default file structure is created for Model""" - exp_name = "default-model-path" - exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) - monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) - settings = exp.create_run_settings(exe="echo", exe_args="hello") - model = exp.create_model(name="model_name", run_settings=settings) - exp.start(model) - model_path = pathlib.Path(test_dir) / model.name - assert model_path.exists() - assert model.path == str(model_path) +def test_poll_for_status_raises_if_ids_not_found_within_timeout( + make_populated_experiment, +): + """If there is a timeout, a timeout error should be raised when it is exceeded""" + exp = make_populated_experiment(1) + ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + (current_status,) = launcher.get_status(id_).values() + different_statuses = set(JobStatus) - {current_status} + with pytest.raises( + TimeoutError, + match=re.escape( + f"Job ID(s) {id_} failed to reach terminal status before timeout" + ), + ): + exp._poll_for_statuses( + [id_], + different_statuses, + timeout=1, + interval=0, + ) -def test_default_ensemble_path( - monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" -) -> None: - """Ensure the default file structure is created for Ensemble""" - - exp_name = "default-ensemble-path" - exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) - monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) - settings = exp.create_run_settings(exe="echo", exe_args="hello") - ensemble = exp.create_ensemble( - name="ensemble_name", run_settings=settings, replicas=2 - ) - exp.start(ensemble) - ensemble_path = pathlib.Path(test_dir) / ensemble.name - assert ensemble_path.exists() - assert ensemble.path == str(ensemble_path) - for member in ensemble.models: - member_path = ensemble_path / member.name - assert member_path.exists() - assert member.path == str(ensemble_path / member.name) - - -def test_user_orch_path( - monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" -) -> None: - """Ensure a relative path is used to created Orchestrator folder""" - - exp_name = "default-orch-path" - exp = Experiment(exp_name, launcher="local", exp_path=test_dir) - monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) - db = exp.create_database( - port=wlmutils.get_test_port(), - interface=wlmutils.get_test_interface(), - path="./testing_folder1234", +@pytest.mark.parametrize( + "num_launchers", + [pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)], +) +@pytest.mark.parametrize( + "select_ids", + [ + pytest.param( + lambda history: history._id_to_issuer.keys(), id="All launched jobs" + ), + pytest.param( + lambda history: next(iter(history.group_by_launcher().values())), + id="All from one launcher", + ), + pytest.param( + lambda history: itertools.chain.from_iterable( + random.sample(tuple(ids), len(JobStatus) // 2) + for ids in history.group_by_launcher().values() + ), + id="Subset per launcher", + ), + pytest.param( + lambda history: random.sample( + tuple(history._id_to_issuer), len(history._id_to_issuer) // 3 + ), + id=f"Random subset across all launchers", + ), + ], +) +def test_experiment_can_stop_jobs(make_populated_experiment, num_launchers, select_ids): + exp = make_populated_experiment(num_launchers) + ids = (launcher.known_ids for launcher in exp._launch_history.iter_past_launchers()) + ids = tuple(itertools.chain.from_iterable(ids)) + before_stop_stats = exp.get_status(*ids) + to_cancel = tuple(select_ids(exp._launch_history)) + stats = exp.stop(*to_cancel) + after_stop_stats = exp.get_status(*ids) + assert stats == (JobStatus.CANCELLED,) * len(to_cancel) + assert dict(zip(ids, before_stop_stats)) | dict(zip(to_cancel, stats)) == dict( + zip(ids, after_stop_stats) ) - exp.start(db) - orch_path = pathlib.Path(osp.abspath("./testing_folder1234")) - assert orch_path.exists() - assert db.path == str(orch_path) - shutil.rmtree(orch_path) -def test_default_model_with_path( - monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" -) -> None: - """Ensure a relative path is used to created Model folder""" - - exp_name = "default-ensemble-path" - exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) - monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) - settings = exp.create_run_settings(exe="echo", exe_args="hello") - model = exp.create_model( - name="model_name", run_settings=settings, path="./testing_folder1234" - ) - exp.start(model) - model_path = pathlib.Path(osp.abspath("./testing_folder1234")) - assert model_path.exists() - assert model.path == str(model_path) - shutil.rmtree(model_path) +def test_experiment_raises_if_asked_to_stop_no_jobs(experiment): + with pytest.raises(ValueError, match="No job ids provided"): + experiment.stop() -def test_default_ensemble_with_path( - monkeypatch: pytest.MonkeyPatch, test_dir: str, wlmutils: "conftest.WLMUtils" -) -> None: - """Ensure a relative path is used to created Ensemble folder""" - - exp_name = "default-ensemble-path" - exp = Experiment(exp_name, launcher=wlmutils.get_test_launcher(), exp_path=test_dir) - monkeypatch.setattr(exp._control, "start", lambda *a, **kw: ...) - settings = exp.create_run_settings(exe="echo", exe_args="hello") - ensemble = exp.create_ensemble( - name="ensemble_name", - run_settings=settings, - path="./testing_folder1234", - replicas=2, +@pytest.mark.parametrize( + "num_launchers", + [pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)], +) +def test_experiment_stop_does_not_raise_on_unknown_job_id( + make_populated_experiment, num_launchers +): + exp = make_populated_experiment(num_launchers) + new_id = create_job_id() + all_known_ids = tuple(exp._launch_history._id_to_issuer) + before_cancel = exp.get_status(*all_known_ids) + (stat,) = exp.stop(new_id) + assert stat == InvalidJobStatus.NEVER_STARTED + after_cancel = exp.get_status(*all_known_ids) + assert before_cancel == after_cancel + + +def test_start_raises_if_no_args_supplied(test_dir): + exp = Experiment(name="exp_name", exp_path=test_dir) + with pytest.raises(ValueError, match="No jobs provided to start"): + exp.start() + + +def test_stop_raises_if_no_args_supplied(test_dir): + exp = Experiment(name="exp_name", exp_path=test_dir) + with pytest.raises(ValueError, match="No job ids provided"): + exp.stop() + + +def test_get_status_raises_if_no_args_supplied(test_dir): + exp = Experiment(name="exp_name", exp_path=test_dir) + with pytest.raises(ValueError, match="No job ids provided"): + exp.get_status() + + +def test_poll_raises_if_no_args_supplied(test_dir): + exp = Experiment(name="exp_name", exp_path=test_dir) + with pytest.raises( + TypeError, match="missing 2 required positional arguments: 'ids' and 'statuses'" + ): + exp._poll_for_statuses() + + +def test_wait_raises_if_no_args_supplied(test_dir): + exp = Experiment(name="exp_name", exp_path=test_dir) + with pytest.raises(ValueError, match="No job ids to wait on provided"): + exp.wait() + + +def test_type_experiment_name_parameter(test_dir): + with pytest.raises(TypeError, match="name argument was not of type str"): + Experiment(name=1, exp_path=test_dir) + + +def test_type_start_parameters(test_dir): + exp = Experiment(name="exp_name", exp_path=test_dir) + with pytest.raises(TypeError, match="jobs argument was not of type Job"): + exp.start("invalid") + + +def test_type_get_status_parameters(test_dir): + exp = Experiment(name="exp_name", exp_path=test_dir) + with pytest.raises(TypeError, match="ids argument was not of type LaunchedJobID"): + exp.get_status(2) + + +def test_type_wait_parameter(test_dir): + exp = Experiment(name="exp_name", exp_path=test_dir) + with pytest.raises(TypeError, match="ids argument was not of type LaunchedJobID"): + exp.wait(2) + + +def test_type_stop_parameter(test_dir): + exp = Experiment(name="exp_name", exp_path=test_dir) + with pytest.raises(TypeError, match="ids argument was not of type LaunchedJobID"): + exp.stop(2) + + +@pytest.mark.parametrize( + "job_list", + ( + pytest.param( + [ + ( + job.Job( + Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + Ensemble("ensemble-name", "echo", replicas=2).build_jobs( + launch_settings.LaunchSettings("local") + ), + ) + ], + id="(job1, (job2, job_3))", + ), + pytest.param( + [ + ( + Ensemble("ensemble-name", "echo", replicas=2).build_jobs( + launch_settings.LaunchSettings("local") + ), + ( + job.Job( + Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + job.Job( + Application( + "test_name_2", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + ), + ) + ], + id="((job1, job2), (job3, job4))", + ), + pytest.param( + [ + ( + job.Job( + Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + ) + ], + id="(job,)", + ), + pytest.param( + [ + [ + job.Job( + Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + ( + Ensemble("ensemble-name", "echo", replicas=2).build_jobs( + launch_settings.LaunchSettings("local") + ), + job.Job( + Application( + "test_name_2", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + ), + ] + ], + id="[job_1, ((job_2, job_3), job_4)]", + ), + ), +) +def test_start_unpack( + test_dir: str, wlmutils, monkeypatch: pytest.MonkeyPatch, job_list: job.Job +): + """Test unpacking a sequences of jobs""" + + monkeypatch.setattr( + "smartsim._core.dispatch._LauncherAdapter.start", + lambda launch, exe, job_execution_path, env, out, err: random_id(), ) - exp.start(ensemble) - ensemble_path = pathlib.Path(osp.abspath("./testing_folder1234")) - assert ensemble_path.exists() - assert ensemble.path == str(ensemble_path) - for member in ensemble.models: - member_path = ensemble_path / member.name - assert member_path.exists() - assert member.path == str(member_path) - shutil.rmtree(ensemble_path) + + exp = Experiment(name="exp_name", exp_path=test_dir) + exp.start(*job_list) diff --git a/tests/test_file_operations.py b/tests/test_file_operations.py new file mode 100644 index 0000000000..327eb74286 --- /dev/null +++ b/tests/test_file_operations.py @@ -0,0 +1,786 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import base64 +import filecmp +import os +import pathlib +import pickle +from glob import glob +from os import path as osp + +import pytest + +from smartsim._core.entrypoints import file_operations +from smartsim._core.entrypoints.file_operations import get_parser + +pytestmark = pytest.mark.group_a + + +def get_gen_file(fileutils, filename): + return fileutils.get_test_conf_path(osp.join("generator_files", filename)) + + +def test_symlink_files(test_dir): + """ + Test operation to symlink files + """ + # Set source directory and file + source = pathlib.Path(test_dir) / "sym_source" + os.mkdir(source) + source_file = source / "sym_source.txt" + with open(source_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy") + + # Set path to be the destination directory + entity_path = os.path.join(test_dir, "entity_name") + + parser = get_parser() + cmd = f"symlink {source_file} {entity_path}" + args = cmd.split() + ns = parser.parse_args(args) + + file_operations.symlink(ns) + + # Assert the two files are the same file + link = pathlib.Path(test_dir) / "entity_name" + assert link.is_symlink() + assert os.readlink(link) == str(source_file) + + # Clean up the test directory + os.unlink(link) + os.remove(pathlib.Path(source) / "sym_source.txt") + os.rmdir(pathlib.Path(test_dir) / "sym_source") + + +def test_symlink_dir(test_dir): + """ + Test operation to symlink directories + """ + + source = pathlib.Path(test_dir) / "sym_source" + os.mkdir(source) + + # entity_path to be the dest dir + entity_path = os.path.join(test_dir, "entity_name") + + parser = get_parser() + cmd = f"symlink {source} {entity_path}" + args = cmd.split() + ns = parser.parse_args(args) + + file_operations.symlink(ns) + + link = pathlib.Path(test_dir) / "entity_name" + # Assert the two files are the same file + assert link.is_symlink() + assert os.readlink(link) == str(source) + + # Clean up the test directory + os.unlink(link) + os.rmdir(pathlib.Path(test_dir) / "sym_source") + + +def test_symlink_not_absolute(test_dir): + """Test that ValueError is raised when a relative path + is given to the symlink operation + """ + # Set source directory and file + source = pathlib.Path(test_dir) / "sym_source" + os.mkdir(source) + source_file = source / "sym_source.txt" + with open(source_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy") + + # Set path to be the destination directory + entity_path = ".." + + parser = get_parser() + cmd = f"symlink {source_file} {entity_path}" + args = cmd.split() + + with pytest.raises(SystemExit) as e: + parser.parse_args(args) + + assert isinstance(e.value.__context__, argparse.ArgumentError) + assert "invalid _abspath value" in e.value.__context__.message + + +def test_copy_op_file(test_dir): + """Test the operation to copy the content of the source file to the destination path + with an empty file of the same name already in the directory""" + + to_copy = os.path.join(test_dir, "to_copy") + os.mkdir(to_copy) + + source_file = pathlib.Path(to_copy) / "copy_file.txt" + with open(source_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy") + + entity_path = os.path.join(test_dir, "entity_name") + os.mkdir(entity_path) + + dest_file = os.path.join(test_dir, "entity_name", "copy_file.txt") + with open(dest_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("") + + parser = get_parser() + cmd = f"copy {source_file} {dest_file}" + args = cmd.split() + ns = parser.parse_args(args) + + # Execute copy + file_operations.copy(ns) + + # Assert files were copied over + with open(dest_file, "r", encoding="utf-8") as dummy_file: + assert dummy_file.read() == "dummy" + + # Clean up + os.remove(pathlib.Path(to_copy) / "copy_file.txt") + os.rmdir(pathlib.Path(test_dir) / "to_copy") + + os.remove(pathlib.Path(entity_path) / "copy_file.txt") + os.rmdir(pathlib.Path(test_dir) / "entity_name") + + +def test_copy_op_dirs(test_dir): + """Test the operation that copies an entire directory tree source to a new location destination + that already exists""" + + to_copy = os.path.join(test_dir, "to_copy") + os.mkdir(to_copy) + + # write some test files in the dir + source_file = pathlib.Path(to_copy) / "copy_file.txt" + with open(source_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy1") + + source_file_2 = pathlib.Path(to_copy) / "copy_file_2.txt" + with open(source_file_2, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy2") + + # entity_path to be the dest dir + entity_path = os.path.join(test_dir, "entity_name") + os.mkdir(entity_path) + + parser = get_parser() + cmd = f"copy {to_copy} {entity_path} --dirs_exist_ok" + args = cmd.split() + ns = parser.parse_args(args) + + # Execute copy + file_operations.copy(ns) + + # Assert dirs were copied over + entity_files_1 = pathlib.Path(entity_path) / "copy_file.txt" + with open(entity_files_1, "r", encoding="utf-8") as dummy_file: + assert dummy_file.read() == "dummy1" + + entity_files_2 = pathlib.Path(entity_path) / "copy_file_2.txt" + with open(entity_files_2, "r", encoding="utf-8") as dummy_file: + assert dummy_file.read() == "dummy2" + + # Clean up + os.remove(pathlib.Path(to_copy) / "copy_file.txt") + os.remove(pathlib.Path(to_copy) / "copy_file_2.txt") + os.rmdir(pathlib.Path(test_dir) / "to_copy") + os.remove(pathlib.Path(entity_path) / "copy_file.txt") + os.remove(pathlib.Path(entity_path) / "copy_file_2.txt") + os.rmdir(pathlib.Path(test_dir) / "entity_name") + + +def test_copy_op_dirs_file_exists_error(test_dir): + """Test that a FileExistsError is raised when copying a directory tree source to a new location destination + when the destination already exists, and the flag --dirs_exist_ok is not included + """ + + to_copy = os.path.join(test_dir, "to_copy") + os.mkdir(to_copy) + + # write some test files in the dir + source_file = pathlib.Path(to_copy) / "copy_file.txt" + with open(source_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy1") + + source_file_2 = pathlib.Path(to_copy) / "copy_file_2.txt" + with open(source_file_2, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy2") + + # entity_path to be the dest dir + entity_path = os.path.join(test_dir, "entity_name") + os.mkdir(entity_path) + + parser = get_parser() + # command does not include the --dirs_exist_ok flag + cmd = f"copy {to_copy} {entity_path}" + args = cmd.split() + ns = parser.parse_args(args) + + # Execute copy + with pytest.raises(FileExistsError) as ex: + file_operations.copy(ns) + assert f"File exists" in ex.value.args + + # Clean up + os.remove(pathlib.Path(to_copy) / "copy_file.txt") + os.remove(pathlib.Path(to_copy) / "copy_file_2.txt") + os.rmdir(pathlib.Path(test_dir) / "to_copy") + os.rmdir(pathlib.Path(test_dir) / "entity_name") + + +def test_copy_op_bad_source_file(test_dir): + """Test that a FileNotFoundError is raised when there is a bad source file""" + + to_copy = os.path.join(test_dir, "to_copy") + os.mkdir(to_copy) + entity_path = os.path.join(test_dir, "entity_name") + os.mkdir(entity_path) + + bad_path = "/not/a/real/path" + # Execute copy + + parser = get_parser() + cmd = f"copy {bad_path} {entity_path}" + args = cmd.split() + ns = parser.parse_args(args) + + with pytest.raises(FileNotFoundError) as ex: + file_operations.copy(ns) + assert "No such file or directory" in ex.value.args + + # Clean up + os.rmdir(pathlib.Path(test_dir) / "to_copy") + os.rmdir(pathlib.Path(test_dir) / "entity_name") + + +def test_copy_op_bad_dest_path(test_dir): + """Test that a FileNotFoundError is raised when there is a bad destination file.""" + + to_copy = os.path.join(test_dir, "to_copy") + os.mkdir(to_copy) + + source_file = pathlib.Path(to_copy) / "copy_file.txt" + with open(source_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy1") + entity_path = os.path.join(test_dir, "entity_name") + os.mkdir(entity_path) + + bad_path = "/not/a/real/path" + + parser = get_parser() + cmd = f"copy {source_file} {bad_path}" + args = cmd.split() + ns = parser.parse_args(args) + + with pytest.raises(FileNotFoundError) as ex: + file_operations.copy(ns) + assert "No such file or directory" in ex.value.args + + # clean up + os.remove(pathlib.Path(to_copy) / "copy_file.txt") + os.rmdir(pathlib.Path(test_dir) / "to_copy") + os.rmdir(pathlib.Path(test_dir) / "entity_name") + + +def test_copy_not_absolute(test_dir): + + to_copy = os.path.join(test_dir, "to_copy") + os.mkdir(to_copy) + + source_file = pathlib.Path(to_copy) / "copy_file.txt" + with open(source_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy1") + entity_path = os.path.join(test_dir, "entity_name") + os.mkdir(entity_path) + + bad_path = ".." + + parser = get_parser() + cmd = f"copy {source_file} {bad_path}" + args = cmd.split() + + with pytest.raises(SystemExit) as e: + parser.parse_args(args) + + assert isinstance(e.value.__context__, argparse.ArgumentError) + assert "invalid _abspath value" in e.value.__context__.message + + # clean up + os.remove(pathlib.Path(to_copy) / "copy_file.txt") + os.rmdir(pathlib.Path(test_dir) / "to_copy") + os.rmdir(pathlib.Path(test_dir) / "entity_name") + + +def test_move_op(test_dir): + """Test the operation to move a file""" + + source_dir = os.path.join(test_dir, "from_here") + os.mkdir(source_dir) + dest_dir = os.path.join(test_dir, "to_here") + os.mkdir(dest_dir) + + dest_file = pathlib.Path(dest_dir) / "to_here.txt" + with open(dest_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write(" ") + + source_file = pathlib.Path(source_dir) / "app_move.txt" + with open(source_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy") + + assert osp.exists(source_file) + with open(source_file, "r", encoding="utf-8") as dummy_file: + assert dummy_file.read() == "dummy" + + parser = get_parser() + cmd = f"move {source_file} {dest_file}" + args = cmd.split() + ns = parser.parse_args(args) + + file_operations.move(ns) + + # Assert that the move was successful + assert not osp.exists(source_file) + assert osp.exists(dest_file) + with open(dest_file, "r", encoding="utf-8") as dummy_file: + assert dummy_file.read() == "dummy" + + # Clean up the directories + os.rmdir(source_dir) + os.remove(dest_file) + os.rmdir(dest_dir) + + +def test_move_not_absolute(test_dir): + """Test that a ValueError is raised when a relative + path is given to the move operation""" + + source_dir = os.path.join(test_dir, "from_here") + os.mkdir(source_dir) + dest_dir = os.path.join(test_dir, "to_here") + os.mkdir(dest_dir) + + dest_file = ".." + + source_file = pathlib.Path(source_dir) / "app_move.txt" + with open(source_file, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy") + + parser = get_parser() + cmd = f"move {source_file} {dest_file}" + args = cmd.split() + + with pytest.raises(SystemExit) as e: + parser.parse_args(args) + + assert isinstance(e.value.__context__, argparse.ArgumentError) + assert "invalid _abspath value" in e.value.__context__.message + + +def test_remove_op_file(test_dir): + """Test the operation to delete a file""" + + # Make a test file with dummy text + to_del = pathlib.Path(test_dir) / "app_del.txt" + with open(to_del, "w+", encoding="utf-8") as dummy_file: + dummy_file.write("dummy") + + assert osp.exists(to_del) + with open(to_del, "r", encoding="utf-8") as dummy_file: + assert dummy_file.read() == "dummy" + + parser = get_parser() + cmd = f"remove {to_del}" + args = cmd.split() + ns = parser.parse_args(args) + + file_operations.remove(ns) + + # Assert file has been deleted + assert not osp.exists(to_del) + + +def test_remove_op_dir(test_dir): + """Test the operation to delete a directory""" + + # Make a test file with dummy text + to_del = pathlib.Path(test_dir) / "dir_del" + os.mkdir(to_del) + + parser = get_parser() + cmd = f"remove {to_del}" + args = cmd.split() + ns = parser.parse_args(args) + + file_operations.remove(ns) + + # Assert directory has been deleted + assert not osp.exists(to_del) + + +def test_remove_op_bad_path(test_dir): + """Test that FileNotFoundError is raised when a bad path is given to the + operation to delete a file""" + + to_del = pathlib.Path(test_dir) / "not_real.txt" + + parser = get_parser() + cmd = f"remove {to_del}" + args = cmd.split() + ns = parser.parse_args(args) + + with pytest.raises(FileNotFoundError) as ex: + file_operations.remove(ns) + assert "No such file or directory" in ex.value.args + + +def test_remove_op_not_absolute(): + """Test that ValueError is raised when a relative path + is given to the operation to delete a file""" + + to_del = ".." + + parser = get_parser() + cmd = f"remove {to_del}" + args = cmd.split() + + with pytest.raises(SystemExit) as e: + parser.parse_args(args) + + assert isinstance(e.value.__context__, argparse.ArgumentError) + assert "invalid _abspath value" in e.value.__context__.message + + +@pytest.mark.parametrize( + ["param_dict", "error_type"], + [ + pytest.param( + { + "5": 10, + "FIRST": "SECOND", + "17": 20, + "65": "70", + "placeholder": "group leftupper region", + "1200": "120", + "VALID": "valid", + }, + "None", + id="correct dict", + ), + pytest.param( + ["list", "of", "values"], + "TypeError", + id="incorrect dict", + ), + pytest.param({}, "ValueError", id="empty dict"), + ], +) +def test_configure_file_op(test_dir, fileutils, param_dict, error_type): + """Test configure file operation with correct parameter dictionary, empty dicitonary, and an incorrect type""" + + tag = ";" + + # retrieve files to compare after test + correct_path = fileutils.get_test_conf_path( + osp.join("generator_files", "easy", "correct/") + ) + + tagged_files = sorted(glob(test_dir + "/*")) + correct_files = sorted(glob(correct_path + "/*")) + + # Pickle the dictionary + pickled_dict = pickle.dumps(param_dict) + + # Encode the pickled dictionary with Base64 + encoded_dict = base64.b64encode(pickled_dict).decode("ascii") + + # Run configure op on test files + for tagged_file in tagged_files: + parser = get_parser() + cmd = f"configure {tagged_file} {tagged_file} {tag} {encoded_dict}" + args = cmd.split() + ns = parser.parse_args(args) + + if error_type == "ValueError": + with pytest.raises(ValueError) as ex: + file_operations.configure(ns) + assert "param dictionary is empty" in ex.value.args[0] + elif error_type == "TypeError": + with pytest.raises(TypeError) as ex: + file_operations.configure(ns) + assert "param dict is not a valid dictionary" in ex.value.args[0] + else: + file_operations.configure(ns) + + if error_type == "None": + for written, correct in zip(tagged_files, correct_files): + assert filecmp.cmp(written, correct) + + +def test_configure_file_invalid_tags(fileutils, test_dir): + """Test configure file operation with an invalid tag""" + generator_files = pathlib.Path(fileutils.get_test_conf_path("generator_files")) + tagged_file = generator_files / "easy/marked/invalidtag.txt" + correct_file = generator_files / "easy/correct/invalidtag.txt" + target_file = pathlib.Path(test_dir, "invalidtag.txt") + + tag = ";" + param_dict = {"VALID": "valid"} + + # Pickle the dictionary + pickled_dict = pickle.dumps(param_dict) + + # Encode the pickled dictionary with Base64 + encoded_dict = base64.b64encode(pickled_dict).decode("ascii") + parser = get_parser() + cmd = f"configure {tagged_file} {test_dir} {tag} {encoded_dict}" + args = cmd.split() + ns = parser.parse_args(args) + + file_operations.configure(ns) + assert filecmp.cmp(correct_file, target_file) + + +def test_configure_file_not_absolute(): + """Test that ValueError is raised when tagged files + given to configure file op are not absolute paths + """ + + tagged_file = ".." + tag = ";" + param_dict = {"5": 10} + # Pickle the dictionary + pickled_dict = pickle.dumps(param_dict) + + # Encode the pickled dictionary with Base64 + encoded_dict = base64.b64encode(pickled_dict) + parser = get_parser() + cmd = f"configure {tagged_file} {tagged_file} {tag} {encoded_dict}" + args = cmd.split() + + with pytest.raises(SystemExit) as e: + parser.parse_args(args) + + assert isinstance(e.value.__context__, argparse.ArgumentError) + assert "invalid _abspath value" in e.value.__context__.message + + +@pytest.mark.parametrize( + ["param_dict", "error_type"], + [ + pytest.param( + {"PARAM0": "param_value_1", "PARAM1": "param_value_2"}, + "None", + id="correct dict", + ), + pytest.param( + ["list", "of", "values"], + "TypeError", + id="incorrect dict", + ), + pytest.param({}, "ValueError", id="empty dict"), + ], +) +def test_configure_directory(test_dir, fileutils, param_dict, error_type): + """Test configure directory operation with correct parameter dictionary, empty dicitonary, and an incorrect type""" + tag = ";" + config = get_gen_file(fileutils, "tag_dir_template") + + # Pickle the dictionary + pickled_dict = pickle.dumps(param_dict) + # Encode the pickled dictionary with Base64 + encoded_dict = base64.b64encode(pickled_dict).decode("ascii") + + parser = get_parser() + cmd = f"configure {config} {test_dir} {tag} {encoded_dict}" + args = cmd.split() + ns = parser.parse_args(args) + + if error_type == "ValueError": + with pytest.raises(ValueError) as ex: + file_operations.configure(ns) + assert "param dictionary is empty" in ex.value.args[0] + elif error_type == "TypeError": + with pytest.raises(TypeError) as ex: + file_operations.configure(ns) + assert "param dict is not a valid dictionary" in ex.value.args[0] + else: + file_operations.configure(ns) + assert osp.isdir(osp.join(test_dir, "nested_0")) + assert osp.isdir(osp.join(test_dir, "nested_1")) + + with open(osp.join(test_dir, "nested_0", "tagged_0.sh")) as f: + line = f.readline() + assert line.strip() == f'echo "Hello with parameter 0 = param_value_1"' + + with open(osp.join(test_dir, "nested_1", "tagged_1.sh")) as f: + line = f.readline() + assert line.strip() == f'echo "Hello with parameter 1 = param_value_2"' + + +def test_configure_directory_not_absolute(): + """Test that ValueError is raised when tagged directories + given to configure op are not absolute paths + """ + + tagged_directory = ".." + tag = ";" + param_dict = {"5": 10} + # Pickle the dictionary + pickled_dict = pickle.dumps(param_dict) + + # Encode the pickled dictionary with Base64 + encoded_dict = base64.b64encode(pickled_dict) + parser = get_parser() + cmd = f"configure {tagged_directory} {tagged_directory} {tag} {encoded_dict}" + args = cmd.split() + + with pytest.raises(SystemExit) as e: + parser.parse_args(args) + + assert isinstance(e.value.__context__, argparse.ArgumentError) + assert "invalid _abspath value" in e.value.__context__.message + + +def test_parser_move(): + """Test that the parser succeeds when receiving expected args for the move operation""" + parser = get_parser() + + src_path = pathlib.Path("/absolute/file/src/path") + dest_path = pathlib.Path("/absolute/file/dest/path") + + cmd = f"move {src_path} {dest_path}" + args = cmd.split() + ns = parser.parse_args(args) + + assert ns.source == src_path + assert ns.dest == dest_path + + +def test_parser_remove(): + """Test that the parser succeeds when receiving expected args for the remove operation""" + parser = get_parser() + + file_path = pathlib.Path("/absolute/file/path") + cmd = f"remove {file_path}" + + args = cmd.split() + ns = parser.parse_args(args) + + assert ns.to_remove == file_path + + +def test_parser_symlink(): + """Test that the parser succeeds when receiving expected args for the symlink operation""" + parser = get_parser() + + src_path = pathlib.Path("/absolute/file/src/path") + dest_path = pathlib.Path("/absolute/file/dest/path") + cmd = f"symlink {src_path} {dest_path}" + + args = cmd.split() + + ns = parser.parse_args(args) + + assert ns.source == src_path + assert ns.dest == dest_path + + +def test_parser_copy(): + """Test that the parser succeeds when receiving expected args for the copy operation""" + parser = get_parser() + + src_path = pathlib.Path("/absolute/file/src/path") + dest_path = pathlib.Path("/absolute/file/dest/path") + + cmd = f"copy {src_path} {dest_path}" + + args = cmd.split() + ns = parser.parse_args(args) + + assert ns.source == src_path + assert ns.dest == dest_path + + +def test_parser_configure_file_parse(): + """Test that the parser succeeds when receiving expected args for the configure file operation""" + parser = get_parser() + + src_path = pathlib.Path("/absolute/file/src/path") + dest_path = pathlib.Path("/absolute/file/dest/path") + tag_delimiter = ";" + + param_dict = { + "5": 10, + "FIRST": "SECOND", + "17": 20, + "65": "70", + "placeholder": "group leftupper region", + "1200": "120", + } + + pickled_dict = pickle.dumps(param_dict) + encoded_dict = base64.b64encode(pickled_dict) + + cmd = f"configure {src_path} {dest_path} {tag_delimiter} {encoded_dict}" + args = cmd.split() + ns = parser.parse_args(args) + + assert ns.source == src_path + assert ns.dest == dest_path + assert ns.tag_delimiter == tag_delimiter + assert ns.param_dict == str(encoded_dict) + + +def test_parser_configure_directory_parse(): + """Test that the parser succeeds when receiving expected args for the configure directory operation""" + parser = get_parser() + + src_path = pathlib.Path("/absolute/file/src/path") + dest_path = pathlib.Path("/absolute/file/dest/path") + tag_delimiter = ";" + + param_dict = { + "5": 10, + "FIRST": "SECOND", + "17": 20, + "65": "70", + "placeholder": "group leftupper region", + "1200": "120", + } + + pickled_dict = pickle.dumps(param_dict) + encoded_dict = base64.b64encode(pickled_dict) + + cmd = f"configure {src_path} {dest_path} {tag_delimiter} {encoded_dict}" + args = cmd.split() + ns = parser.parse_args(args) + + assert ns.source == src_path + assert ns.dest == dest_path + assert ns.tag_delimiter == tag_delimiter + assert ns.param_dict == str(encoded_dict) diff --git a/tests/test_generator.py b/tests/test_generator.py index fd9a5b8363..f949d8f663 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -24,356 +24,413 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import filecmp +import itertools +import pathlib +import unittest.mock +from glob import glob from os import path as osp import pytest -from tabulate import tabulate -from smartsim import Experiment -from smartsim._core.generation import Generator -from smartsim.database import Orchestrator -from smartsim.settings import RunSettings +from smartsim._core.commands import Command, CommandList +from smartsim._core.generation.generator import Generator +from smartsim._core.generation.operations.operations import ( + ConfigureOperation, + CopyOperation, + FileSysOperationSet, + GenerationContext, + SymlinkOperation, +) +from smartsim.entity import SmartSimEntity +from smartsim.launchable import Job -# The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a +ids = set() -rs = RunSettings("python", exe_args="sleep.py") +_ID_GENERATOR = (str(i) for i in itertools.count()) -""" -Test the generation of files and input data for an experiment -TODO - - test lists of inputs for each file type - - test empty directories - - test re-generation +def random_id(): + return next(_ID_GENERATOR) -""" +@pytest.fixture +def generator_instance(test_dir: str) -> Generator: + """Instance of Generator""" + # os.mkdir(root) + yield Generator(root=pathlib.Path(test_dir)) -def get_gen_file(fileutils, filename): - return fileutils.get_test_conf_path(osp.join("generator_files", filename)) +@pytest.fixture +def mock_index(): + """Fixture to create a mock destination path.""" + return 1 -def test_ensemble(fileutils, test_dir): - exp = Experiment("gen-test", launcher="local") - gen = Generator(test_dir) - params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} - ensemble = exp.create_ensemble("test", params=params, run_settings=rs) +class EchoHelloWorldEntity(SmartSimEntity): + """A simple smartsim entity that meets the `ExecutableProtocol` protocol""" - config = get_gen_file(fileutils, "in.atm") - ensemble.attach_generator_files(to_configure=config) - gen.generate_experiment(ensemble) + def __init__(self): + self.name = "entity_name" + self.files = FileSysOperationSet([]) + self.file_parameters = None - assert len(ensemble) == 9 - assert osp.isdir(osp.join(test_dir, "test")) - for i in range(9): - assert osp.isdir(osp.join(test_dir, "test/test_" + str(i))) + def as_executable_sequence(self): + return ("echo", "Hello", "World!") -def test_ensemble_overwrite(fileutils, test_dir): - exp = Experiment("gen-test-overwrite", launcher="local") - - gen = Generator(test_dir, overwrite=True) - - params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} - ensemble = exp.create_ensemble("test", params=params, run_settings=rs) - - config = get_gen_file(fileutils, "in.atm") - ensemble.attach_generator_files(to_configure=[config]) - gen.generate_experiment(ensemble) - - # re generate without overwrite - config = get_gen_file(fileutils, "in.atm") - ensemble.attach_generator_files(to_configure=[config]) - gen.generate_experiment(ensemble) - - assert len(ensemble) == 9 - assert osp.isdir(osp.join(test_dir, "test")) - for i in range(9): - assert osp.isdir(osp.join(test_dir, "test/test_" + str(i))) - - -def test_ensemble_overwrite_error(fileutils, test_dir): - exp = Experiment("gen-test-overwrite-error", launcher="local") - - gen = Generator(test_dir) - - params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} - ensemble = exp.create_ensemble("test", params=params, run_settings=rs) - - config = get_gen_file(fileutils, "in.atm") - ensemble.attach_generator_files(to_configure=[config]) - gen.generate_experiment(ensemble) - - # re generate without overwrite - config = get_gen_file(fileutils, "in.atm") - ensemble.attach_generator_files(to_configure=[config]) - with pytest.raises(FileExistsError): - gen.generate_experiment(ensemble) - - -def test_full_exp(fileutils, test_dir, wlmutils): - exp = Experiment("gen-test", test_dir, launcher="local") - - model = exp.create_model("model", run_settings=rs) - script = fileutils.get_test_conf_path("sleep.py") - model.attach_generator_files(to_copy=script) - - orc = Orchestrator(wlmutils.get_test_port()) - params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} - ensemble = exp.create_ensemble("test_ens", params=params, run_settings=rs) - - config = get_gen_file(fileutils, "in.atm") - ensemble.attach_generator_files(to_configure=config) - exp.generate(orc, ensemble, model) - - # test for ensemble - assert osp.isdir(osp.join(test_dir, "test_ens/")) - for i in range(9): - assert osp.isdir(osp.join(test_dir, "test_ens/test_ens_" + str(i))) - - # test for orc dir - assert osp.isdir(osp.join(test_dir, orc.name)) - - # test for model file - assert osp.isdir(osp.join(test_dir, "model")) - assert osp.isfile(osp.join(test_dir, "model/sleep.py")) - - -def test_dir_files(fileutils, test_dir): - """test the generate of models with files that - are directories with subdirectories and files - """ - - exp = Experiment("gen-test", test_dir, launcher="local") - - params = {"THERMO": [10, 20, 30], "STEPS": [10, 20, 30]} - ensemble = exp.create_ensemble("dir_test", params=params, run_settings=rs) - conf_dir = get_gen_file(fileutils, "test_dir") - ensemble.attach_generator_files(to_configure=conf_dir) - - exp.generate(ensemble, tag="@") - - assert osp.isdir(osp.join(test_dir, "dir_test/")) - for i in range(9): - model_path = osp.join(test_dir, "dir_test/dir_test_" + str(i)) - assert osp.isdir(model_path) - assert osp.isdir(osp.join(model_path, "test_dir_1")) - assert osp.isfile(osp.join(model_path, "test.in")) - - -def test_print_files(fileutils, test_dir, capsys): - """Test the stdout print of files attached to an ensemble""" - - exp = Experiment("print-attached-files-test", test_dir, launcher="local") +@pytest.fixture +def mock_job() -> unittest.mock.MagicMock: + """Fixture to create a mock Job.""" + job = unittest.mock.MagicMock( + entity=EchoHelloWorldEntity(), + get_launch_steps=unittest.mock.MagicMock( + side_effect=lambda: NotImplementedError() + ), + spec=Job, + ) + job.name = "test_job" + yield job - ensemble = exp.create_ensemble("dir_test", replicas=1, run_settings=rs) - ensemble.entities = [] - ensemble.print_attached_files() - captured = capsys.readouterr() - assert captured.out == "The ensemble is empty, no files to show.\n" +# UNIT TESTS - params = {"THERMO": [10, 20], "STEPS": [20, 30]} - ensemble = exp.create_ensemble("dir_test", params=params, run_settings=rs) - gen_dir = get_gen_file(fileutils, "test_dir") - symlink_dir = get_gen_file(fileutils, "to_symlink_dir") - copy_dir = get_gen_file(fileutils, "to_copy_dir") - ensemble.print_attached_files() - captured = capsys.readouterr() - expected_out = ( - tabulate( - [ - [model.name, "No file attached to this model."] - for model in ensemble.models - ], - headers=["Model name", "Files"], - tablefmt="grid", - ) - + "\n" - ) +def test_init_generator(generator_instance: Generator, test_dir: str): + """Test Generator init""" + assert generator_instance.root == pathlib.Path(test_dir) - assert captured.out == expected_out - - ensemble.attach_generator_files() - ensemble.print_attached_files() - captured = capsys.readouterr() - expected_out = ( - tabulate( - [ - [model.name, "No file attached to this entity."] - for model in ensemble.models - ], - headers=["Model name", "Files"], - tablefmt="grid", - ) - + "\n" - ) - assert captured.out == expected_out - ensemble.attach_generator_files( - to_configure=[gen_dir, copy_dir], to_copy=copy_dir, to_symlink=symlink_dir +def test_build_job_base_path( + generator_instance: Generator, mock_job: unittest.mock.MagicMock, mock_index +): + """Test Generator._build_job_base_path returns correct path""" + root_path = generator_instance._build_job_base_path(mock_job, mock_index) + expected_path = ( + generator_instance.root + / f"{mock_job.__class__.__name__.lower()}s" + / f"{mock_job.name}-{mock_index}" ) - - expected_out = tabulate( - [ - ["Copy", copy_dir], - ["Symlink", symlink_dir], - ["Configure", f"{gen_dir}\n{copy_dir}"], - ], - headers=["Strategy", "Files"], - tablefmt="grid", + assert root_path == expected_path + + +def test_build_job_run_path( + test_dir: str, + mock_job: unittest.mock.MagicMock, + generator_instance: Generator, + monkeypatch: pytest.MonkeyPatch, + mock_index, +): + """Test Generator._build_job_run_path returns correct path""" + monkeypatch.setattr( + Generator, + "_build_job_base_path", + lambda self, job, job_index: pathlib.Path(test_dir), ) - - assert all(str(model.files) == expected_out for model in ensemble.models) - - expected_out_multi = ( - tabulate( - [[model.name, expected_out] for model in ensemble.models], - headers=["Model name", "Files"], - tablefmt="grid", - ) - + "\n" + run_path = generator_instance._build_job_run_path(mock_job, mock_index) + expected_run_path = pathlib.Path(test_dir) / generator_instance.run_directory + assert run_path == expected_run_path + + +def test_build_job_log_path( + test_dir: str, + mock_job: unittest.mock.MagicMock, + generator_instance: Generator, + monkeypatch: pytest.MonkeyPatch, + mock_index, +): + """Test Generator._build_job_log_path returns correct path""" + monkeypatch.setattr( + Generator, + "_build_job_base_path", + lambda self, job, job_index: pathlib.Path(test_dir), ) - ensemble.print_attached_files() + log_path = generator_instance._build_job_log_path(mock_job, mock_index) + expected_log_path = pathlib.Path(test_dir) / generator_instance.log_directory + assert log_path == expected_log_path - captured = capsys.readouterr() - assert captured.out == expected_out_multi +def test_build_log_file_path(test_dir: str, generator_instance: Generator): + """Test Generator._build_log_file_path returns correct path""" + expected_path = pathlib.Path(test_dir) / "smartsim_params.txt" + assert generator_instance._build_log_file_path(test_dir) == expected_path -def test_multiple_tags(fileutils, test_dir): - """Test substitution of multiple tagged parameters on same line""" - exp = Experiment("test-multiple-tags", test_dir) - model_params = {"port": 6379, "password": "unbreakable_password"} - model_settings = RunSettings("bash", "multi_tags_template.sh") - parameterized_model = exp.create_model( - "multi-tags", run_settings=model_settings, params=model_params +def test_build_out_file_path( + test_dir: str, generator_instance: Generator, mock_job: unittest.mock.MagicMock +): + """Test Generator._build_out_file_path returns out path""" + out_file_path = generator_instance._build_out_file_path( + pathlib.Path(test_dir), mock_job.name ) - config = get_gen_file(fileutils, "multi_tags_template.sh") - parameterized_model.attach_generator_files(to_configure=[config]) - exp.generate(parameterized_model, overwrite=True) - exp.start(parameterized_model, block=True) + assert out_file_path == pathlib.Path(test_dir) / f"{mock_job.name}.out" - with open(osp.join(parameterized_model.path, "multi-tags.out")) as f: - log_content = f.read() - assert "My two parameters are 6379 and unbreakable_password, OK?" in log_content +def test_build_err_file_path( + test_dir: str, generator_instance: Generator, mock_job: unittest.mock.MagicMock +): + """Test Generator._build_err_file_path returns err path""" + err_file_path = generator_instance._build_err_file_path( + pathlib.Path(test_dir), mock_job.name + ) + assert err_file_path == pathlib.Path(test_dir) / f"{mock_job.name}.err" + + +def test_generate_job( + mock_job: unittest.mock.MagicMock, generator_instance: Generator, mock_index: int +): + """Test Generator.generate_job returns correct paths""" + job_paths = generator_instance.generate_job(mock_job, mock_index) + assert job_paths.run_path.name == Generator.run_directory + assert job_paths.out_path.name == f"{mock_job.entity.name}.out" + assert job_paths.err_path.name == f"{mock_job.entity.name}.err" + + +def test_execute_commands(generator_instance: Generator): + """Test Generator._execute_commands subprocess.run""" + with ( + unittest.mock.patch( + "smartsim._core.generation.generator.subprocess.run" + ) as run_process, + ): + cmd_list = CommandList(Command(["test", "command"])) + generator_instance._execute_commands(cmd_list) + run_process.assert_called_once() + + +def test_mkdir_file(generator_instance: Generator, test_dir: str): + """Test Generator._mkdir_file returns correct type and value""" + cmd = generator_instance._mkdir_file(pathlib.Path(test_dir)) + assert isinstance(cmd, Command) + assert cmd.command == ["mkdir", "-p", test_dir] + + +@pytest.mark.parametrize( + "dest", + ( + pytest.param(None, id="dest as None"), + pytest.param( + pathlib.Path("absolute/path"), + id="dest as valid path", + ), + ), +) +def test_copy_files_valid_dest( + dest, source, generator_instance: Generator, test_dir: str +): + to_copy = [CopyOperation(src=file, dest=dest) for file in source] + gen = GenerationContext(pathlib.Path(test_dir)) + cmd_list = generator_instance._copy_files(files=to_copy, context=gen) + assert isinstance(cmd_list, CommandList) + # Extract file paths from commands + cmd_src_paths = set() + for cmd in cmd_list.commands: + src_index = cmd.command.index("copy") + 1 + cmd_src_paths.add(cmd.command[src_index]) + # Assert all file paths are in the command list + file_paths = {str(file) for file in source} + assert file_paths == cmd_src_paths, "Not all file paths are in the command list" + + +@pytest.mark.parametrize( + "dest", + ( + pytest.param(None, id="dest as None"), + pytest.param( + pathlib.Path("absolute/path"), + id="dest as valid path", + ), + ), +) +def test_symlink_files_valid_dest( + dest, source, generator_instance: Generator, test_dir: str +): + to_symlink = [SymlinkOperation(src=file, dest=dest) for file in source] + gen = GenerationContext(pathlib.Path(test_dir)) + cmd_list = generator_instance._symlink_files(files=to_symlink, context=gen) + assert isinstance(cmd_list, CommandList) + # Extract file paths from commands + cmd_src_paths = set() + for cmd in cmd_list.commands: + print(cmd) + src_index = cmd.command.index("symlink") + 1 + cmd_src_paths.add(cmd.command[src_index]) + # Assert all file paths are in the command list + file_paths = {str(file) for file in source} + assert file_paths == cmd_src_paths, "Not all file paths are in the command list" + + +@pytest.mark.parametrize( + "dest", + ( + pytest.param(None, id="dest as None"), + pytest.param( + pathlib.Path("absolute/path"), + id="dest as valid path", + ), + ), +) +def test_configure_files_valid_dest( + dest, source, generator_instance: Generator, test_dir: str +): + file_param = { + "5": 10, + "FIRST": "SECOND", + "17": 20, + "65": "70", + "placeholder": "group leftupper region", + "1200": "120", + "VALID": "valid", + } + to_configure = [ + ConfigureOperation(src=file, dest=dest, file_parameters=file_param) + for file in source + ] + gen = GenerationContext(pathlib.Path(test_dir)) + cmd_list = generator_instance._configure_files(files=to_configure, context=gen) + assert isinstance(cmd_list, CommandList) + # Extract file paths from commands + cmd_src_paths = set() + for cmd in cmd_list.commands: + src_index = cmd.command.index("configure") + 1 + cmd_src_paths.add(cmd.command[src_index]) + # Assert all file paths are in the command list + file_paths = {str(file) for file in source} + assert file_paths == cmd_src_paths, "Not all file paths are in the command list" + + +@pytest.fixture +def run_directory(test_dir, generator_instance): + return pathlib.Path(test_dir) / generator_instance.run_directory + + +@pytest.fixture +def log_directory(test_dir, generator_instance): + return pathlib.Path(test_dir) / generator_instance.log_directory + + +def test_build_commands( + generator_instance: Generator, + run_directory: pathlib.Path, + log_directory: pathlib.Path, +): + """Test Generator._build_commands calls internal helper functions""" + with ( + unittest.mock.patch( + "smartsim._core.generation.Generator._append_mkdir_commands" + ) as mock_append_mkdir_commands, + unittest.mock.patch( + "smartsim._core.generation.Generator._append_file_operations" + ) as mock_append_file_operations, + ): + generator_instance._build_commands( + EchoHelloWorldEntity(), + run_directory, + log_directory, + ) + mock_append_mkdir_commands.assert_called_once() + mock_append_file_operations.assert_called_once() + + +def test_append_mkdir_commands( + generator_instance: Generator, + run_directory: pathlib.Path, + log_directory: pathlib.Path, +): + """Test Generator._append_mkdir_commands calls Generator._mkdir_file twice""" + with ( + unittest.mock.patch( + "smartsim._core.generation.Generator._mkdir_file" + ) as mock_mkdir_file, + ): + generator_instance._append_mkdir_commands( + CommandList(), + run_directory, + log_directory, + ) + assert mock_mkdir_file.call_count == 2 + + +def test_append_file_operations( + context: GenerationContext, generator_instance: Generator +): + """Test Generator._append_file_operations calls all file operations""" + with ( + unittest.mock.patch( + "smartsim._core.generation.Generator._copy_files" + ) as mock_copy_files, + unittest.mock.patch( + "smartsim._core.generation.Generator._symlink_files" + ) as mock_symlink_files, + unittest.mock.patch( + "smartsim._core.generation.Generator._configure_files" + ) as mock_configure_files, + ): + generator_instance._append_file_operations( + CommandList(), + EchoHelloWorldEntity(), + context, + ) + mock_copy_files.assert_called_once() + mock_symlink_files.assert_called_once() + mock_configure_files.assert_called_once() -def test_generation_log(fileutils, test_dir): - """Test that an error is issued when a tag is unused and make_fatal is True""" - exp = Experiment("gen-log-test", test_dir, launcher="local") +@pytest.fixture +def paths_to_copy(fileutils): + paths = fileutils.get_test_conf_path(osp.join("generator_files", "to_copy_dir")) + yield [pathlib.Path(path) for path in sorted(glob(paths + "/*"))] - params = {"THERMO": [10, 20], "STEPS": [10, 20]} - ensemble = exp.create_ensemble("dir_test", params=params, run_settings=rs) - conf_file = get_gen_file(fileutils, "in.atm") - ensemble.attach_generator_files(to_configure=conf_file) - def not_header(line): - """you can add other general checks in here""" - return not line.startswith("Generation start date and time:") +@pytest.fixture +def paths_to_symlink(fileutils): + paths = fileutils.get_test_conf_path(osp.join("generator_files", "to_symlink_dir")) + yield [pathlib.Path(path) for path in sorted(glob(paths + "/*"))] - exp.generate(ensemble, verbose=True) - log_file = osp.join(test_dir, "smartsim_params.txt") - ground_truth = get_gen_file( - fileutils, osp.join("log_params", "smartsim_params.txt") +@pytest.fixture +def paths_to_configure(fileutils): + paths = fileutils.get_test_conf_path( + osp.join("generator_files", "easy", "correct/") ) - - with open(log_file) as f1, open(ground_truth) as f2: - assert not not_header(f1.readline()) - f1 = filter(not_header, f1) - f2 = filter(not_header, f2) - assert all(x == y for x, y in zip(f1, f2)) - - for entity in ensemble: - assert filecmp.cmp( - osp.join(entity.path, "smartsim_params.txt"), - get_gen_file( - fileutils, - osp.join("log_params", "dir_test", entity.name, "smartsim_params.txt"), - ), - ) - - -def test_config_dir(fileutils, test_dir): - """Test the generation and configuration of models with - tagged files that are directories with subdirectories and files - """ - exp = Experiment("config-dir", launcher="local") - - gen = Generator(test_dir) - - params = {"PARAM0": [0, 1], "PARAM1": [2, 3]} - ensemble = exp.create_ensemble("test", params=params, run_settings=rs) - - config = get_gen_file(fileutils, "tag_dir_template") - ensemble.attach_generator_files(to_configure=config) - gen.generate_experiment(ensemble) - - assert osp.isdir(osp.join(test_dir, "test")) - - def _check_generated(test_num, param_0, param_1): - conf_test_dir = osp.join(test_dir, "test", f"test_{test_num}") - assert osp.isdir(conf_test_dir) - assert osp.isdir(osp.join(conf_test_dir, "nested_0")) - assert osp.isdir(osp.join(conf_test_dir, "nested_1")) - - with open(osp.join(conf_test_dir, "nested_0", "tagged_0.sh")) as f: - line = f.readline() - assert line.strip() == f'echo "Hello with parameter 0 = {param_0}"' - - with open(osp.join(conf_test_dir, "nested_1", "tagged_1.sh")) as f: - line = f.readline() - assert line.strip() == f'echo "Hello with parameter 1 = {param_1}"' - - _check_generated(0, 0, 2) - _check_generated(1, 0, 3) - _check_generated(2, 1, 2) - _check_generated(3, 1, 3) - - -def test_no_gen_if_file_not_exist(fileutils): - """Test that generation of file with non-existant config - raises a FileNotFound exception - """ - exp = Experiment("file-not-found", launcher="local") - ensemble = exp.create_ensemble("test", params={"P": [0, 1]}, run_settings=rs) - config = get_gen_file(fileutils, "path_not_exist") - with pytest.raises(FileNotFoundError): - ensemble.attach_generator_files(to_configure=config) - - -def test_no_gen_if_symlink_to_dir(fileutils): - """Test that when configuring a directory containing a symlink - a ValueError exception is raised to prevent circular file - structure configuration - """ - exp = Experiment("circular-config-files", launcher="local") - ensemble = exp.create_ensemble("test", params={"P": [0, 1]}, run_settings=rs) - config = get_gen_file(fileutils, "circular_config") - with pytest.raises(ValueError): - ensemble.attach_generator_files(to_configure=config) - - -def test_no_file_overwrite(): - exp = Experiment("test_no_file_overwrite", launcher="local") - ensemble = exp.create_ensemble("test", params={"P": [0, 1]}, run_settings=rs) - with pytest.raises(ValueError): - ensemble.attach_generator_files( - to_configure=["/normal/file.txt", "/path/to/smartsim_params.txt"] - ) - with pytest.raises(ValueError): - ensemble.attach_generator_files( - to_symlink=["/normal/file.txt", "/path/to/smartsim_params.txt"] - ) - with pytest.raises(ValueError): - ensemble.attach_generator_files( - to_copy=["/normal/file.txt", "/path/to/smartsim_params.txt"] - ) + yield [pathlib.Path(path) for path in sorted(glob(paths + "/*"))] + + +@pytest.fixture +def context(test_dir: str): + yield GenerationContext(pathlib.Path(test_dir)) + + +@pytest.fixture +def operations_list(paths_to_copy, paths_to_symlink, paths_to_configure): + op_list = [] + for file in paths_to_copy: + op_list.append(CopyOperation(src=file)) + for file in paths_to_symlink: + op_list.append(SymlinkOperation(src=file)) + for file in paths_to_configure: + op_list.append(SymlinkOperation(src=file)) + return op_list + + +@pytest.fixture +def formatted_command_list(operations_list: list, context: GenerationContext): + new_list = CommandList() + for file in operations_list: + new_list.append(file.format(context)) + return new_list + + +def test_execute_commands( + operations_list: list, formatted_command_list, generator_instance: Generator +): + """Test Generator._execute_commands calls with appropriate type and num times""" + with ( + unittest.mock.patch( + "smartsim._core.generation.generator.subprocess.run" + ) as mock_run, + ): + generator_instance._execute_commands(formatted_command_list) + assert mock_run.call_count == len(formatted_command_list) diff --git a/tests/test_init.py b/tests/test_init.py index dfb58bd557..3014f81935 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -24,29 +24,20 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import tempfile import pytest -# The tests in this file belong to the group_a group -pytestmark = pytest.mark.group_a +pytestmark = [pytest.mark.group_a, pytest.mark.group_b, pytest.mark.slow_tests] __author__ = "Sam Partee" -try: - from smartsim import * +def test_import_ss(monkeypatch): + with tempfile.TemporaryDirectory() as empty_dir: + # Move to an empty directory so `smartsim` dir is not in cwd + monkeypatch.chdir(empty_dir) - _top_import_error = None -except Exception as e: - _top_import_error = e - - -def test_import_ss(): - # Test either above import has failed for some reason - # "import *" is discouraged outside of the module level, hence we - # rely on setting up the variable above - assert _top_import_error is None - - -test_import_ss() + # Make sure SmartSim is importable + import smartsim diff --git a/tests/test_intervals.py b/tests/test_intervals.py new file mode 100644 index 0000000000..1b865867f2 --- /dev/null +++ b/tests/test_intervals.py @@ -0,0 +1,87 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import contextlib +import operator +import time + +import pytest + +from smartsim._core.control.interval import SynchronousTimeInterval + +pytestmark = pytest.mark.group_a + + +@pytest.mark.parametrize( + "timeout", [pytest.param(i, id=f"{i} second(s)") for i in range(10)] +) +def test_sync_timeout_finite(timeout, monkeypatch): + """Test that the sync timeout intervals are correctly calculated""" + monkeypatch.setattr(time, "perf_counter", lambda *_, **__: 0) + t = SynchronousTimeInterval(timeout) + assert t.delta == timeout + assert t.elapsed == 0 + assert t.remaining == timeout + assert (operator.not_ if timeout > 0 else bool)(t.expired) + assert not t.infinite + future = timeout + 2 + monkeypatch.setattr(time, "perf_counter", lambda *_, **__: future) + assert t.elapsed == future + assert t.remaining == 0 + assert t.expired + assert not t.infinite + new_t = t.new_interval() + assert new_t.delta == timeout + assert new_t.elapsed == 0 + assert new_t.remaining == timeout + assert (operator.not_ if timeout > 0 else bool)(new_t.expired) + assert not new_t.infinite + + +def test_sync_timeout_can_block_thread(): + """Test that the sync timeout can block the calling thread""" + timeout = 1 + now = time.perf_counter() + SynchronousTimeInterval(timeout).block() + later = time.perf_counter() + assert abs(later - now - timeout) <= 0.25 + + +def test_sync_timeout_infinte(): + """Passing in `None` to a sync timeout creates a timeout with an infinite + delta time + """ + t = SynchronousTimeInterval(None) + assert t.remaining == float("inf") + assert t.infinite + with pytest.raises(RuntimeError, match="block thread forever"): + t.block() + + +def test_sync_timeout_raises_on_invalid_value(monkeypatch): + """Cannot make a sync time interval with a negative time delta""" + with pytest.raises(ValueError): + SynchronousTimeInterval(-1) diff --git a/tests/test_launch_history.py b/tests/test_launch_history.py new file mode 100644 index 0000000000..3b4cd5bcc5 --- /dev/null +++ b/tests/test_launch_history.py @@ -0,0 +1,205 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import contextlib +import itertools + +import pytest + +from smartsim._core.control.launch_history import LaunchHistory +from smartsim._core.utils.launcher import LauncherProtocol, create_job_id + +pytestmark = pytest.mark.group_a + + +class MockLancher(LauncherProtocol): + __hash__ = object.__hash__ + + @classmethod + def create(cls, _): + raise NotImplementedError + + def start(self, _): + raise NotImplementedError + + def get_status(self, *_): + raise NotImplementedError + + def stop_jobs(self, *_): + raise NotImplementedError + + +LAUNCHER_INSTANCE_A = MockLancher() +LAUNCHER_INSTANCE_B = MockLancher() + + +@pytest.mark.parametrize( + "initial_state, to_save", + ( + pytest.param( + {}, + [(MockLancher(), create_job_id())], + id="Empty state, one save", + ), + pytest.param( + {}, + [(MockLancher(), create_job_id()), (MockLancher(), create_job_id())], + id="Empty state, many save", + ), + pytest.param( + {}, + [ + (LAUNCHER_INSTANCE_A, create_job_id()), + (LAUNCHER_INSTANCE_A, create_job_id()), + ], + id="Empty state, repeat launcher instance", + ), + pytest.param( + {create_job_id(): MockLancher()}, + [(MockLancher(), create_job_id())], + id="Preexisting state, one save", + ), + pytest.param( + {create_job_id(): MockLancher()}, + [(MockLancher(), create_job_id()), (MockLancher(), create_job_id())], + id="Preexisting state, many save", + ), + pytest.param( + {create_job_id(): LAUNCHER_INSTANCE_A}, + [(LAUNCHER_INSTANCE_A, create_job_id())], + id="Preexisting state, repeat launcher instance", + ), + ), +) +def test_save_launch(initial_state, to_save): + history = LaunchHistory(initial_state) + launcher = MockLancher() + + assert history._id_to_issuer == initial_state + for launcher, id_ in to_save: + history.save_launch(launcher, id_) + assert history._id_to_issuer == initial_state | {id_: l for l, id_ in to_save} + + +def test_save_launch_raises_if_id_already_in_use(): + launcher = MockLancher() + other_launcher = MockLancher() + id_ = create_job_id() + history = LaunchHistory() + history.save_launch(launcher, id_) + with pytest.raises(ValueError): + history.save_launch(other_launcher, id_) + + +@pytest.mark.parametrize( + "ids_to_issuer, expected_num_launchers", + ( + pytest.param( + {create_job_id(): MockLancher()}, + 1, + id="One launch, one instance", + ), + pytest.param( + {create_job_id(): LAUNCHER_INSTANCE_A for _ in range(5)}, + 1, + id="Many launch, one instance", + ), + pytest.param( + {create_job_id(): MockLancher() for _ in range(5)}, + 5, + id="Many launch, many instance", + ), + ), +) +def test_iter_past_launchers(ids_to_issuer, expected_num_launchers): + history = LaunchHistory(ids_to_issuer) + assert len(list(history.iter_past_launchers())) == expected_num_launchers + known_launchers = set(history._id_to_issuer.values()) + assert all( + launcher in known_launchers for launcher in history.iter_past_launchers() + ) + + +ID_A = create_job_id() +ID_B = create_job_id() +ID_C = create_job_id() + + +@pytest.mark.parametrize( + "init_state, ids, expected_group_by", + ( + pytest.param( + {ID_A: LAUNCHER_INSTANCE_A, ID_B: LAUNCHER_INSTANCE_A}, + None, + {LAUNCHER_INSTANCE_A: {ID_A, ID_B}}, + id="All known ids, single launcher", + ), + pytest.param( + {ID_A: LAUNCHER_INSTANCE_A, ID_B: LAUNCHER_INSTANCE_A}, + {ID_A}, + {LAUNCHER_INSTANCE_A: {ID_A}}, + id="Subset known ids, single launcher", + ), + pytest.param( + {ID_A: LAUNCHER_INSTANCE_A, ID_B: LAUNCHER_INSTANCE_B}, + None, + {LAUNCHER_INSTANCE_A: {ID_A}, LAUNCHER_INSTANCE_B: {ID_B}}, + id="All known ids, many launchers", + ), + pytest.param( + {ID_A: LAUNCHER_INSTANCE_A, ID_B: LAUNCHER_INSTANCE_B}, + {ID_A}, + {LAUNCHER_INSTANCE_A: {ID_A}}, + id="Subset known ids, many launchers, same issuer", + ), + pytest.param( + { + ID_A: LAUNCHER_INSTANCE_A, + ID_B: LAUNCHER_INSTANCE_B, + ID_C: LAUNCHER_INSTANCE_A, + }, + {ID_A, ID_B}, + {LAUNCHER_INSTANCE_A: {ID_A}, LAUNCHER_INSTANCE_B: {ID_B}}, + id="Subset known ids, many launchers, many issuer", + ), + ), +) +def test_group_by_launcher(init_state, ids, expected_group_by): + histroy = LaunchHistory(init_state) + assert histroy.group_by_launcher(ids) == expected_group_by + + +@pytest.mark.parametrize( + "ctx, unknown_ok", + ( + pytest.param(pytest.raises(ValueError), False, id="unknown_ok=False"), + pytest.param(contextlib.nullcontext(), True, id="unknown_ok=True"), + ), +) +def test_group_by_launcher_encounters_unknown_launch_id(ctx, unknown_ok): + histroy = LaunchHistory() + with ctx: + assert histroy.group_by_launcher([create_job_id()], unknown_ok=unknown_ok) == {} diff --git a/tests/test_operations.py b/tests/test_operations.py new file mode 100644 index 0000000000..abfc141d89 --- /dev/null +++ b/tests/test_operations.py @@ -0,0 +1,364 @@ +import base64 +import os +import pathlib +import pickle + +import pytest + +from smartsim._core.commands import Command +from smartsim._core.generation.operations.operations import ( + ConfigureOperation, + CopyOperation, + FileSysOperationSet, + GenerationContext, + SymlinkOperation, + _check_run_path, + _create_dest_path, + configure_cmd, + copy_cmd, + default_tag, + symlink_cmd, +) +from smartsim._core.generation.operations.utils.helpers import check_src_and_dest_path + +pytestmark = pytest.mark.group_a + + +@pytest.fixture +def generation_context(test_dir: str): + """Fixture to create a GenerationContext object.""" + return GenerationContext(pathlib.Path(test_dir)) + + +@pytest.fixture +def file_system_operation_set( + copy_operation: CopyOperation, + symlink_operation: SymlinkOperation, + configure_operation: ConfigureOperation, +): + """Fixture to create a FileSysOperationSet object.""" + return FileSysOperationSet([copy_operation, symlink_operation, configure_operation]) + + +# TODO is this test even necessary +@pytest.mark.parametrize( + "job_run_path, dest", + ( + pytest.param( + pathlib.Path("/absolute/src"), + pathlib.Path("relative/dest"), + id="Valid paths", + ), + pytest.param( + pathlib.Path("/absolute/src"), + pathlib.Path(""), + id="Empty destination path", + ), + ), +) +def test_check_src_and_dest_path_valid(job_run_path, dest): + """Test valid path inputs for helpers.check_src_and_dest_path""" + check_src_and_dest_path(job_run_path, dest) + + +@pytest.mark.parametrize( + "job_run_path, dest, error", + ( + pytest.param( + pathlib.Path("relative/src"), + pathlib.Path("relative/dest"), + ValueError, + id="Relative src Path", + ), + pytest.param( + pathlib.Path("/absolute/src"), + pathlib.Path("/absolute/src"), + ValueError, + id="Absolute dest Path", + ), + pytest.param( + 123, + pathlib.Path("relative/dest"), + TypeError, + id="non Path src", + ), + pytest.param( + pathlib.Path("/absolute/src"), + 123, + TypeError, + id="non Path dest", + ), + ), +) +def test_check_src_and_dest_path_invalid(job_run_path, dest, error): + """Test invalid path inputs for helpers.check_src_and_dest_path""" + with pytest.raises(error): + check_src_and_dest_path(job_run_path, dest) + + +@pytest.mark.parametrize( + "job_run_path, dest, expected", + ( + pytest.param( + pathlib.Path("/absolute/root"), + pathlib.Path("relative/dest"), + "/absolute/root/relative/dest", + id="Valid paths", + ), + pytest.param( + pathlib.Path("/absolute/root"), + pathlib.Path(""), + "/absolute/root", + id="Empty destination path", + ), + ), +) +def test_create_dest_path_valid(job_run_path, dest, expected): + """Test valid path inputs for operations._create_dest_path""" + assert _create_dest_path(job_run_path, dest) == expected + + +@pytest.mark.parametrize( + "job_run_path, error", + ( + pytest.param( + pathlib.Path("relative/path"), ValueError, id="Run path is not absolute" + ), + pytest.param(1234, TypeError, id="Run path is not pathlib.path"), + ), +) +def test_check_run_path_invalid(job_run_path, error): + """Test invalid path inputs for operations._check_run_path""" + with pytest.raises(error): + _check_run_path(job_run_path) + + +def test_valid_init_generation_context(test_dir: str): + """Validate GenerationContext init""" + generation_context = GenerationContext(pathlib.Path(test_dir)) + assert isinstance(generation_context, GenerationContext) + assert generation_context.job_run_path == pathlib.Path(test_dir) + + +def test_invalid_init_generation_context(): + """Validate GenerationContext init""" + with pytest.raises(TypeError): + GenerationContext(1234) + with pytest.raises(TypeError): + GenerationContext("") + + +def test_init_copy_operation(mock_src: pathlib.Path, mock_dest: pathlib.Path): + """Validate CopyOperation init""" + copy_operation = CopyOperation(mock_src, mock_dest) + assert isinstance(copy_operation, CopyOperation) + assert copy_operation.src == mock_src + assert copy_operation.dest == mock_dest + + +def test_copy_operation_format( + copy_operation: CopyOperation, + mock_dest: str, + mock_src: str, + generation_context: GenerationContext, + test_dir: str, +): + """Validate CopyOperation.format""" + exec = copy_operation.format(generation_context) + assert isinstance(exec, Command) + assert str(mock_src) in exec.command + assert copy_cmd in exec.command + assert _create_dest_path(test_dir, mock_dest) in exec.command + + +def test_init_symlink_operation(mock_src: str, mock_dest: str): + """Validate SymlinkOperation init""" + symlink_operation = SymlinkOperation(mock_src, mock_dest) + assert isinstance(symlink_operation, SymlinkOperation) + assert symlink_operation.src == mock_src + assert symlink_operation.dest == mock_dest + + +def test_symlink_operation_format( + symlink_operation: SymlinkOperation, + mock_src: str, + mock_dest: str, + generation_context: GenerationContext, +): + """Validate SymlinkOperation.format""" + exec = symlink_operation.format(generation_context) + assert isinstance(exec, Command) + assert str(mock_src) in exec.command + assert symlink_cmd in exec.command + + normalized_path = os.path.normpath(mock_src) + parent_dir = os.path.dirname(normalized_path) + final_dest = _create_dest_path(generation_context.job_run_path, mock_dest) + new_dest = os.path.join(final_dest, parent_dir) + assert new_dest in exec.command + + +def test_init_configure_operation(mock_src: str, mock_dest: str): + """Validate ConfigureOperation init""" + configure_operation = ConfigureOperation( + src=mock_src, dest=mock_dest, file_parameters={"FOO": "BAR"} + ) + assert isinstance(configure_operation, ConfigureOperation) + assert configure_operation.src == mock_src + assert configure_operation.dest == mock_dest + assert configure_operation.tag == default_tag + decoded_dict = base64.b64decode(configure_operation.file_parameters.encode("ascii")) + unpickled_dict = pickle.loads(decoded_dict) + assert unpickled_dict == {"FOO": "BAR"} + + +def test_configure_operation_format( + configure_operation: ConfigureOperation, + test_dir: str, + mock_dest: str, + mock_src: str, + generation_context: GenerationContext, +): + """Validate ConfigureOperation.format""" + exec = configure_operation.format(generation_context) + assert isinstance(exec, Command) + assert str(mock_src) in exec.command + assert configure_cmd in exec.command + assert _create_dest_path(test_dir, mock_dest) in exec.command + + +def test_init_file_sys_operation_set( + copy_operation: CopyOperation, + symlink_operation: SymlinkOperation, + configure_operation: ConfigureOperation, +): + """Test initialize FileSystemOperationSet""" + file_system_operation_set = FileSysOperationSet( + [copy_operation, symlink_operation, configure_operation] + ) + assert isinstance(file_system_operation_set.operations, list) + assert len(file_system_operation_set.operations) == 3 + + +def test_add_copy_operation(file_system_operation_set: FileSysOperationSet): + """Test FileSystemOperationSet.add_copy""" + orig_num_ops = len(file_system_operation_set.copy_operations) + file_system_operation_set.add_copy(src=pathlib.Path("/src")) + assert len(file_system_operation_set.copy_operations) == orig_num_ops + 1 + + +def test_add_symlink_operation(file_system_operation_set: FileSysOperationSet): + """Test FileSystemOperationSet.add_symlink""" + orig_num_ops = len(file_system_operation_set.symlink_operations) + file_system_operation_set.add_symlink(src=pathlib.Path("/src")) + assert len(file_system_operation_set.symlink_operations) == orig_num_ops + 1 + + +def test_add_configure_operation( + file_system_operation_set: FileSysOperationSet, +): + """Test FileSystemOperationSet.add_configuration""" + orig_num_ops = len(file_system_operation_set.configure_operations) + file_system_operation_set.add_configuration( + src=pathlib.Path("/src"), file_parameters={"FOO": "BAR"} + ) + assert len(file_system_operation_set.configure_operations) == orig_num_ops + 1 + + +@pytest.mark.parametrize( + "dest,error", + ( + pytest.param(123, TypeError, id="dest as integer"), + pytest.param("", TypeError, id="dest as empty str"), + pytest.param( + pathlib.Path("/absolute/path"), ValueError, id="dest as absolute str" + ), + ), +) +def test_copy_files_invalid_dest(dest, error, source): + """Test invalid copy destination""" + with pytest.raises(error): + _ = [CopyOperation(src=file, dest=dest) for file in source] + + +@pytest.mark.parametrize( + "src,error", + ( + pytest.param(123, TypeError, id="src as integer"), + pytest.param("", TypeError, id="src as empty str"), + pytest.param( + pathlib.Path("relative/path"), ValueError, id="src as relative str" + ), + ), +) +def test_copy_files_invalid_src(src, error): + """Test invalid copy source""" + with pytest.raises(error): + _ = CopyOperation(src=src) + + +@pytest.mark.parametrize( + "dest,error", + ( + pytest.param(123, TypeError, id="dest as integer"), + pytest.param("", TypeError, id="dest as empty str"), + pytest.param( + pathlib.Path("/absolute/path"), ValueError, id="dest as absolute str" + ), + ), +) +def test_symlink_files_invalid_dest(dest, error, source): + """Test invalid symlink destination""" + with pytest.raises(error): + _ = [SymlinkOperation(src=file, dest=dest) for file in source] + + +@pytest.mark.parametrize( + "src,error", + ( + pytest.param(123, TypeError, id="src as integer"), + pytest.param("", TypeError, id="src as empty str"), + pytest.param( + pathlib.Path("relative/path"), ValueError, id="src as relative str" + ), + ), +) +def test_symlink_files_invalid_src(src, error): + """Test invalid symlink source""" + with pytest.raises(error): + _ = SymlinkOperation(src=src) + + +@pytest.mark.parametrize( + "dest,error", + ( + pytest.param(123, TypeError, id="dest as integer"), + pytest.param("", TypeError, id="dest as empty str"), + pytest.param( + pathlib.Path("/absolute/path"), ValueError, id="dest as absolute str" + ), + ), +) +def test_configure_files_invalid_dest(dest, error, source): + """Test invalid configure destination""" + with pytest.raises(error): + _ = [ + ConfigureOperation(src=file, dest=dest, file_parameters={"FOO": "BAR"}) + for file in source + ] + + +@pytest.mark.parametrize( + "src,error", + ( + pytest.param(123, TypeError, id="src as integer"), + pytest.param("", TypeError, id="src as empty str"), + pytest.param( + pathlib.Path("relative/path"), ValueError, id="src as relative str" + ), + ), +) +def test_configure_files_invalid_src(src, error): + """Test invalid configure source""" + with pytest.raises(error): + _ = ConfigureOperation(src=src, file_parameters={"FOO": "BAR"}) diff --git a/tests/test_permutation_strategies.py b/tests/test_permutation_strategies.py new file mode 100644 index 0000000000..314c21063b --- /dev/null +++ b/tests/test_permutation_strategies.py @@ -0,0 +1,203 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import dataclasses + +import pytest + +from smartsim.builders.utils import strategies +from smartsim.builders.utils.strategies import ParamSet +from smartsim.error import errors + +pytestmark = pytest.mark.group_a + + +def test_strategy_registration(monkeypatch): + monkeypatch.setattr(strategies, "_REGISTERED_STRATEGIES", {}) + assert strategies._REGISTERED_STRATEGIES == {} + + new_strat = lambda params, exe_args, nmax: [] + decorator = strategies._register("new_strat") + assert strategies._REGISTERED_STRATEGIES == {} + + ret_val = decorator(new_strat) + assert ret_val is new_strat + assert strategies._REGISTERED_STRATEGIES == {"new_strat": new_strat} + + +def test_strategies_cannot_be_overwritten(monkeypatch): + monkeypatch.setattr( + strategies, + "_REGISTERED_STRATEGIES", + {"some-strategy": lambda params, exe_args, nmax: []}, + ) + with pytest.raises(ValueError): + strategies._register("some-strategy")(lambda params, exe_args, nmax: []) + + +def test_strategy_retreval(monkeypatch): + new_strat_a = lambda params, exe_args, nmax: [] + new_strat_b = lambda params, exe_args, nmax: [] + + monkeypatch.setattr( + strategies, + "_REGISTERED_STRATEGIES", + {"new_strat_a": new_strat_a, "new_strat_b": new_strat_b}, + ) + assert strategies.resolve("new_strat_a") == new_strat_a + assert strategies.resolve("new_strat_b") == new_strat_b + + +def test_user_strategy_error_raised_when_attempting_to_get_unknown_strat(): + with pytest.raises(ValueError): + strategies.resolve("NOT-REGISTERED") + + +def broken_strategy(p, n, e): + raise Exception("This custom strategy raised an error") + + +@pytest.mark.parametrize( + "strategy", + ( + pytest.param(broken_strategy, id="Strategy raises during execution"), + pytest.param(lambda params, exe_args, nmax: 123, id="Does not return a list"), + pytest.param( + lambda params, exe_args, nmax: [1, 2, 3], + id="Does not return a list of ParamSet", + ), + ), +) +def test_custom_strategy_raises_user_strategy_error_if_something_goes_wrong(strategy): + with pytest.raises(errors.UserStrategyError): + strategies.resolve(strategy)({"SPAM": ["EGGS"]}, {"HELLO": [["WORLD"]]}, 123) + + +@pytest.mark.parametrize( + "strategy, expected_output", + ( + pytest.param( + strategies.create_all_permutations, + ( + [ + ParamSet( + params={"SPAM": "a", "EGGS": "c"}, exe_args={"EXE": ["a"]} + ), + ParamSet( + params={"SPAM": "a", "EGGS": "c"}, + exe_args={"EXE": ["b", "c"]}, + ), + ParamSet( + params={"SPAM": "a", "EGGS": "d"}, exe_args={"EXE": ["a"]} + ), + ParamSet( + params={"SPAM": "a", "EGGS": "d"}, + exe_args={"EXE": ["b", "c"]}, + ), + ParamSet( + params={"SPAM": "b", "EGGS": "c"}, exe_args={"EXE": ["a"]} + ), + ParamSet( + params={"SPAM": "b", "EGGS": "c"}, + exe_args={"EXE": ["b", "c"]}, + ), + ParamSet( + params={"SPAM": "b", "EGGS": "d"}, exe_args={"EXE": ["a"]} + ), + ParamSet( + params={"SPAM": "b", "EGGS": "d"}, + exe_args={"EXE": ["b", "c"]}, + ), + ] + ), + id="All Permutations", + ), + pytest.param( + strategies.step_values, + ( + [ + ParamSet( + params={"SPAM": "a", "EGGS": "c"}, exe_args={"EXE": ["a"]} + ), + ParamSet( + params={"SPAM": "b", "EGGS": "d"}, + exe_args={"EXE": ["b", "c"]}, + ), + ] + ), + id="Step Values", + ), + pytest.param( + strategies.random_permutations, + ( + [ + ParamSet( + params={"SPAM": "a", "EGGS": "c"}, exe_args={"EXE": ["a"]} + ), + ParamSet( + params={"SPAM": "a", "EGGS": "c"}, + exe_args={"EXE": ["b", "c"]}, + ), + ParamSet( + params={"SPAM": "a", "EGGS": "d"}, exe_args={"EXE": ["a"]} + ), + ParamSet( + params={"SPAM": "a", "EGGS": "d"}, + exe_args={"EXE": ["b", "c"]}, + ), + ParamSet( + params={"SPAM": "b", "EGGS": "c"}, exe_args={"EXE": ["a"]} + ), + ParamSet( + params={"SPAM": "b", "EGGS": "c"}, + exe_args={"EXE": ["b", "c"]}, + ), + ParamSet( + params={"SPAM": "b", "EGGS": "d"}, exe_args={"EXE": ["a"]} + ), + ParamSet( + params={"SPAM": "b", "EGGS": "d"}, + exe_args={"EXE": ["b", "c"]}, + ), + ] + ), + id="Uncapped Random Permutations", + ), + ), +) +def test_strategy_returns_expected_set(strategy, expected_output): + params = {"SPAM": ["a", "b"], "EGGS": ["c", "d"]} + exe_args = {"EXE": [["a"], ["b", "c"]]} + output = list(strategy(params, exe_args, 50)) + assert len(output) == len(expected_output) + assert all(item in expected_output for item in output) + assert all(item in output for item in expected_output) + + +def test_param_set_is_frozen(): + param = ParamSet("set1", "set2") + with pytest.raises(dataclasses.FrozenInstanceError): + param.exe_args = "change" diff --git a/tests/test_shell_launcher.py b/tests/test_shell_launcher.py new file mode 100644 index 0000000000..f371d793f1 --- /dev/null +++ b/tests/test_shell_launcher.py @@ -0,0 +1,392 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import contextlib +import os +import pathlib +import subprocess +import sys +import textwrap +import unittest.mock + +import psutil +import pytest + +from smartsim._core.shell.shell_launcher import ShellLauncher, ShellLauncherCommand, sp +from smartsim._core.utils import helpers +from smartsim._core.utils.shell import * +from smartsim.entity import entity +from smartsim.error.errors import LauncherJobNotFound +from smartsim.status import JobStatus + +pytestmark = pytest.mark.group_a + + +class EchoHelloWorldEntity(entity.SmartSimEntity): + """A simple smartsim entity""" + + def __init__(self): + super().__init__("test-entity") + + def __eq__(self, other): + if type(self) is not type(other): + return NotImplemented + return self.as_executable_sequence() == other.as_executable_sequence() + + def as_executable_sequence(self): + return (helpers.expand_exe_path("echo"), "Hello", "World!") + + +def create_directory(directory_path: str) -> pathlib.Path: + """Creates the execution directory for testing.""" + tmp_dir = pathlib.Path(directory_path) + tmp_dir.mkdir(exist_ok=True, parents=True) + return tmp_dir + + +def generate_output_files(tmp_dir: pathlib.Path): + """Generates output and error files within the run directory for testing.""" + out_file = tmp_dir / "tmp.out" + err_file = tmp_dir / "tmp.err" + return out_file, err_file + + +def generate_directory(test_dir: str): + """Generates a execution directory, output file, and error file for testing.""" + execution_dir = create_directory(os.path.join(test_dir, "tmp")) + out_file, err_file = generate_output_files(execution_dir) + return execution_dir, out_file, err_file + + +@pytest.fixture +def shell_launcher(): + launcher = ShellLauncher() + yield launcher + if any(proc.poll() is None for proc in launcher._launched.values()): + raise RuntimeError("Test leaked processes") + + +@pytest.fixture +def make_shell_command(test_dir): + run_dir, out_file_, err_file_ = generate_directory(test_dir) + + @contextlib.contextmanager + def impl( + args: t.Sequence[str], + working_dir: str | os.PathLike[str] = run_dir, + env: dict[str, str] | None = None, + out_file: str | os.PathLike[str] = out_file_, + err_file: str | os.PathLike[str] = err_file_, + ): + with ( + open(out_file, "w", encoding="utf-8") as out, + open(err_file, "w", encoding="utf-8") as err, + ): + yield ShellLauncherCommand( + env or {}, pathlib.Path(working_dir), out, err, tuple(args) + ) + + yield impl + + +@pytest.fixture +def shell_cmd(make_shell_command) -> ShellLauncherCommand: + """Fixture to create an instance of Generator.""" + with make_shell_command(EchoHelloWorldEntity().as_executable_sequence()) as hello: + yield hello + + +# UNIT TESTS + + +def test_shell_launcher_command_init(shell_cmd: ShellLauncherCommand, test_dir: str): + """Test that ShellLauncherCommand initializes correctly""" + assert shell_cmd.env == {} + assert shell_cmd.path == pathlib.Path(test_dir) / "tmp" + assert shell_cmd.stdout.name == os.path.join(test_dir, "tmp", "tmp.out") + assert shell_cmd.stderr.name == os.path.join(test_dir, "tmp", "tmp.err") + assert shell_cmd.command_tuple == EchoHelloWorldEntity().as_executable_sequence() + + +def test_shell_launcher_init(shell_launcher: ShellLauncher): + """Test that ShellLauncher initializes correctly""" + assert shell_launcher._launched == {} + + +def test_check_popen_inputs(shell_launcher: ShellLauncher, test_dir: str): + """Test that ShellLauncher.check_popen_inputs throws correctly""" + cmd = ShellLauncherCommand( + {}, + pathlib.Path(test_dir) / "directory_dne", + subprocess.DEVNULL, + subprocess.DEVNULL, + EchoHelloWorldEntity().as_executable_sequence(), + ) + with pytest.raises(ValueError): + _ = shell_launcher.start(cmd) + + +def test_shell_launcher_start_calls_popen( + shell_launcher: ShellLauncher, shell_cmd: ShellLauncherCommand +): + """Test that the process leading up to the shell launcher popen call was correct""" + with unittest.mock.patch( + "smartsim._core.shell.shell_launcher.sp.Popen" + ) as mock_open: + _ = shell_launcher.start(shell_cmd) + mock_open.assert_called_once() + + +def test_shell_launcher_start_calls_popen_with_value( + shell_launcher: ShellLauncher, shell_cmd: ShellLauncherCommand +): + """Test that popen was called with correct values""" + with unittest.mock.patch( + "smartsim._core.shell.shell_launcher.sp.Popen" + ) as mock_open: + _ = shell_launcher.start(shell_cmd) + mock_open.assert_called_once_with( + shell_cmd.command_tuple, + cwd=shell_cmd.path, + env=shell_cmd.env, + stdout=shell_cmd.stdout, + stderr=shell_cmd.stderr, + ) + + +def test_popen_returns_popen_object( + shell_launcher: ShellLauncher, shell_cmd: ShellLauncherCommand, test_dir: str +): + """Test that the popen call returns a popen object""" + id = shell_launcher.start(shell_cmd) + with shell_launcher._launched[id] as proc: + assert isinstance(proc, sp.Popen) + + +def test_popen_writes_to_output_file( + shell_launcher: ShellLauncher, shell_cmd: ShellLauncherCommand, test_dir: str +): + """Test that popen writes to .out file upon successful process call""" + _, out_file, err_file = generate_directory(test_dir) + id = shell_launcher.start(shell_cmd) + proc = shell_launcher._launched[id] + assert proc.wait() == 0 + assert proc.returncode == 0 + with open(out_file, "r", encoding="utf-8") as out: + assert out.read() == "Hello World!\n" + with open(err_file, "r", encoding="utf-8") as err: + assert err.read() == "" + + +def test_popen_fails_with_invalid_cmd(shell_launcher: ShellLauncher, test_dir: str): + """Test that popen returns a non zero returncode after failure""" + run_dir, out_file, err_file = generate_directory(test_dir) + with ( + open(out_file, "w", encoding="utf-8") as out, + open(err_file, "w", encoding="utf-8") as err, + ): + args = (helpers.expand_exe_path("ls"), "--flag_dne") + cmd = ShellLauncherCommand({}, run_dir, out, err, args) + id = shell_launcher.start(cmd) + proc = shell_launcher._launched[id] + proc.wait() + assert proc.returncode != 0 + with open(out_file, "r", encoding="utf-8") as out: + assert out.read() == "" + with open(err_file, "r", encoding="utf-8") as err: + content = err.read() + assert "unrecognized option" in content + + +def test_popen_issues_unique_ids( + shell_launcher: ShellLauncher, shell_cmd: ShellLauncherCommand, test_dir: str +): + """Validate that all ids are unique within ShellLauncher._launched""" + seen = set() + for _ in range(5): + id = shell_launcher.start(shell_cmd) + assert id not in seen, "Duplicate ID issued" + seen.add(id) + assert len(shell_launcher._launched) == 5 + assert all(proc.wait() == 0 for proc in shell_launcher._launched.values()) + + +def test_retrieve_status_dne(shell_launcher: ShellLauncher): + """Test tht ShellLauncher returns the status of completed Jobs""" + with pytest.raises(LauncherJobNotFound): + _ = shell_launcher.get_status("dne") + + +def test_shell_launcher_returns_complete_status( + shell_launcher: ShellLauncher, shell_cmd: ShellLauncherCommand +): + """Test tht ShellLauncher returns the status of completed Jobs""" + for _ in range(5): + id = shell_launcher.start(shell_cmd) + proc = shell_launcher._launched[id] + proc.wait() + code = shell_launcher.get_status(id)[id] + assert code == JobStatus.COMPLETED + + +def test_shell_launcher_returns_failed_status( + shell_launcher: ShellLauncher, test_dir: str +): + """Test tht ShellLauncher returns the status of completed Jobs""" + run_dir, out_file, err_file = generate_directory(test_dir) + with ( + open(out_file, "w", encoding="utf-8") as out, + open(err_file, "w", encoding="utf-8") as err, + ): + args = (helpers.expand_exe_path("ls"), "--flag_dne") + cmd = ShellLauncherCommand({}, run_dir, out, err, args) + for _ in range(5): + id = shell_launcher.start(cmd) + proc = shell_launcher._launched[id] + proc.wait() + code = shell_launcher.get_status(id)[id] + assert code == JobStatus.FAILED + + +def test_shell_launcher_returns_running_status( + shell_launcher: ShellLauncher, test_dir: str +): + """Test tht ShellLauncher returns the status of completed Jobs""" + run_dir, out_file, err_file = generate_directory(test_dir) + with ( + open(out_file, "w", encoding="utf-8") as out, + open(err_file, "w", encoding="utf-8") as err, + ): + cmd = ShellLauncherCommand( + {}, run_dir, out, err, (helpers.expand_exe_path("sleep"), "5") + ) + for _ in range(5): + id = shell_launcher.start(cmd) + code = shell_launcher.get_status(id)[id] + assert code == JobStatus.RUNNING + assert all(proc.wait() == 0 for proc in shell_launcher._launched.values()) + + +@pytest.mark.parametrize( + "psutil_status,job_status", + [ + pytest.param(psutil.STATUS_RUNNING, JobStatus.RUNNING, id="running"), + pytest.param(psutil.STATUS_SLEEPING, JobStatus.RUNNING, id="sleeping"), + pytest.param(psutil.STATUS_WAKING, JobStatus.RUNNING, id="waking"), + pytest.param(psutil.STATUS_DISK_SLEEP, JobStatus.RUNNING, id="disk_sleep"), + pytest.param(psutil.STATUS_DEAD, JobStatus.FAILED, id="dead"), + pytest.param(psutil.STATUS_TRACING_STOP, JobStatus.PAUSED, id="tracing_stop"), + pytest.param(psutil.STATUS_WAITING, JobStatus.PAUSED, id="waiting"), + pytest.param(psutil.STATUS_STOPPED, JobStatus.PAUSED, id="stopped"), + pytest.param(psutil.STATUS_LOCKED, JobStatus.PAUSED, id="locked"), + pytest.param(psutil.STATUS_PARKED, JobStatus.PAUSED, id="parked"), + pytest.param(psutil.STATUS_IDLE, JobStatus.PAUSED, id="idle"), + pytest.param(psutil.STATUS_ZOMBIE, JobStatus.COMPLETED, id="zombie"), + pytest.param( + "some-brand-new-unknown-status-str", JobStatus.UNKNOWN, id="unknown" + ), + ], +) +def test_get_status_maps_correctly( + psutil_status, job_status, monkeypatch: pytest.MonkeyPatch, test_dir: str +): + """Test tht ShellLauncher.get_status returns correct mapping""" + shell_launcher = ShellLauncher() + run_dir, out_file, err_file = generate_directory(test_dir) + with ( + open(out_file, "w", encoding="utf-8") as out, + open(err_file, "w", encoding="utf-8") as err, + ): + cmd = ShellLauncherCommand( + {}, run_dir, out, err, EchoHelloWorldEntity().as_executable_sequence() + ) + id = shell_launcher.start(cmd) + proc = shell_launcher._launched[id] + monkeypatch.setattr(proc, "poll", lambda: None) + monkeypatch.setattr(psutil.Process, "status", lambda self: psutil_status) + value = shell_launcher.get_status(id) + assert value.get(id) == job_status + assert proc.wait() == 0 + + +@pytest.mark.parametrize( + "args", + ( + pytest.param(("sleep", "60"), id="Sleep for a minute"), + *( + pytest.param( + ( + sys.executable, + "-c", + textwrap.dedent(f"""\ + import signal, time + signal.signal(signal.{signal_name}, + lambda n, f: print("Ignoring")) + time.sleep(60) + """), + ), + id=f"Process Swallows {signal_name}", + ) + for signal_name in ("SIGINT", "SIGTERM") + ), + ), +) +def test_launcher_can_stop_processes(shell_launcher, make_shell_command, args): + with make_shell_command(args) as cmd: + start = time.perf_counter() + id_ = shell_launcher.start(cmd) + time.sleep(0.1) + assert {id_: JobStatus.RUNNING} == shell_launcher.get_status(id_) + assert JobStatus.FAILED == shell_launcher._stop(id_, wait_time=0.25) + end = time.perf_counter() + assert {id_: JobStatus.FAILED} == shell_launcher.get_status(id_) + proc = shell_launcher._launched[id_] + assert proc.poll() is not None + assert proc.poll() != 0 + assert 0.1 < end - start < 1 + + +def test_launcher_can_stop_many_processes( + make_shell_command, shell_launcher, shell_cmd +): + with ( + make_shell_command(("sleep", "60")) as sleep_60, + make_shell_command(("sleep", "45")) as sleep_45, + make_shell_command(("sleep", "30")) as sleep_30, + ): + id_60 = shell_launcher.start(sleep_60) + id_45 = shell_launcher.start(sleep_45) + id_30 = shell_launcher.start(sleep_30) + id_short = shell_launcher.start(shell_cmd) + time.sleep(0.1) + assert { + id_60: JobStatus.FAILED, + id_45: JobStatus.FAILED, + id_30: JobStatus.FAILED, + id_short: JobStatus.COMPLETED, + } == shell_launcher.stop_jobs(id_30, id_45, id_60, id_short) From 879b96eeec8bdd1adb8e7154c75c813ce80d2e7d Mon Sep 17 00:00:00 2001 From: Alyssa Cote <46540273+AlyssaCote@users.noreply.github.com> Date: Mon, 28 Oct 2024 13:30:34 -0700 Subject: [PATCH 3/5] Fix logging bug in dragon_install (#761) Logs are improved during dragon install when there is a platform and asset type mismatch. [ committed by @AlyssaCote ] [ reviewed by @ankona ] --- doc/changelog.md | 1 + smartsim/_core/_cli/scripts/dragon_install.py | 8 ++- tests/_legacy/test_dragon_installer.py | 63 +++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/doc/changelog.md b/doc/changelog.md index 752957bfdc..f1d81a1c59 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- Fix dragon build logging bug - Merge core refactor into MLI feature branch - Implement asynchronous notifications for shared data - Quick bug fix in _validate diff --git a/smartsim/_core/_cli/scripts/dragon_install.py b/smartsim/_core/_cli/scripts/dragon_install.py index 7a7d75f1d2..40d2d7b64e 100644 --- a/smartsim/_core/_cli/scripts/dragon_install.py +++ b/smartsim/_core/_cli/scripts/dragon_install.py @@ -250,7 +250,11 @@ def filter_assets( if len(assets) > 0: asset = next((asset for asset in assets if _platform_filter(asset.name)), None) - if not asset: + if not asset and not is_crayex_platform(): + # non-Cray platform, HSN asset + logger.warning(f"Platform does not support HSN assets") + elif not asset and is_crayex_platform(): + # Cray platform, non HSN asset asset = assets[0] logger.warning(f"Platform-specific package not found. Using {asset.name}") @@ -267,9 +271,9 @@ def retrieve_asset_info(request: DragonInstallRequest) -> GitReleaseAsset: platform_result = check_platform() if not platform_result.is_cray: - logger.warning("Installing Dragon without HSTA support") for msg in platform_result.failures: logger.warning(msg) + logger.warning("Installing Dragon without HSTA support") if asset is None: raise SmartSimCLIActionCancelled("No dragon runtime asset available to install") diff --git a/tests/_legacy/test_dragon_installer.py b/tests/_legacy/test_dragon_installer.py index 8ce7404c5f..ff325a6aa0 100644 --- a/tests/_legacy/test_dragon_installer.py +++ b/tests/_legacy/test_dragon_installer.py @@ -29,6 +29,7 @@ import tarfile import typing as t from collections import namedtuple +from unittest.mock import MagicMock import pytest from github.GitRelease import GitRelease @@ -44,6 +45,7 @@ DragonInstallRequest, cleanup, create_dotenv, + filter_assets, install_dragon, install_package, retrieve_asset, @@ -577,3 +579,64 @@ def test_create_dotenv_format(monkeypatch: pytest.MonkeyPatch, test_dir: str): for line in lines: line_split = line.split("=") assert len(line_split) == 2 + + +@pytest.mark.parametrize( + "_platform_filter_return,is_cray_platform_return,returned_asset_bool", + [ + pytest.param( + True, + True, + True, + id="cray platform, hsn asset", + ), + pytest.param( + False, + False, + False, + id="non cray platform, hsn asset", + ), + pytest.param( + False, + True, + True, + id="cray platform, non hsn asset", + ), + pytest.param( + True, + False, + True, + id="non cray platform, non hsn asset", + ), + ], +) +def test_filter_assets( + monkeypatch, + test_dir, + _platform_filter_return, + is_cray_platform_return, + returned_asset_bool, +): + request = DragonInstallRequest(test_dir, version="0.10") + monkeypatch.setattr( + "smartsim._core._cli.scripts.dragon_install._platform_filter", + MagicMock(return_value=_platform_filter_return), + ) + monkeypatch.setattr( + "smartsim._core._cli.scripts.dragon_install._version_filter", + MagicMock(return_value=True), + ) + monkeypatch.setattr( + "smartsim._core._cli.scripts.dragon_install._pin_filter", + MagicMock(return_value=True), + ) + monkeypatch.setattr( + "smartsim._core._cli.scripts.dragon_install.is_crayex_platform", + MagicMock(return_value=is_cray_platform_return), + ) + mocked_asset = MagicMock() + asset = filter_assets(request, [mocked_asset]) + if returned_asset_bool: + assert asset is not None + else: + assert asset is None From 2d8b902f0fbbb77ae9271753cd852e225a396bca Mon Sep 17 00:00:00 2001 From: Chris McBride Date: Tue, 29 Oct 2024 15:03:56 -0400 Subject: [PATCH 4/5] Fix set-hostlist regression from merge (#770) Fix a regression introduced during the merge of core refactor into v1.0. - Ensure host list argument is passed to `DragonRunRequestView` constructor [ committed by @ankona ] [ reviewed by @al-rigazzi ] --- doc/changelog.md | 1 + .../_core/launcher/dragon/dragon_launcher.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/doc/changelog.md b/doc/changelog.md index f1d81a1c59..febf5d0b7d 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- Fix regression on hostlist param to DragonRunRequest - Fix dragon build logging bug - Merge core refactor into MLI feature branch - Implement asynchronous notifications for shared data diff --git a/smartsim/_core/launcher/dragon/dragon_launcher.py b/smartsim/_core/launcher/dragon/dragon_launcher.py index 752b6c2495..824a60588c 100644 --- a/smartsim/_core/launcher/dragon/dragon_launcher.py +++ b/smartsim/_core/launcher/dragon/dragon_launcher.py @@ -381,6 +381,20 @@ def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT: from smartsim.settings.arguments.launch.dragon import DragonLaunchArguments +def _host_list_from_run_args( + run_args: t.Dict[str, "int | str | float | None"] +) -> t.Optional[str]: + """Extract a host list from a run args dictionary. + + :param run_args: Dictionary containing available run arguments + :returns: a CSV string containing host names or None + """ + hosts: t.Optional[str] = None + if host_arg := run_args.get("host-list", None): + hosts = str(host_arg) + return hosts + + def _as_run_request_args_and_policy( run_req_args: DragonLaunchArguments, exe: t.Sequence[str], @@ -396,13 +410,14 @@ def _as_run_request_args_and_policy( exe_, *args = exe run_args = dict[str, "int | str | float | None"](run_req_args._launch_args) policy = DragonRunPolicy.from_run_args(run_args) + hosts = _host_list_from_run_args(run_args) return ( DragonRunRequestView( exe=exe_, exe_args=args, path=path, env=env, - # TODO: Not sure how this info is injected + hostlist=hosts, name=None, output_file=stdout_path, error_file=stderr_path, From f4e76a4070d80f3ee6af1877fb492fa78d27a3f9 Mon Sep 17 00:00:00 2001 From: Alyssa Cote <46540273+AlyssaCote@users.noreply.github.com> Date: Tue, 29 Oct 2024 13:10:52 -0700 Subject: [PATCH 5/5] Rewrite RequestBatch (#767) RequestBatch rewrite. [ committed by @AlyssaCote ] [ reviewed by @ankona @al-rigazzi ] --- doc/changelog.md | 1 + .../control/request_dispatcher.py | 23 +-- .../infrastructure/control/worker_manager.py | 68 ++++--- .../mli/infrastructure/worker/torch_worker.py | 33 +-- .../_core/mli/infrastructure/worker/worker.py | 192 +++++++++++------- .../test_core_machine_learning_worker.py | 102 +++++++--- tests/dragon_wlm/test_device_manager.py | 22 +- tests/dragon_wlm/test_error_handling.py | 33 ++- tests/dragon_wlm/test_request_dispatcher.py | 96 ++++++++- tests/dragon_wlm/test_torch_worker.py | 19 +- 10 files changed, 407 insertions(+), 182 deletions(-) diff --git a/doc/changelog.md b/doc/changelog.md index febf5d0b7d..98ef6c0f50 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- RequestBatch rewrite - Fix regression on hostlist param to DragonRunRequest - Fix dragon build logging bug - Merge core refactor into MLI feature branch diff --git a/smartsim/_core/mli/infrastructure/control/request_dispatcher.py b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py index e22a2c8f62..6dafd9b7c9 100644 --- a/smartsim/_core/mli/infrastructure/control/request_dispatcher.py +++ b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py @@ -172,7 +172,7 @@ def can_be_removed(self) -> bool: """ return self.empty() and self._disposable - def flush(self) -> list[t.Any]: + def flush(self) -> list[InferenceRequest]: """Get all requests from queue. :returns: Requests waiting to be executed @@ -334,7 +334,7 @@ def _check_callback(self, request: InferenceRequest) -> bool: :param request: The request to validate :returns: False if callback validation fails for the request, True otherwise """ - if request.callback: + if request.callback_desc: return True logger.error("No callback channel provided in request") @@ -376,9 +376,7 @@ def _on_iteration(self) -> None: tensor_bytes_list = bytes_list[1:] self._perf_timer.start_timings() - request = self._worker.deserialize_message( - request_bytes, self._callback_factory - ) + request = self._worker.deserialize_message(request_bytes) if request.has_input_meta and tensor_bytes_list: request.raw_inputs = tensor_bytes_list @@ -387,7 +385,11 @@ def _on_iteration(self) -> None: if not self._validate_request(request): exception_handler( ValueError("Error validating the request"), - request.callback, + ( + self._callback_factory(request.callback_desc) + if request.callback_desc + else None + ), None, ) self._perf_timer.measure_time("validate_request") @@ -493,10 +495,8 @@ def flush_requests(self) -> None: if queue.ready: self._perf_timer.measure_time("find_queue") try: - batch = RequestBatch( - requests=queue.flush(), - inputs=None, - model_id=queue.model_id, + batch = RequestBatch.from_requests( + queue.flush(), queue.model_id ) finally: self._perf_timer.measure_time("flush_requests") @@ -528,9 +528,6 @@ def flush_requests(self) -> None: self._perf_timer.measure_time("transform_input") batch.inputs = transformed_inputs - for request in batch.requests: - request.raw_inputs = [] - request.input_meta = [] try: self._outgoing_queue.put(batch) diff --git a/smartsim/_core/mli/infrastructure/control/worker_manager.py b/smartsim/_core/mli/infrastructure/control/worker_manager.py index bf6fddb81d..1c93276074 100644 --- a/smartsim/_core/mli/infrastructure/control/worker_manager.py +++ b/smartsim/_core/mli/infrastructure/control/worker_manager.py @@ -132,8 +132,10 @@ def _check_feature_stores(self, batch: RequestBatch) -> bool: fs_model: t.Set[str] = set() if batch.model_id.key: fs_model = {batch.model_id.descriptor} - fs_inputs = {key.descriptor for key in batch.input_keys} - fs_outputs = {key.descriptor for key in batch.output_keys} + fs_inputs = {key.descriptor for keys in batch.input_keys for key in keys} + fs_outputs = { + key.descriptor for keys in batch.output_key_refs.values() for key in keys + } # identify which feature stores are requested and unknown fs_desired = fs_model.union(fs_inputs).union(fs_outputs) @@ -159,7 +161,7 @@ def _validate_batch(self, batch: RequestBatch) -> bool: :param batch: The batch of requests to validate :returns: False if the request fails any validation checks, True otherwise """ - if batch is None or not batch.has_valid_requests: + if batch is None or not batch.has_callbacks: return False return self._check_feature_stores(batch) @@ -188,7 +190,7 @@ def _on_iteration(self) -> None: return if not self._device_manager: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: msg = "No Device Manager found. WorkerManager._on_start() " "must be called after initialization. If possible, " "you should use `WorkerManager.execute()` instead of " @@ -200,7 +202,7 @@ def _on_iteration(self) -> None: "and will not be processed." exception_handler( RuntimeError(msg), - request.callback, + self._callback_factory(callback_desc), "Error acquiring device manager", ) return @@ -212,10 +214,10 @@ def _on_iteration(self) -> None: feature_stores=self._feature_stores, ) except Exception as exc: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: exception_handler( exc, - request.callback, + self._callback_factory(callback_desc), "Error loading model on device or getting device.", ) return @@ -226,18 +228,20 @@ def _on_iteration(self) -> None: try: model_result = LoadModelResult(device.get_model(batch.model_id.key)) except Exception as exc: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: exception_handler( - exc, request.callback, "Error getting model from device." + exc, + self._callback_factory(callback_desc), + "Error getting model from device.", ) return self._perf_timer.measure_time("load_model") if not batch.inputs: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: exception_handler( ValueError("Error batching inputs"), - request.callback, + self._callback_factory(callback_desc), None, ) return @@ -248,8 +252,12 @@ def _on_iteration(self) -> None: batch, model_result, transformed_input, device.name ) except Exception as e: - for request in batch.requests: - exception_handler(e, request.callback, "Error while executing.") + for callback_desc in batch.callback_descriptors: + exception_handler( + e, + self._callback_factory(callback_desc), + "Error while executing.", + ) return self._perf_timer.measure_time("execute") @@ -258,24 +266,35 @@ def _on_iteration(self) -> None: batch, execute_result ) except Exception as e: - for request in batch.requests: + for callback_desc in batch.callback_descriptors: exception_handler( - e, request.callback, "Error while transforming the output." + e, + self._callback_factory(callback_desc), + "Error while transforming the output.", ) return - for request, transformed_output in zip(batch.requests, transformed_outputs): + for callback_desc, transformed_output in zip( + batch.callback_descriptors, transformed_outputs + ): reply = InferenceReply() - if request.has_output_keys: + if batch.output_key_refs: try: + output_keys = batch.output_key_refs[callback_desc] reply.output_keys = self._worker.place_output( - request, + output_keys, transformed_output, self._feature_stores, ) + except KeyError: + # the callback is not in the output_key_refs dict + # because it doesn't have output_keys associated with it + continue except Exception as e: exception_handler( - e, request.callback, "Error while placing the output." + e, + self._callback_factory(callback_desc), + "Error while placing the output.", ) continue else: @@ -302,12 +321,11 @@ def _on_iteration(self) -> None: self._perf_timer.measure_time("serialize_resp") - if request.callback: - request.callback.send(serialized_resp) - if reply.has_outputs: - # send tensor data after response - for output in reply.outputs: - request.callback.send(output) + callback = self._callback_factory(callback_desc) + callback.send(serialized_resp) + if reply.has_outputs: + for output in reply.outputs: + callback.send(output) self._perf_timer.measure_time("send") self._perf_timer.end_timings() diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py index 64e94e5eb6..f8d6e7c2de 100644 --- a/smartsim/_core/mli/infrastructure/worker/torch_worker.py +++ b/smartsim/_core/mli/infrastructure/worker/torch_worker.py @@ -34,7 +34,6 @@ from .....error import SmartSimError from .....log import get_logger -from ...mli_schemas.tensor import tensor_capnp from .worker import ( ExecuteResult, FetchInputResult, @@ -42,6 +41,7 @@ LoadModelResult, MachineLearningWorkerBase, RequestBatch, + TensorMeta, TransformInputResult, TransformOutputResult, ) @@ -64,7 +64,8 @@ def load_model( """Given a loaded MachineLearningModel, ensure it is loaded into device memory. - :param request: The request that triggered the pipeline + :param batch: The batch that triggered the pipeline + :param fetch_result: The fetched model :param device: The device on which the model must be placed :returns: LoadModelResult wrapping the model loaded for the request :raises ValueError: If model reference object is not found @@ -97,13 +98,13 @@ def load_model( @staticmethod def transform_input( batch: RequestBatch, - fetch_results: list[FetchInputResult], + fetch_results: FetchInputResult, mem_pool: MemoryPool, ) -> TransformInputResult: """Given a collection of data, perform a transformation on the data and put the raw tensor data on a MemoryPool allocation. - :param request: The request that triggered the pipeline + :param batch: The batch that triggered the pipeline :param fetch_result: Raw outputs from fetching inputs out of a feature store :param mem_pool: The memory pool used to access batched input tensors :returns: The transformed inputs wrapped in a TransformInputResult @@ -116,32 +117,32 @@ def transform_input( all_dims: list[list[int]] = [] all_dtypes: list[str] = [] - if fetch_results[0].meta is None: + if fetch_results.meta is None: raise ValueError("Cannot reconstruct tensor without meta information") # Traverse inputs to get total number of samples and compute slices # Assumption: first dimension is samples, all tensors in the same input # have same number of samples # thus we only look at the first tensor for each input - for res_idx, fetch_result in enumerate(fetch_results): - if fetch_result.meta is None or any( - item_meta is None for item_meta in fetch_result.meta + for res_idx, res_meta_list in enumerate(fetch_results.meta): + if res_meta_list is None or any( + item_meta is None for item_meta in res_meta_list ): raise ValueError("Cannot reconstruct tensor without meta information") - first_tensor_desc: tensor_capnp.TensorDescriptor = fetch_result.meta[0] + first_tensor_desc: TensorMeta = res_meta_list[0] # type: ignore num_samples = first_tensor_desc.dimensions[0] slices.append(slice(total_samples, total_samples + num_samples)) total_samples = total_samples + num_samples - if res_idx == len(fetch_results) - 1: + if res_idx == len(fetch_results.meta) - 1: # For each tensor in the last input, get remaining dimensions # Assumptions: all inputs have the same number of tensors and # last N-1 dimensions match across inputs for corresponding tensors # thus: resulting array will be of size (num_samples, all_other_dims) - for item_meta in fetch_result.meta: - tensor_desc: tensor_capnp.TensorDescriptor = item_meta - tensor_dims = list(tensor_desc.dimensions) + for item_meta in res_meta_list: + tensor_desc: TensorMeta = item_meta # type: ignore + tensor_dims = tensor_desc.dimensions all_dims.append([total_samples, *tensor_dims[1:]]) - all_dtypes.append(str(tensor_desc.dataType)) + all_dtypes.append(tensor_desc.datatype) for result_tensor_idx, (dims, dtype) in enumerate(zip(all_dims, all_dtypes)): itemsize = np.empty((1), dtype=dtype).itemsize @@ -151,8 +152,8 @@ def transform_input( try: mem_view[:alloc_size] = b"".join( [ - fetch_result.inputs[result_tensor_idx] - for fetch_result in fetch_results + fetch_result[result_tensor_idx] + for fetch_result in fetch_results.inputs ] ) except IndexError as e: diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index 9556b8e438..b122a1d9ba 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -36,14 +36,13 @@ from .....error import SmartSimError from .....log import get_logger -from ...comm.channel.channel import CommChannelBase from ...message_handler import MessageHandler from ...mli_schemas.model.model_capnp import Model +from ...mli_schemas.tensor.tensor_capnp import TensorDescriptor from ..storage.feature_store import FeatureStore, ModelKey, TensorKey if t.TYPE_CHECKING: from smartsim._core.mli.mli_schemas.response.response_capnp import Status - from smartsim._core.mli.mli_schemas.tensor.tensor_capnp import TensorDescriptor logger = get_logger(__name__) @@ -57,10 +56,10 @@ class InferenceRequest: def __init__( self, model_key: t.Optional[ModelKey] = None, - callback: t.Optional[CommChannelBase] = None, + callback_desc: t.Optional[str] = None, raw_inputs: t.Optional[t.List[bytes]] = None, input_keys: t.Optional[t.List[TensorKey]] = None, - input_meta: t.Optional[t.List[t.Any]] = None, + input_meta: t.Optional[t.List[TensorDescriptor]] = None, output_keys: t.Optional[t.List[TensorKey]] = None, raw_model: t.Optional[Model] = None, batch_size: int = 0, @@ -68,7 +67,8 @@ def __init__( """Initialize the InferenceRequest. :param model_key: A tuple containing a (key, descriptor) pair - :param callback: The channel used for notification of inference completion + :param callback_desc: The channel descriptor used for notification + of inference completion :param raw_inputs: Raw bytes of tensor inputs :param input_keys: A list of tuples containing a (key, descriptor) pair :param input_meta: Metadata about the input data @@ -80,7 +80,7 @@ def __init__( """A tuple containing a (key, descriptor) pair""" self.raw_model = raw_model """Raw bytes of an ML model""" - self.callback = callback + self.callback_desc = callback_desc """The channel used for notification of inference completion""" self.raw_inputs = raw_inputs or [] """Raw bytes of tensor inputs""" @@ -191,6 +191,20 @@ def has_output_keys(self) -> bool: return self.output_keys is not None and bool(self.output_keys) +@dataclass +class TensorMeta: + """Metadata about a tensor, built from TensorDescriptors.""" + + dimensions: t.List[int] + """Dimensions of the tensor""" + order: str + """Order of the tensor in row major ("c"), or + column major ("f") format""" + datatype: str + """Datatype of the tensor as specified by the TensorDescriptor + NumericalType enums. Examples include "float32", "int8", etc.""" + + class LoadModelResult: """A wrapper around a loaded model.""" @@ -253,7 +267,11 @@ def __init__(self, result: t.Any, slices: list[slice]) -> None: class FetchInputResult: """A wrapper around fetched inputs.""" - def __init__(self, result: t.List[bytes], meta: t.Optional[t.List[t.Any]]) -> None: + def __init__( + self, + result: t.List[t.List[bytes]], + meta: t.List[t.List[t.Optional[TensorMeta]]], + ) -> None: """Initialize the FetchInputResult. :param result: List of input tensor bytes @@ -316,20 +334,31 @@ def __init__(self, result: bytes) -> None: class RequestBatch: """A batch of aggregated inference requests.""" - requests: list[InferenceRequest] - """List of InferenceRequests in the batch""" + raw_model: t.Optional[Model] + """Raw bytes of the model""" + callback_descriptors: t.List[str] + """The descriptors for channels used for notification of inference completion""" + raw_inputs: t.List[t.List[bytes]] + """Raw bytes of tensor inputs""" + input_meta: t.List[t.List[TensorMeta]] + """Metadata about the input data""" + input_keys: t.List[t.List[TensorKey]] + """A list of tuples containing a (key, descriptor) pair""" + output_key_refs: t.Dict[str, t.List[TensorKey]] + """A dictionary mapping callbacks descriptors to output keys""" inputs: t.Optional[TransformInputResult] """Transformed batch of input tensors""" model_id: "ModelIdentifier" """Model (key, descriptor) tuple""" @property - def has_valid_requests(self) -> bool: - """Returns whether the batch contains at least one request. + def has_callbacks(self) -> bool: + """Determines if the batch has at least one callback channel + available for sending results. - :returns: True if at least one request is available + :returns: True if at least one callback is present """ - return len(self.requests) > 0 + return len(self.callback_descriptors) > 0 @property def has_raw_model(self) -> bool: @@ -339,37 +368,49 @@ def has_raw_model(self) -> bool: """ return self.raw_model is not None - @property - def raw_model(self) -> t.Optional[t.Any]: - """Returns the raw model to use to execute for this batch - if it is available. - - :returns: A model if available, otherwise None""" - if self.has_valid_requests: - return self.requests[0].raw_model - return None - - @property - def input_keys(self) -> t.List[TensorKey]: - """All input keys available in this batch's requests. - - :returns: All input keys belonging to requests in this batch""" - keys = [] - for request in self.requests: - keys.extend(request.input_keys) - - return keys - - @property - def output_keys(self) -> t.List[TensorKey]: - """All output keys available in this batch's requests. - - :returns: All output keys belonging to requests in this batch""" - keys = [] - for request in self.requests: - keys.extend(request.output_keys) - - return keys + @classmethod + def from_requests( + cls, + requests: t.List[InferenceRequest], + model_id: ModelIdentifier, + ) -> "RequestBatch": + """Create a RequestBatch from a list of requests. + + :param requests: The requests to batch + :param model_id: The model identifier + :returns: A RequestBatch instance + """ + return cls( + raw_model=requests[0].raw_model, + callback_descriptors=[ + request.callback_desc for request in requests if request.callback_desc + ], + raw_inputs=[ + request.raw_inputs for request in requests if request.raw_inputs + ], + input_meta=[ + [ + TensorMeta( + dimensions=list(meta.dimensions), + order=str(meta.order), + datatype=str(meta.dataType), + ) + for meta in request.input_meta + ] + for request in requests + if request.input_meta + ], + input_keys=[ + request.input_keys for request in requests if request.input_keys + ], + output_key_refs={ + request.callback_desc: request.output_keys + for request in requests + if request.callback_desc and request.output_keys + }, + inputs=None, + model_id=model_id, + ) class MachineLearningWorkerCore: @@ -378,13 +419,10 @@ class MachineLearningWorkerCore: @staticmethod def deserialize_message( data_blob: bytes, - callback_factory: t.Callable[[str], CommChannelBase], ) -> InferenceRequest: """Deserialize a message from a byte stream into an InferenceRequest. :param data_blob: The byte stream to deserialize - :param callback_factory: A factory method that can create an instance - of the desired concrete comm channel type :returns: The raw input message deserialized into an InferenceRequest """ request = MessageHandler.deserialize_request(data_blob) @@ -400,7 +438,6 @@ def deserialize_message( model_bytes = request.model.data callback_key = request.replyChannel.descriptor - comm_channel = callback_factory(callback_key) input_keys: t.Optional[t.List[TensorKey]] = None input_bytes: t.Optional[t.List[bytes]] = None output_keys: t.Optional[t.List[TensorKey]] = None @@ -422,7 +459,7 @@ def deserialize_message( inference_request = InferenceRequest( model_key=model_key, - callback=comm_channel, + callback_desc=callback_key, raw_inputs=input_bytes, input_meta=input_meta, input_keys=input_keys, @@ -497,7 +534,7 @@ def fetch_model( @staticmethod def fetch_inputs( batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] - ) -> t.List[FetchInputResult]: + ) -> FetchInputResult: """Given a collection of ResourceKeys, identify the physical location and input metadata. @@ -507,49 +544,52 @@ def fetch_inputs( :raises ValueError: If neither an input key or an input tensor are provided :raises SmartSimError: If a tensor for a given key cannot be retrieved """ - fetch_results = [] - for request in batch.requests: - if request.raw_inputs: - fetch_results.append( - FetchInputResult(request.raw_inputs, request.input_meta) - ) - continue - - if not feature_stores: - raise ValueError("No input and no feature store provided") - - if request.has_input_keys: - data: t.List[bytes] = [] + if not batch.raw_inputs and not batch.input_keys: + raise ValueError("No input source") - for fs_key in request.input_keys: + if not feature_stores: + raise ValueError("No feature stores provided") + + data_list: t.List[t.List[bytes]] = [] + meta_list: t.List[t.List[t.Optional[TensorMeta]]] = [] + # meta_list will be t.List[t.List[TensorMeta]] once input_key metadata + # is available to be retrieved from the feature store + + if batch.raw_inputs: + for raw_inputs, input_meta in zip(batch.raw_inputs, batch.input_meta): + data_list.append(raw_inputs) + meta_list.append(input_meta) # type: ignore + + if batch.input_keys: + for batch_keys in batch.input_keys: + batch_data: t.List[bytes] = [] + for fs_key in batch_keys: try: feature_store = feature_stores[fs_key.descriptor] tensor_bytes = t.cast(bytes, feature_store[fs_key.key]) - data.append(tensor_bytes) + batch_data.append(tensor_bytes) except KeyError as ex: logger.exception(ex) raise SmartSimError( f"Tensor could not be retrieved with key {fs_key.key}" ) from ex - fetch_results.append( - FetchInputResult(data, meta=None) - ) # fixme: need to get both tensor and descriptor - continue + data_list.append(batch_data) + meta_list.append([None] * len(batch_data)) + # fixme: need to get both tensor and descriptor + # this will eventually append meta info retrieved from the feature store - raise ValueError("No input source") - - return fetch_results + return FetchInputResult(result=data_list, meta=meta_list) @staticmethod def place_output( - request: InferenceRequest, + output_keys: t.List[TensorKey], transform_result: TransformOutputResult, feature_stores: t.Dict[str, FeatureStore], ) -> t.Collection[t.Optional[TensorKey]]: """Given a collection of data, make it available as a shared resource in the feature store. - :param request: The request that triggered the pipeline + :param output_keys: The output_keys that will be placed in the feature store :param transform_result: Transformed version of the inference result :param feature_stores: Available feature stores used for persistence :returns: A collection of keys that were placed in the feature store @@ -563,7 +603,7 @@ def place_output( # accurately placed, datum might need to include this. # Consider parallelizing all PUT feature_store operations - for fs_key, v in zip(request.output_keys, transform_result.outputs): + for fs_key, v in zip(output_keys, transform_result.outputs): feature_store = feature_stores[fs_key.descriptor] feature_store[fs_key.key] = v keys.append(fs_key) @@ -596,7 +636,7 @@ def load_model( @abstractmethod def transform_input( batch: RequestBatch, - fetch_results: list[FetchInputResult], + fetch_results: FetchInputResult, mem_pool: MemoryPool, ) -> TransformInputResult: """Given a collection of data, perform a transformation on the data and put diff --git a/tests/dragon_wlm/test_core_machine_learning_worker.py b/tests/dragon_wlm/test_core_machine_learning_worker.py index f9295d9e86..febdb00c8d 100644 --- a/tests/dragon_wlm/test_core_machine_learning_worker.py +++ b/tests/dragon_wlm/test_core_machine_learning_worker.py @@ -42,6 +42,7 @@ TransformOutputResult, ) +from .channel import FileSystemCommChannel from .feature_store import FileSystemFeatureStore, MemoryFeatureStore # The tests in this file belong to the dragon group @@ -100,7 +101,7 @@ def test_fetch_model_disk(persist_torch_model: pathlib.Path, test_dir: str) -> N model_key = ModelKey(key=key, descriptor=fsd) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_model(batch, {fsd: feature_store}) assert fetch_result.model_bytes @@ -118,7 +119,7 @@ def test_fetch_model_disk_missing() -> None: model_key = ModelKey(key=key, descriptor=fsd) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) with pytest.raises(sse.SmartSimError) as ex: worker.fetch_model(batch, {fsd: feature_store}) @@ -143,7 +144,7 @@ def test_fetch_model_feature_store(persist_torch_model: pathlib.Path) -> None: model_key = ModelKey(key=key, descriptor=feature_store.descriptor) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_model(batch, {fsd: feature_store}) assert fetch_result.model_bytes @@ -161,7 +162,7 @@ def test_fetch_model_feature_store_missing() -> None: model_key = ModelKey(key=key, descriptor=feature_store.descriptor) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) # todo: consider that raising this exception shows impl. replace... with pytest.raises(sse.SmartSimError) as ex: @@ -184,7 +185,7 @@ def test_fetch_model_memory(persist_torch_model: pathlib.Path) -> None: model_key = ModelKey(key=key, descriptor=feature_store.descriptor) request = InferenceRequest(model_key=model_key) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_model(batch, {fsd: feature_store}) assert fetch_result.model_bytes @@ -202,14 +203,14 @@ def test_fetch_input_disk(persist_torch_tensor: pathlib.Path) -> None: request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) worker = MachineLearningWorkerCore feature_store[tensor_name] = persist_torch_tensor.read_bytes() fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) - assert fetch_result[0].inputs is not None + assert fetch_result.inputs[0] is not None def test_fetch_input_disk_missing() -> None: @@ -224,7 +225,7 @@ def test_fetch_input_disk_missing() -> None: request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) with pytest.raises(sse.SmartSimError) as ex: worker.fetch_inputs(batch, {fsd: feature_store}) @@ -249,12 +250,12 @@ def test_fetch_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: feature_store[tensor_name] = persist_torch_tensor.read_bytes() model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) - assert fetch_result[0].inputs + assert fetch_result.inputs[0] assert ( - list(fetch_result[0].inputs)[0][:10] == persist_torch_tensor.read_bytes()[:10] + list(fetch_result.inputs[0])[0][:10] == persist_torch_tensor.read_bytes()[:10] ) @@ -287,12 +288,11 @@ def test_fetch_multi_input_feature_store(persist_torch_tensor: pathlib.Path) -> ) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) - raw_bytes = list(fetch_result[0].inputs) - assert raw_bytes + raw_bytes = list(fetch_result.inputs[0]) assert raw_bytes[0][:10] == persist_torch_tensor.read_bytes()[:10] assert raw_bytes[1][:10] == body2[:10] assert raw_bytes[2][:10] == body3[:10] @@ -309,7 +309,7 @@ def test_fetch_input_feature_store_missing() -> None: request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) with pytest.raises(sse.SmartSimError) as ex: worker.fetch_inputs(batch, {fsd: feature_store}) @@ -331,13 +331,43 @@ def test_fetch_input_memory(persist_torch_tensor: pathlib.Path) -> None: request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) model_key = ModelKey(key="test-model", descriptor=fsd) - batch = RequestBatch([request], None, model_key) + batch = RequestBatch.from_requests([request], model_key) fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) - assert fetch_result[0].inputs is not None + assert fetch_result.inputs[0] is not None -def test_place_outputs() -> None: +def test_fetch_inputs_no_input_source(): + """Verify that the ML worker fails when there are no inputs provided""" + worker = MachineLearningWorkerCore + request1 = InferenceRequest() + request2 = InferenceRequest() + + batch = RequestBatch.from_requests( + [request1, request2], ModelKey("test-model", "desc") + ) + + with pytest.raises(ValueError) as exc: + fetch_result = worker.fetch_inputs(batch, {}) + assert str(exc.value) == "No input source" + + +def test_fetch_inputs_no_feature_store(): + """Verify that the ML worker fails when there is no feature store""" + worker = MachineLearningWorkerCore + request1 = InferenceRequest(raw_inputs=[b"abcdef"]) + request2 = InferenceRequest(raw_inputs=[b"ghijkl"]) + + batch = RequestBatch.from_requests( + [request1, request2], ModelKey("test-model", "desc") + ) + + with pytest.raises(ValueError) as exc: + fetch_result = worker.fetch_inputs(batch, {}) + assert str(exc.value) == "No feature stores provided" + + +def test_place_outputs(test_dir: str) -> None: """Verify outputs are shared using the feature store""" worker = MachineLearningWorkerCore @@ -351,18 +381,42 @@ def test_place_outputs() -> None: TensorKey(key=key_name + "2", descriptor=fsd), TensorKey(key=key_name + "3", descriptor=fsd), ] + + keys2 = [ + TensorKey(key=key_name + "4", descriptor=fsd), + TensorKey(key=key_name + "5", descriptor=fsd), + TensorKey(key=key_name + "6", descriptor=fsd), + ] data = [b"abcdef", b"ghijkl", b"mnopqr"] + data2 = [b"stuvwx", b"yzabcd", b"efghij"] - for fsk, v in zip(keys, data): - feature_store[fsk.key] = v + callback1 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback1").descriptor + callback2 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback2").descriptor - request = InferenceRequest(output_keys=keys) + model_id = ModelKey(key="test-model", descriptor=fsd) + request = InferenceRequest(callback_desc=callback1, output_keys=keys) + request2 = InferenceRequest(callback_desc=callback2, output_keys=keys2) transform_result = TransformOutputResult(data, [1], "c", "float32") + transform_result2 = TransformOutputResult(data2, [1], "c", "float32") - worker.place_output(request, transform_result, {fsd: feature_store}) + request_batch = RequestBatch.from_requests([request, request2], model_id) + + worker.place_output( + request_batch.output_key_refs[callback1], + transform_result, + {fsd: feature_store}, + ) + + worker.place_output( + request_batch.output_key_refs[callback2], + transform_result2, + {fsd: feature_store}, + ) - for i in range(3): - assert feature_store[keys[i].key] == data[i] + all_keys = keys + keys2 + all_data = data + data2 + for i in range(6): + assert feature_store[all_keys[i].key] == all_data[i] @pytest.mark.parametrize( diff --git a/tests/dragon_wlm/test_device_manager.py b/tests/dragon_wlm/test_device_manager.py index d270e921cb..8a82c3d32b 100644 --- a/tests/dragon_wlm/test_device_manager.py +++ b/tests/dragon_wlm/test_device_manager.py @@ -123,7 +123,7 @@ def test_device_manager_model_in_request(): request = InferenceRequest( model_key=model_key, - callback=None, + callback_desc=None, raw_inputs=None, input_keys=[tensor_key], input_meta=None, @@ -132,10 +132,13 @@ def test_device_manager_model_in_request(): batch_size=0, ) - request_batch = RequestBatch( + request_batch = RequestBatch.from_requests( [request], - TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), - model_id=model_key, + model_key, + ) + + request_batch.inputs = TransformInputResult( + b"transformed", [slice(0, 1)], [[1, 2]], ["float32"] ) with device_manager.get_device( @@ -161,7 +164,7 @@ def test_device_manager_model_key(): request = InferenceRequest( model_key=model_key, - callback=None, + callback_desc=None, raw_inputs=None, input_keys=[tensor_key], input_meta=None, @@ -170,10 +173,13 @@ def test_device_manager_model_key(): batch_size=0, ) - request_batch = RequestBatch( + request_batch = RequestBatch.from_requests( [request], - TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), - model_id=model_key, + model_key, + ) + + request_batch.inputs = TransformInputResult( + b"transformed", [slice(0, 1)], [[1, 2]], ["float32"] ) with device_manager.get_device( diff --git a/tests/dragon_wlm/test_error_handling.py b/tests/dragon_wlm/test_error_handling.py index aacd47b556..a0bfbd7e9f 100644 --- a/tests/dragon_wlm/test_error_handling.py +++ b/tests/dragon_wlm/test_error_handling.py @@ -24,11 +24,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pathlib import typing as t from unittest.mock import MagicMock import pytest +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel + dragon = pytest.importorskip("dragon") import multiprocessing as mp @@ -123,9 +126,13 @@ def setup_worker_manager_model_bytes( tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + callback_descriptor = FileSystemCommChannel( + pathlib.Path(test_dir) / "callback1" + ).descriptor + inf_request = InferenceRequest( model_key=None, - callback=None, + callback_desc=callback_descriptor, raw_inputs=None, input_keys=[tensor_key], input_meta=None, @@ -136,10 +143,13 @@ def setup_worker_manager_model_bytes( model_id = ModelKey(key="key", descriptor=app_feature_store.descriptor) - request_batch = RequestBatch( + request_batch = RequestBatch.from_requests( [inf_request], - TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), - model_id=model_id, + model_id, + ) + + request_batch.inputs = TransformInputResult( + b"transformed", [slice(0, 1)], [[1, 2]], ["float32"] ) dispatcher_task_queue.put(request_batch) @@ -182,9 +192,13 @@ def setup_worker_manager_model_key( output_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) model_id = ModelKey(key="model key", descriptor=app_feature_store.descriptor) + callback_descriptor = FileSystemCommChannel( + pathlib.Path(test_dir) / "callback2" + ).descriptor + request = InferenceRequest( model_key=model_id, - callback=None, + callback_desc=callback_descriptor, raw_inputs=None, input_keys=[tensor_key], input_meta=None, @@ -192,10 +206,13 @@ def setup_worker_manager_model_key( raw_model=b"model", batch_size=0, ) - request_batch = RequestBatch( + request_batch = RequestBatch.from_requests( [request], - TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), - model_id=model_id, + model_id, + ) + + request_batch.inputs = TransformInputResult( + b"transformed", [slice(0, 1)], [[1, 2]], ["float32"] ) dispatcher_task_queue.put(request_batch) diff --git a/tests/dragon_wlm/test_request_dispatcher.py b/tests/dragon_wlm/test_request_dispatcher.py index 8dc0f67a31..1b1fb1b013 100644 --- a/tests/dragon_wlm/test_request_dispatcher.py +++ b/tests/dragon_wlm/test_request_dispatcher.py @@ -26,6 +26,7 @@ import gc import os +import pathlib import time import typing as t from queue import Empty @@ -49,7 +50,6 @@ # isort: on - from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel from smartsim._core.mli.comm.channel.dragon_util import create_local @@ -69,9 +69,13 @@ from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( DragonFeatureStore, ) +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey, TensorKey from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.infrastructure.worker.worker import InferenceRequest, TensorMeta +from smartsim._core.mli.message_handler import MessageHandler from smartsim.log import get_logger +from .utils.channel import FileSystemCommChannel from .utils.msg_pump import mock_messages logger = get_logger(__name__) @@ -161,7 +165,6 @@ def test_request_dispatcher( processes.append(process) process.start() assert process.returncode is None, "The message pump failed to start" - # give dragon some time to populate the message queues for i in range(15): try: @@ -177,7 +180,7 @@ def test_request_dispatcher( raise exc assert batch is not None - assert batch.has_valid_requests + assert batch.has_callbacks model_key = batch.model_id.key @@ -200,7 +203,7 @@ def test_request_dispatcher( ) ) - assert len(batch.requests) == 2 + assert len(batch.callback_descriptors) == 2 assert batch.model_id.key == model_key assert model_key in request_dispatcher._queues assert model_key in request_dispatcher._active_queues @@ -235,3 +238,88 @@ def test_request_dispatcher( # Try to remove the dispatcher and free the memory del request_dispatcher gc.collect() + + +def test_request_batch(test_dir: str) -> None: + """Test the RequestBatch.from_requests instantiates properly""" + tensor_key = TensorKey(key="key", descriptor="desc1") + tensor_key2 = TensorKey(key="key2", descriptor="desc1") + output_key = TensorKey(key="key", descriptor="desc2") + output_key2 = TensorKey(key="key2", descriptor="desc2") + model_id1 = ModelKey(key="model key", descriptor="model desc") + model_id2 = ModelKey(key="model key2", descriptor="model desc") + tensor_desc = MessageHandler.build_tensor_descriptor("c", "float32", [1, 2]) + req_batch_model_id = ModelKey(key="req key", descriptor="desc") + + callback1 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback1").descriptor + callback2 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback2").descriptor + callback3 = FileSystemCommChannel(pathlib.Path(test_dir) / "callback3").descriptor + + request1 = InferenceRequest( + model_key=model_id1, + callback_desc=callback1, + raw_inputs=[b"input data"], + input_keys=[tensor_key], + input_meta=[tensor_desc], + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + + request2 = InferenceRequest( + model_key=model_id2, + callback_desc=callback2, + raw_inputs=None, + input_keys=None, + input_meta=None, + output_keys=[output_key, output_key2], + raw_model=b"model", + batch_size=0, + ) + + request3 = InferenceRequest( + model_key=model_id2, + callback_desc=callback3, + raw_inputs=None, + input_keys=[tensor_key, tensor_key2], + input_meta=[tensor_desc], + output_keys=None, + raw_model=b"model", + batch_size=0, + ) + + request_batch = RequestBatch.from_requests( + [request1, request2, request3], req_batch_model_id + ) + + assert len(request_batch.callback_descriptors) == 3 + for callback in request_batch.callback_descriptors: + assert isinstance(callback, str) + assert len(request_batch.output_key_refs.keys()) == 2 + assert request_batch.has_callbacks + assert request_batch.model_id == req_batch_model_id + assert request_batch.inputs == None + assert request_batch.raw_model == b"model" + assert request_batch.raw_inputs == [[b"input data"]] + assert request_batch.input_meta == [ + [TensorMeta([1, 2], "c", "float32")], + [TensorMeta([1, 2], "c", "float32")], + ] + assert request_batch.input_keys == [ + [tensor_key], + [tensor_key, tensor_key2], + ] + assert request_batch.output_key_refs == { + callback1: [output_key], + callback2: [output_key, output_key2], + } + + +def test_request_batch_has_no_callbacks(): + """Verify that a request batch with no callbacks is correctly identified""" + request1 = InferenceRequest() + request2 = InferenceRequest() + + batch = RequestBatch.from_requests([request1, request2], ModelKey("model", "desc")) + + assert not batch.has_callbacks diff --git a/tests/dragon_wlm/test_torch_worker.py b/tests/dragon_wlm/test_torch_worker.py index 2a9e7d01bd..5be0177c5a 100644 --- a/tests/dragon_wlm/test_torch_worker.py +++ b/tests/dragon_wlm/test_torch_worker.py @@ -110,7 +110,7 @@ def get_request() -> InferenceRequest: return InferenceRequest( model_key=ModelKey(key="model", descriptor="xyz"), - callback=None, + callback_desc=None, raw_inputs=tensor_numpy, input_keys=None, input_meta=serialized_tensors_descriptors, @@ -121,10 +121,13 @@ def get_request() -> InferenceRequest: def get_request_batch_from_request( - request: InferenceRequest, inputs: t.Optional[TransformInputResult] = None + request: InferenceRequest, ) -> RequestBatch: - return RequestBatch([request], inputs, request.model_key) + return RequestBatch.from_requests( + [request], + request.model_key, + ) sample_request: InferenceRequest = get_request() @@ -146,13 +149,13 @@ def test_load_model(mlutils) -> None: def test_transform_input(mlutils) -> None: fetch_input_result = FetchInputResult( - sample_request.raw_inputs, sample_request.input_meta + sample_request_batch.raw_inputs, sample_request_batch.input_meta ) mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) transform_input_result = worker.transform_input( - sample_request_batch, [fetch_input_result], mem_pool + sample_request_batch, fetch_input_result, mem_pool ) batch = get_batch().numpy() @@ -184,15 +187,15 @@ def test_execute(mlutils) -> None: Net().to(torch_device[mlutils.get_test_device().lower()]) ) fetch_input_result = FetchInputResult( - sample_request.raw_inputs, sample_request.input_meta + sample_request_batch.raw_inputs, sample_request_batch.input_meta ) - request_batch = get_request_batch_from_request(sample_request, fetch_input_result) + request_batch = get_request_batch_from_request(sample_request) mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) transform_result = worker.transform_input( - request_batch, [fetch_input_result], mem_pool + request_batch, fetch_input_result, mem_pool ) execute_result = worker.execute(