diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 6c1361b46c..e57994633d 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -32,108 +32,61 @@ on:
release:
types: [published]
-
-env:
- HOMEBREW_NO_ANALYTICS: "ON" # Make Homebrew installation a little quicker
- HOMEBREW_NO_AUTO_UPDATE: "ON"
- HOMEBREW_NO_BOTTLE_SOURCE_FALLBACK: "ON"
- HOMEBREW_NO_GITHUB_API: "ON"
- HOMEBREW_NO_INSTALL_CLEANUP: "ON"
- CIBW_SKIP: "pp* *i686*" # skip building for PyPy
- CIBW_ARCHS_MACOS: x86_64
- CIBW_ARCHS_LINUX: x86_64 # ppc64le # uncomment to enable powerPC build
- CIBW_ENVIRONMENT_MACOS: PATH="$(brew --prefix)/opt/make/libexec/gnubin:$PATH"
- MACOSX_DEPLOYMENT_TARGET: "10.09"
-
-
jobs:
- build_wheels:
- name: Build wheels on ${{ matrix.os }}
- runs-on: ${{ matrix.os }}
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-22.04, macos-12]
-
+ build_dists:
+ name: Build Distributions
+ runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
+ with:
+ python-version: '3.9'
- - name: Install cibuildwheel
- run: python -m pip install cibuildwheel>=2.12.3
+ - name: Install build
+ run: python -m pip install 'build>=1.2.2,<2'
- name: Install build-essentials
- if: contains(matrix.os, 'ubuntu')
run: |
sudo add-apt-repository ppa:ubuntu-toolchain-r/test
sudo apt-get update
- sudo apt-get install -y build-essential
- sudo apt-get install -y wget
+ sudo apt-get install -y build-essential wget
- - name: Install GNU make for MacOS
- if: contains(matrix.os, 'macos')
- run: brew install make || true
+ - name: Build Distributions
+ run: python -m build .
- - name: list target wheels
- run: |
- python -m cibuildwheel . --print-build-identifiers
-
- - name: Build wheels
- run: python -m cibuildwheel --output-dir wheelhouse
- env:
- CIBW_ENVIRONMENT_MACOS: PATH="$(brew --prefix)/opt/make/libexec/gnubin:$PATH"
- MACOSX_DEPLOYMENT_TARGET: "10.09"
-
- - uses: actions/upload-artifact@v2
+ - uses: actions/upload-artifact@v3
with:
- path: ./wheelhouse/*.whl
-
-
- build_sdist:
- name: Build source distribution
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
-
- - uses: actions/setup-python@v5
- name: Install Python
- with:
- python-version: '3.9'
-
- - name: Build sdist
- run: |
- python -m pip install cmake>=3.13
- python setup.py sdist
-
- - uses: actions/upload-artifact@v2
- with:
- path: dist/*.tar.gz
+ name: distributables
+ path: ./dist/*
upload_pypi:
- needs: [build_wheels, build_sdist]
- runs-on: ubuntu-latest
+ needs: [build_dists]
+ runs-on: ubuntu-22.04
steps:
- - uses: actions/download-artifact@v2
+ - uses: actions/download-artifact@v3
with:
- name: artifact
+ name: distributables
path: dist
- uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI }}
- #repository_url: https://test.pypi.org/legacy/
-
+ # repository-url: https://test.pypi.org/legacy/
createPullRequest:
- runs-on: ubuntu-latest
+ needs: [upload_pypi]
+ runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Create pull request
run: |
- gh pr create -B develop -H master --title 'Merge master into develop' --body 'This PR brings develop up to date with master for release.'
+ gh pr create -B develop \
+ -H master \
+ --title 'Merge master into develop' \
+ --body 'This PR brings develop up to date with master for release.'
env:
GH_TOKEN: ${{ github.token }}
diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml
index 04d7e25ae1..ca6ec4cc40 100644
--- a/.github/workflows/run_tests.yml
+++ b/.github/workflows/run_tests.yml
@@ -49,12 +49,12 @@ env:
jobs:
run_tests:
- name: Run tests ${{ matrix.subset }} with ${{ matrix.os }}, Python ${{ matrix.py_v}}, RedisAI ${{ matrix.rai }}
+ name: Run tests ${{ matrix.subset }} with ${{ matrix.os }}, Python ${{ matrix.py_v}}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
- subset: [backends, slow_tests, group_a, group_b]
+ subset: [backends, slow_tests, group_a, group_b, dragon]
os: [macos-12, macos-14, ubuntu-22.04] # Operating systems
compiler: [8] # GNU compiler version
rai: [1.2.7] # Redis AI versions
@@ -63,9 +63,6 @@ jobs:
- os: macos-14
py_v: "3.9"
- env:
- SMARTSIM_REDISAI: ${{ matrix.rai }}
-
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
@@ -117,15 +114,26 @@ jobs:
- name: Install SmartSim (with ML backends)
run: |
python -m pip install git+https://github.com/CrayLabs/SmartRedis.git@develop#egg=smartredis
- python -m pip install .[dev,mypy,ml]
+ python -m pip install .[dev,mypy]
- - name: Install ML Runtimes with Smart (with pt, tf, and onnx support)
- if: contains( matrix.os, 'ubuntu' ) || contains( matrix.os, 'macos-12')
- run: smart build --device cpu --onnx -v
+ - name: Install ML Runtimes
+ if: matrix.subset != 'dragon'
+ run: smart build --device cpu -v
- - name: Install ML Runtimes with Smart (no ONNX,TF on Apple Silicon)
- if: contains( matrix.os, 'macos-14' )
- run: smart build --device cpu --no_tf -v
+
+ - name: Install ML Runtimes (with dragon)
+ if: matrix.subset == 'dragon'
+ env:
+ SMARTSIM_DRAGON_TOKEN: ${{ secrets.DRAGON_TOKEN }}
+ run: |
+ if [ -n "${SMARTSIM_DRAGON_TOKEN}" ]; then
+ smart build --device cpu -v --dragon-repo dragonhpc/dragon-nightly --dragon-version 0.10
+ else
+ smart build --device cpu -v --dragon
+ fi
+ SP=$(python -c 'import site; print(site.getsitepackages()[0])')/smartsim/_core/config/dragon/.env
+ LLP=$(cat $SP | grep LD_LIBRARY_PATH | awk '{split($0, array, "="); print array[2]}')
+ echo "LD_LIBRARY_PATH=$LLP:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Run mypy
run: |
@@ -151,9 +159,16 @@ jobs:
echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV
py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ ./tests/backends
+ # Run pytest (dragon subtests)
+ - name: Run Dragon Pytest
+ if: (matrix.subset == 'dragon' && matrix.os == 'ubuntu-22.04')
+ run: |
+ echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV
+ dragon -s py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ -m ${{ matrix.subset }} ./tests
+
# Run pytest (test subsets)
- name: Run Pytest
- if: "!contains(matrix.subset, 'backends')" # if not running backend tests
+ if: (matrix.subset != 'backends' && matrix.subset != 'dragon') # if not running backend tests or dragon tests
run: |
echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV
py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ -m ${{ matrix.subset }} ./tests
diff --git a/.gitignore b/.gitignore
index 77b91d5865..97132aff7e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,6 +12,7 @@ tests/test_output
# Dependencies
smartsim/_core/.third-party
smartsim/_core/.dragon
+smartsim/_core/build
# Docs
_build
diff --git a/.wci.yml b/.wci.yml
index 6194f19391..cf53334c3a 100644
--- a/.wci.yml
+++ b/.wci.yml
@@ -22,8 +22,8 @@
language: Python
release:
- version: 0.7.0
- date: 2024-05-14
+ version: 0.8.0
+ date: 2024-09-25
documentation:
general: https://www.craylabs.org/docs/overview.html
diff --git a/Makefile b/Makefile
index bddbda722b..b4ceef2194 100644
--- a/Makefile
+++ b/Makefile
@@ -150,11 +150,11 @@ tutorials-dev:
@docker compose build tutorials-dev
@docker run -p 8888:8888 smartsim-tutorials:dev-latest
-# help: tutorials-prod - Build and start a docker container to run the tutorials (v0.7.0)
+# help: tutorials-prod - Build and start a docker container to run the tutorials (v0.8.0)
.PHONY: tutorials-prod
tutorials-prod:
@docker compose build tutorials-prod
- @docker run -p 8888:8888 smartsim-tutorials:v0.7.0
+ @docker run -p 8888:8888 smartsim-tutorials:v0.8.0
# help:
@@ -164,22 +164,22 @@ tutorials-prod:
# help: test - Run all tests
.PHONY: test
test:
- @python -m pytest --ignore=tests/full_wlm/
+ @python -m pytest --ignore=tests/full_wlm/ --ignore=tests/dragon_wlm
# help: test-verbose - Run all tests verbosely
.PHONY: test-verbose
test-verbose:
- @python -m pytest -vv --ignore=tests/full_wlm/
+ @python -m pytest -vv --ignore=tests/full_wlm/ --ignore=tests/dragon_wlm
# help: test-debug - Run all tests with debug output
.PHONY: test-debug
test-debug:
- @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/
+ @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ --ignore=tests/dragon_wlm
# help: test-cov - Run all tests with coverage
.PHONY: test-cov
test-cov:
- @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/
+ @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ --ignore=tests/dragon_wlm
# help: test-full - Run all WLM tests with Python coverage (full test suite)
@@ -192,3 +192,8 @@ test-full:
.PHONY: test-wlm
test-wlm:
@python -m pytest -vv tests/full_wlm/ tests/on_wlm
+
+# help: test-dragon - Run dragon-specific tests
+.PHONY: test-dragon
+test-dragon:
+ @dragon pytest tests/dragon_wlm
diff --git a/README.md b/README.md
index c0986042eb..610d6608c0 100644
--- a/README.md
+++ b/README.md
@@ -643,11 +643,11 @@ from C, C++, Fortran and Python with the SmartRedis Clients:
1.2.7 |
PyTorch |
- 2.0.1 |
+ 2.1.0 |
TensorFlow\Keras |
- 2.13.1 |
+ 2.15.0 |
ONNX |
diff --git a/doc/_static/version_names.json b/doc/_static/version_names.json
index bc095f84af..8b127e5867 100644
--- a/doc/_static/version_names.json
+++ b/doc/_static/version_names.json
@@ -1,7 +1,8 @@
{
"version_names":[
"develop (unstable)",
- "0.7.0 (stable)",
+ "0.8.0 (stable)",
+ "0.7.0",
"0.6.2",
"0.6.1",
"0.6.0",
@@ -15,6 +16,7 @@
"version_urls": [
"https://www.craylabs.org/develop/overview.html",
"https://www.craylabs.org/docs/overview.html",
+ "https://www.craylabs.org/docs/versions/0.7.0/overview.html",
"https://www.craylabs.org/docs/versions/0.6.2/overview.html",
"https://www.craylabs.org/docs/versions/0.6.1/overview.html",
"https://www.craylabs.org/docs/versions/0.6.0/overview.html",
diff --git a/doc/changelog.md b/doc/changelog.md
index 740197ce5d..752957bfdc 100644
--- a/doc/changelog.md
+++ b/doc/changelog.md
@@ -9,12 +9,80 @@ Jump to:
## SmartSim
-### Development branch
+### MLI branch
-To be released at some future point in time
+Description
+
+- Merge core refactor into MLI feature branch
+- Implement asynchronous notifications for shared data
+- Quick bug fix in _validate
+- Add helper methods to MLI classes
+- Update error handling for consistency
+- Parameterize installation of dragon package with `smart build`
+- Update docstrings
+- Filenames conform to snake case
+- Update SmartSim environment variables using new naming convention
+- Refactor `exception_handler`
+- Add RequestDispatcher and the possibility of batching inference requests
+- Enable hostname selection for dragon tasks
+- Remove pydantic dependency from MLI code
+- Update MLI environment variables using new naming convention
+- Reduce a copy by using torch.from_numpy instead of torch.tensor
+- Enable dynamic feature store selection
+- Fix dragon package installation bug
+- Adjust schemas for better performance
+- Add TorchWorker first implementation and mock inference app example
+- Add error handling in Worker Manager pipeline
+- Add EnvironmentConfigLoader for ML Worker Manager
+- Add Model schema with model metadata included
+- Removed device from schemas, MessageHandler and tests
+- Add ML worker manager, sample worker, and feature store
+- Add schemas and MessageHandler class for de/serialization of
+ inference requests and response messages
+
+
+### Develop
+
+To be released at some point in the future
+
+Description
+
+- Implement workaround for Tensorflow that allows RedisAI to build with GCC-14
+- Add instructions for installing SmartSim on PML's Scylla
+
+Detailed Notes
+
+- In libtensorflow, the input argument to TF_SessionRun seems to be mistyped to
+ TF_Output instead of TF_Input. These two types differ only in name. GCC-14
+ catches this and throws an error, even though earlier versions allow this. To
+ solve this problem, patches are applied to the Tensorflow backend in RedisAI.
+ Future versions of Tensorflow may fix this problem, but for now this seems to be
+ the best workaround.
+ ([SmartSim-PR738](https://github.com/CrayLabs/SmartSim/pull/738))
+- PML's Scylla is still under development. The usual SmartSim
+ build instructions do not apply because the GPU dependencies
+ have yet to be installed at a system-wide level. Scylla has
+ its own entry in the documentation.
+ ([SmartSim-PR733](https://github.com/CrayLabs/SmartSim/pull/733))
+
+
+### 0.8.0
+
+Released on 27 September, 2024
Description
+- Add instructions for Frontier to set the MIOPEN cache
+- Refine Frontier documentation for proper use of miniforge3
+- Refactor to the RedisAI build to allow more flexibility in versions
+ and sources of ML backends
+- Add Dockerfiles with GPU support
+- Fine grain build support for GPUs
+- Update Torch to 2.1.0, Tensorflow to 2.15.0
+- Better error messages in build process
+- Allow specifying Model and Ensemble parameters with
+ number-like types (e.g. numpy types)
+- Pin watchdog to 4.x
- Update codecov to 4.5.0
- Remove build of Redis from setup.py
- Mitigate dependency installation issues
@@ -30,6 +98,46 @@ Description
Detailed Notes
+- On Frontier, the MIOPEN cache may need to be set prior to using
+ RedisAI in the ``smart validate``. The instructions for Frontier
+ have been updated accordingly.
+ ([SmartSim-PR727](https://github.com/CrayLabs/SmartSim/pull/727))
+- On Frontier, the recommended way to activate conda environments is
+ to go through source activate. This also means that ``conda init``
+ is not needed. The instructions for Frontier have been updated to
+ reflect this.
+ ([SmartSim-PR719](https://github.com/CrayLabs/SmartSim/pull/719))
+- The RedisAIBuilder class was completely overhauled to allow users to
+ express a wider range of support for hardware/software stacks. This
+ will be extended to support ROCm, CUDA-11, and CUDA-12.
+ ([SmartSim-PR669](https://github.com/CrayLabs/SmartSim/pull/669))
+- Versions for each of these packages are no longer specified in an
+ internal class. Instead a default set of JSON files specifies the
+ sources and versions. Users can specify their own custom specifications
+ at smart build time.
+ ([SmartSim-PR669](https://github.com/CrayLabs/SmartSim/pull/669))
+- Because all build configuration has been moved to static files and all
+ backends are compiled during `smart build`, SmartSim can now be shipped as a
+ pure python wheel.
+ ([SmartSim-PR728](https://github.com/CrayLabs/SmartSim/pull/728))
+- Two new Dockerfiles are now provided (one each for 11.8 and 12.1) that
+ can be used to build a container to run the tutorials. No HPC support
+ should be expected at this time
+ ([SmartSim-PR669](https://github.com/CrayLabs/SmartSim/pull/669))
+- As a result of the previous change, SmartSim now requires C++17 and a
+ minimum Cuda version of 11.8 in order to build Torch 2.1.0.
+ ([SmartSim-PR669](https://github.com/CrayLabs/SmartSim/pull/669))
+- Error messages were not being interpolated correctly. This has been
+ addressed to provide more context when exposing error messages to users.
+ ([SmartSim-PR669](https://github.com/CrayLabs/SmartSim/pull/669))
+- The serializer would fail if a parameter for a Model or Ensemble
+ was specified as a numpy dtype. The constructors for these
+ methods now validate that the input is number-like and convert
+ them to strings
+ ([SmartSim-PR676](https://github.com/CrayLabs/SmartSim/pull/676))
+- Pin watchdog to 4.x because v5 introduces new types and requires
+ updates to the type-checking
+ ([SmartSim-PR690](https://github.com/CrayLabs/SmartSim/pull/690))
- Update codecov to 4.5.0 to mitigate GitHub action failure
([SmartSim-PR657](https://github.com/CrayLabs/SmartSim/pull/657))
- The builder module was included in setup.py to allow us to ship the
diff --git a/doc/conf.py b/doc/conf.py
index 932bce0132..8f3a9ca632 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -29,7 +29,7 @@
import smartsim
version = smartsim.__version__
except ImportError:
- version = "0.7.0"
+ version = "0.8.0"
# The full version, including alpha/beta/rc tags
release = version
diff --git a/doc/installation_instructions/basic.rst b/doc/installation_instructions/basic.rst
index 02c17e1fda..a5db285ca8 100644
--- a/doc/installation_instructions/basic.rst
+++ b/doc/installation_instructions/basic.rst
@@ -18,7 +18,7 @@ Prerequisites
Basic
=====
-The base prerequisites to install SmartSim and SmartRedis are:
+The base prerequisites to install SmartSim and SmartRedis wtih CPU-only support are:
- Python 3.9-3.11
- Pip
@@ -27,13 +27,11 @@ The base prerequisites to install SmartSim and SmartRedis are:
- C++ compiler
- GNU Make > 4.0
- git
- - `git-lfs`_
-
-.. _git-lfs: https://github.com/git-lfs/git-lfs?utm_source=gitlfs_site&utm_medium=installation_link&utm_campaign=gitlfs
.. note::
- GCC 5-9, 11, and 12 is recommended. There are known bugs with GCC 10.
+ GCC 9, 11-13 is recommended (here are known issues compiling with GCC 10). For
+ CUDA 11.8, GCC 9 or 11 must be used.
.. warning::
@@ -43,66 +41,146 @@ The base prerequisites to install SmartSim and SmartRedis are:
`which gcc g++` do not point to Apple Clang.
-GPU Support
-===========
+ML Library Support
+==================
-The machine-learning backends have additional requirements in order to
-use GPUs for inference
+We currently support both Nvidia and AMD GPUs when using RedisAI for GPU inference. The support
+for these GPUs often depends on the version of the CUDA or ROCm stack that is available on your
+machine. In _most_ cases, the versions are backwards compatible. If you encounter problems, please
+contact us and we can build the backend libraries for your desired version of CUDA and ROCm.
- - `CUDA Toolkit 11 (tested with 11.8) `_
- - `cuDNN 8 (tested with 8.9.1) `_
- - OS: Linux
- - GPU: Nvidia
+CPU backends are provided for Apple (both Intel and Apple Silicon) and Linux (x86_64).
-Be sure to reference the :ref:`installation notes ` for helpful
+Be sure to reference the table below to find which versions of the ML libraries are supported for
+your particular platform. Additionally, see :ref:`installation notes ` for helpful
information regarding various system types before installation.
-==================
-Supported Versions
-==================
+Linux
+-----
+.. tabs::
-.. list-table:: Supported System for Pre-built Wheels
- :widths: 50 50 50 50
- :header-rows: 1
- :align: center
+ .. group-tab:: CUDA 11
+
+ Additional requirements:
+
+ * GCC <= 11
+ * CUDA Toolkit 11.7 or 11.8
+ * cuDNN 8.9
+
+ .. list-table:: Nvidia CUDA 11
+ :widths: 50 50 50 50
+ :header-rows: 1
+ :align: center
+
+ * - Python Versions
+ - Torch
+ - Tensorflow
+ - ONNX Runtime
+ * - 3.9-3.11
+ - 2.3.1
+ - 2.14.1
+ - 1.17.3
+
+ .. group-tab:: CUDA 12
+
+ Additional requirements:
+
+ * CUDA Toolkit 12
+ * cuDNN 8.9
+
+ .. list-table:: Nvidia CUDA 12
+ :widths: 50 50 50 50
+ :header-rows: 1
+ :align: center
+
+ * - Python Versions
+ - Torch
+ - Tensorflow
+ - ONNX Runtime
+ * - 3.9-3.11
+ - 2.3.1
+ - 2.17
+ - 1.17.3
+
+ .. group-tab:: ROCm 6
+
+ .. list-table:: AMD ROCm 6.1
+ :widths: 50 50 50 50
+ :header-rows: 1
+ :align: center
+
+ * - Python Versions
+ - Torch
+ - Tensorflow
+ - ONNX Runtime
+ * - 3.9-3.11
+ - 2.4.1
+ - N/A
+ - N/A
+
+ .. group-tab:: CPU
+
+ .. list-table:: CPU-only
+ :widths: 50 50 50 50
+ :header-rows: 1
+ :align: center
+
+ * - Python Versions
+ - Torch
+ - Tensorflow
+ - ONNX Runtime
+ * - 3.9-3.11
+ - 2.4.0
+ - 2.15
+ - 1.17.3
+
+MacOSX
+------
- * - Platform
- - CPU
- - GPU
- - Python Versions
- * - MacOS
- - x86_64, aarch64
- - Not supported
- - 3.9 - 3.11
- * - Linux
- - x86_64
- - Nvidia
- - 3.9 - 3.11
+.. tabs::
+ .. group-tab:: Apple Silicon
-.. note::
+ .. list-table:: Apple Silicon ARM64 (no Metal support)
+ :widths: 50 50 50 50
+ :header-rows: 1
+ :align: center
- Users have succesfully run SmartSim on Windows using Windows Subsystem for Linux
- with Nvidia support. Generally, users should follow the Linux instructions here,
- however we make no guarantee or offer of support.
+ * - Python Versions
+ - Torch
+ - Tensorflow
+ - ONNX Runtime
+ * - 3.9-3.11
+ - 2.4.0
+ - 2.17
+ - 1.17.3
+ .. group-tab:: Intel Mac (x86)
-Native support for various machine learning libraries and their
-versions is dictated by our dependency on RedisAI_ 1.2.7.
+ .. list-table:: CPU-only
+ :widths: 50 50 50 50
+ :header-rows: 1
+ :align: center
-+------------------+----------+-------------+---------------+
-| RedisAI | PyTorch | Tensorflow | ONNX Runtime |
-+==================+==========+=============+===============+
-| 1.2.7 (default) | 2.0.1 | 2.13.1 | 1.16.3 |
-+------------------+----------+-------------+---------------+
+ * - Python Versions
+ - Torch
+ - Tensorflow
+ - ONNX Runtime
+ * - 3.9-3.11
+ - 2.2.0
+ - 2.15
+ - 1.17.3
-.. warning::
- On Apple Silicon, only the PyTorch backend is supported for now. Please contact us
- if you need support for other backends
+.. note::
+
+ Users have successfully run SmartSim on Windows using Windows Subsystem for Linux
+ with Nvidia support. Generally, users should follow the Linux instructions here,
+ however we make no guarantee or offer of support.
+
-TensorFlow_ 2.0 and Keras_ are supported through `graph freezing`_.
+TensorFlow_ and Keras_ are supported through `graph freezing`_.
ScikitLearn_ and Spark_ models are supported by SmartSim as well
through the use of the ONNX_ runtime (which is not built by
@@ -167,21 +245,8 @@ and install SmartSim from PyPI with the following command:
pip install smartsim
-If you would like SmartSim to also install python machine learning libraries
-that can be used outside SmartSim to build SmartSim-compatible models, you
-can request their installation through the ``[ml]`` optional dependencies,
-as follows:
-
-.. code-block:: bash
-
- # For bash
- pip install smartsim[ml]
- # For zsh
- pip install smartsim\[ml\]
-
-At this point, SmartSim is installed and can be used for more basic features.
-If you want to use the machine learning features of SmartSim, you will need
-to install the ML backends in the section below.
+At this point, SmartSim can be used for describing and launching experiments, but
+without any database/feature store functionality which allows for ML-enabled workflows.
Step 2: Build SmartSim
@@ -198,19 +263,19 @@ To see all the installation options:
smart --help
-CPU Install
------------
-
-To install the default ML backends for CPU, run
-
.. code-block:: bash
# run one of the following
- smart build --device cpu # install PT and TF for cpu
- smart build --device cpu --onnx # install all backends (PT, TF, ONNX) on cpu
+ smart build --device cpu # For unaccelerated AI/ML loads
+ smart build --device cuda118 # Nvidia Accelerator with CUDA 11.8
+ smart build --device cuda125 # Nvidia Accelerator with CUDA 12.5
+ smart build --device rocm57 # AMD Accelerator with ROCm 5.7.0
-By default, ``smart`` will install PyTorch and TensorFlow backends
-for use in SmartSim.
+By default, ``smart`` will install all backends available for the specified accelerator
+_and_ the compatible versions of the Python packages associated with the backends. To
+disable support for a specific backend, ``smart build`` accepts the flags
+``--skip-torch``, ``--skip-tensorflow``, ``--skip-onnx`` which can also be used in
+combination.
.. note::
@@ -218,19 +283,6 @@ for use in SmartSim.
all of the previous installs for the ML backends and ``smart clobber`` will
remove all pre-built dependencies as well as the ML backends.
-
-GPU Install
------------
-
-With the proper environment setup (see :ref:`GPU support`) the only difference
-to building SmartSim with GPU support is to specify a different ``device``
-
-.. code-block:: bash
-
- # run one of the following
- smart build --device gpu # install PT and TF for gpu
- smart build --device gpu --onnx # install all backends (PT, TF, ONNX) on gpu
-
.. note::
GPU builds can be troublesome due to the way that RedisAI and the ML-package
@@ -251,9 +303,21 @@ For example, to install dragon alongside the RedisAI CPU backends, you can run
.. code-block:: bash
- # run one of the following
smart build --device cpu --dragon # install Dragon, PT and TF for cpu
- smart build --device cpu --onnx --dragon # install Dragon and all backends (PT, TF, ONNX) on cpu
+
+``smart build`` supports installing a specific version of dragon. It exposes the
+parameters ``--dragon-repo`` and ``--dragon-version``, which can be used alone or
+in combination to customize the Dragon installation. For example:
+
+.. code-block:: bash
+
+ # using the --dragon-repo and --dragon-version flags to customize the Dragon installation
+ smart build --device cpu --dragon-repo userfork/dragon # install Dragon from a specific repo
+ smart build --device cpu --dragon-version 0.10 # install a specific Dragon release
+
+ # combining both flags
+ smart build --device cpu --dragon-repo userfork/dragon --dragon-version 0.91
+
.. note::
Dragon is only supported on Linux systems. For further information, you
@@ -319,35 +383,11 @@ source remains at the site of the clone instead of in site-packages.
.. code-block:: bash
cd smartsim
- pip install -e .[dev,ml] # for bash users
- pip install -e .\[dev,ml\] # for zsh users
-
-Use the now installed ``smart`` cli to install the machine learning runtimes and dragon.
-
-.. tabs::
-
- .. tab:: Linux
-
- .. code-block:: bash
-
- # run one of the following
- smart build --device cpu --onnx --dragon # install with cpu-only support
- smart build --device gpu --onnx --dragon # install with both cpu and gpu support
-
-
- .. tab:: MacOS (Intel x64)
-
- .. code-block:: bash
-
- smart build --device cpu --onnx # install all backends (PT, TF, ONNX) on gpu
-
-
- .. tab:: MacOS (Apple Silicon)
-
- .. code-block:: bash
-
- smart build --device cpu --no_tf # Only install PyTorch (TF/ONNX unsupported)
+ pip install -e .[dev] # for bash users
+ pip install -e ".[dev]" # for zsh users
+Use the now installed ``smart`` cli to install the machine learning runtimes and
+dragon. Referring to "Step 2: Build SmartSim" above.
Build the SmartRedis library
============================
diff --git a/doc/installation_instructions/platform.rst b/doc/installation_instructions/platform.rst
index 086fc2951c..c1eb51df1a 100644
--- a/doc/installation_instructions/platform.rst
+++ b/doc/installation_instructions/platform.rst
@@ -12,12 +12,16 @@ that SmartSim may be used on.
.. include:: platform/frontier.rst
+.. include:: platform/perlmutter.rst
+
.. include:: platform/cray.rst
.. include:: platform/ncar-cheyenne.rst
.. include:: platform/olcf-summit.rst
+.. include:: platform/pml-scylla.rst
+
.. _site_installation:
.. include:: site-install.rst
diff --git a/doc/installation_instructions/platform/frontier.rst b/doc/installation_instructions/platform/frontier.rst
index e238561559..9b05061fe1 100644
--- a/doc/installation_instructions/platform/frontier.rst
+++ b/doc/installation_instructions/platform/frontier.rst
@@ -1,23 +1,15 @@
OLCF Frontier
=============
-Summary
--------
-
-Frontier is an AMD CPU/AMD GPU system.
-
-As of 2023-07-06, users can use the following instructions, however we
-anticipate that all the SmartSim dependencies will be available system-wide via
-the modules system.
-
Known limitations
-----------------
We are continually working on getting all the features of SmartSim working on
Frontier, however we do have some known limitations:
-* For now, only Torch models are supported. We are working to find a recipe to
- install Tensorflow with ROCm support from scratch
+* For now, only Torch models are supported. If you need Tensorflow or ONNX
+ support please contact us
+* All SmartSim experiments must be run from Lustre, _not_ your home directory
* The colocated database will fail without specifying ``custom_pinning``. This
is because the default pinning assumes that processor 0 is available, but the
'low-noise' default on Frontier reserves the processor on each NUMA node.
@@ -30,8 +22,8 @@ Frontier, however we do have some known limitations:
Please raise an issue in the SmartSim Github or contact the developers if the above
issues are affecting your workflow or if you find any other problems.
-Build process
--------------
+One-time Setup
+--------------
To install the SmartRedis and SmartSim python packages on Frontier, please follow
these instructions, being sure to set the following variables
@@ -39,25 +31,20 @@ these instructions, being sure to set the following variables
.. code:: bash
export PROJECT_NAME=CHANGE_ME
- export VENV_NAME=CHANGE_ME
-Then continue with the install:
+**Step 1:** Create and activate a virtual environment for SmartSim:
.. code:: bash
- module load PrgEnv-gnu-amd git-lfs cmake cray-python
- module unload xalt amd-mixed
- module load rocm/4.5.2
- export CC=gcc
- export CXX=g++
+ module load PrgEnv-gnu miniforge3 rocm/6.1.3
export SCRATCH=/lustre/orion/$PROJECT_NAME/scratch/$USER/
- export VENV_HOME=$SCRATCH/$VENV_NAME/
+ conda create -n smartsim python=3.11
+ source activate smartsim
- python3 -m venv $VENV_HOME
- source $VENV_HOME/bin/activate
- pip install torch==1.11.0+rocm4.5.2 torchvision==0.12.0+rocm4.5.2 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/rocm4.5.2
+**Step 2:** Build the SmartRedis C++ and Fortran libraries:
+.. code:: bash
cd $SCRATCH
git clone https://github.com/CrayLabs/SmartRedis.git
@@ -65,57 +52,61 @@ Then continue with the install:
make lib-with-fortran
pip install .
- # Download SmartSim and site-specific files
+**Step 3:** Install SmartSim in the conda environment:
+
+.. code:: bash
+
cd $SCRATCH
- git clone https://github.com/CrayLabs/site-deployments.git
- git clone https://github.com/CrayLabs/SmartSim.git
- cd SmartSim
- pip install -e .[dev]
+ pip install git+https://github.com/CrayLabs/SmartSim.git
-Next to finish the compilation, we need to manually modify one of the auxiliary
-cmake files that comes packaged with Torch
+**Step 4:** Build Redis, RedisAI, the backends, and all the Python packages:
.. code:: bash
- export TORCH_CMAKE_DIR=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')
- # Manual step: modify all references to the 'rocm' directory to rocm-4.5.2
- vim $TORCH_CMAKE_DIR/Caffe2/Caffe2Targets.cmake
+ smart build --device=rocm-6
-Finally, build Redis (or keydb for a more performant solution), RedisAI, and the
-machine-learning backends using:
+**Step 5:** Check that SmartSim has been installed and built correctly:
.. code:: bash
- KEYDB_FLAG="" # set this to --keydb if desired
- smart build --device gpu --torch_dir $TORCH_CMAKE_DIR --no_tf -v $(KEYDB_FLAG)
+ # Optimizations for inference
+ export MIOPEN_USER_DB_PATH="/tmp/${USER}/my-miopen-cache"
+ export MIOPEN_CUSTOM_CACHE_DIR=$MIOPEN_USER_DB_PATH
+ rm -rf $MIOPEN_USER_DB_PATH
+ mkdir -p $MIOPEN_USER_DB_PATH
+
+ # Run the install validation utility
+ smart validate --device gpu
-Set up environment
-------------------
+The following output indicates a successful install:
+
+.. code:: bash
+
+ [SmartSim] INFO Verifying Tensor Transfer
+ [SmartSim] INFO Verifying Torch Backend
+ 16:26:35 login SmartSim[557020:MainThread] INFO Success!
+
+Post-installation
+-----------------
Before running SmartSim, the environment should match the one used to
-build, and some variables should be set to work around some ROCm PyTorch
-issues:
+build, and some variables should be set to optimize performance:
.. code:: bash
# Set these to the same values that were used for install
export PROJECT_NAME=CHANGE_ME
- export VENV_NAME=CHANGE_ME
.. code:: bash
- module load PrgEnv-gnu-amd git-lfs cmake cray-python
- module unload xalt amd-mixed
- module load rocm/4.5.2
+ module load PrgEnv-gnu miniforge3 rocm/6.1.3
+ source activate smartsim
- export SCRATCH=/lustre/orion/$PROJECT_NAME/scratch/$USER/
- export MIOPEN_USER_DB_PATH=/tmp/miopendb/
- export MIOPEN_SYSTEM_DB_PATH=$MIOPEN_USER_DB_PATH
- mkdir -p $MIOPEN_USER_DB_PATH
- export MIOPEN_DISABLE_CACHE=1
- export VENV_HOME=$SCRATCH/$VENV_NAME/
- source $VENV_HOME/bin/activate
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$VENV_HOME/lib/python3.9/site-packages/torch/lib
+ # Optimizations for inference
+ export MIOPEN_USER_DB_PATH="/tmp/${USER}/my-miopen-cache"
+ export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH}
+ rm -rf ${MIOPEN_USER_DB_PATH}
+ mkdir -p ${MIOPEN_USER_DB_PATH}
Binding DBs to Slingshot
------------------------
@@ -129,17 +120,3 @@ following way:
exp = Experiment("my_exp", launcher="slurm")
orc = exp.create_database(db_nodes=3, interface=["hsn0","hsn1","hsn2","hsn3"], single_cmd=True)
-
-Running tests
--------------
-
-The same environment set to run SmartSim must be set to run tests. The
-environment variables needed to run the test suite are the following:
-
-.. code:: bash
-
- export SMARTSIM_TEST_ACCOUNT=PROJECT_NAME # Change this to above
- export SMARTSIM_TEST_LAUNCHER=slurm
- export SMARTSIM_TEST_DEVICE=gpu
- export SMARTSIM_TEST_PORT=6789
- export SMARTSIM_TEST_INTERFACE="hsn0,hsn1,hsn2,hsn3"
diff --git a/doc/installation_instructions/platform/olcf-summit.rst b/doc/installation_instructions/platform/olcf-summit.rst
index 7e2ba513da..07be24eec7 100644
--- a/doc/installation_instructions/platform/olcf-summit.rst
+++ b/doc/installation_instructions/platform/olcf-summit.rst
@@ -19,7 +19,7 @@ into problems.
.. code-block:: bash
# setup Python and build environment
- export ENV_NAME=smartsim-0.7.0
+ export ENV_NAME=smartsim-0.8.0
git clone https://github.com/CrayLabs/SmartRedis.git smartredis
git clone https://github.com/CrayLabs/SmartSim.git smartsim
conda config --prepend channels https://ftp.osuosl.org/pub/open-ce/1.6.1/
diff --git a/doc/installation_instructions/platform/perlmutter.rst b/doc/installation_instructions/platform/perlmutter.rst
new file mode 100644
index 0000000000..71f97a4dc9
--- /dev/null
+++ b/doc/installation_instructions/platform/perlmutter.rst
@@ -0,0 +1,64 @@
+NERSC Perlmutter
+================
+
+One-time Setup
+--------------
+
+To install SmartSim on Perlmutter, follow these steps:
+
+**Step 1:** Create and activate a conda environment for SmartSim:
+
+.. code:: bash
+
+ module load conda cudatoolkit/12.2 cudnn/8.9.3_cuda12 PrgEnv-gnu
+ conda create -n smartsim python=3.11
+ conda activate smartsim
+
+**Step 2:** Build the SmartRedis C++ and Fortran libraries:
+
+.. code:: bash
+
+ git clone https://github.com/CrayLabs/SmartRedis.git
+ cd SmartRedis
+ make lib-with-fortran
+ pip install .
+ cd ..
+
+**Step 3:** Install SmartSim in the conda environment:
+
+.. code:: bash
+
+ pip install git+https://github.com/CrayLabs/SmartSim.git
+
+**Step 4:** Build Redis, RedisAI, the backends, and all the Python packages:
+
+.. code:: bash
+
+ smart build --device=cuda-12
+
+**Step 5:** Check that SmartSim has been installed and built correctly:
+
+.. code:: bash
+
+ smart validate --device gpu
+
+The following output indicates a successful install:
+
+.. code:: bash
+
+ [SmartSim] INFO Verifying Tensor Transfer
+ [SmartSim] INFO Verifying Torch Backend
+ [SmartSim] INFO Verifying ONNX Backend
+ [SmartSim] INFO Verifying TensorFlow Backend
+ 16:26:35 login SmartSim[557020:MainThread] INFO Success!
+
+Post-installation
+-----------------
+
+After completing the above steps to install SmartSim in a conda environment, you
+can reload the conda environment by running the following commands:
+
+.. code:: bash
+
+ module load conda cudatoolkit/12.2 cudnn/8.9.3_cuda12 PrgEnv-gnu
+ conda activate smartsim
diff --git a/doc/installation_instructions/platform/pml-scylla.rst b/doc/installation_instructions/platform/pml-scylla.rst
new file mode 100644
index 0000000000..c13f178213
--- /dev/null
+++ b/doc/installation_instructions/platform/pml-scylla.rst
@@ -0,0 +1,84 @@
+PML Scylla
+==========
+
+.. warning::
+ As of September 2024, the software stack on Scylla is still being finalized.
+ Therefore, please consider these instructions as preliminary for now.
+
+One-time Setup
+--------------
+
+To install SmartSim on Scylla, follow these steps:
+
+**Step 1:** Create and activate a Python virtual environment for SmartSim:
+
+.. code:: bash
+
+ module use module use /scyllapfs/hpe/ashao/smartsim_dependencies/modulefiles
+ module load cudatoolkit cudnn git
+ python -m venv /scyllafps/scratch/$USER/venvs/smartsim
+ source /scyllafps/scratch/$USER/venvs/smartsim/bin/activate
+
+**Step 2:** Build the SmartRedis C++ and Fortran libraries:
+
+.. code:: bash
+
+ git clone https://github.com/CrayLabs/SmartRedis.git
+ cd SmartRedis
+ make lib-with-fortran
+ pip install .
+ cd ..
+
+**Step 3:** Install SmartSim in the conda environment:
+
+.. code:: bash
+
+ pip install git+https://github.com/CrayLabs/SmartSim.git
+
+**Step 4:** Build Redis, RedisAI, the backends, and all the Python packages:
+
+.. code:: bash
+
+ export TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0" # Workaround for a PyTorch problem
+ smart build --device=cuda-12
+ module unload cudnn # Workaround for a PyTorch problem
+
+
+.. note::
+ The first workaround is needed because for some reason the autodetection
+ of CUDA architectures is not consistent internally with one of PyTorch's
+ dependencies. This seems to be unique to this machine as we do not see
+ this on other platforms.
+
+ The second workaround is needed because PyTorch 2.3 (and possibly 2.2)
+ will attempt to load the version of cuDNN that is in the LD_LIBRARY_PATH
+ instead of the version shipped with PyTorch itself. This results in
+ unfound symbols.
+
+**Step 5:** Check that SmartSim has been installed and built correctly:
+
+.. code:: bash
+
+ srun -n 1 -p gpu --gpus=1 --pty smart validate --device gpu
+
+The following output indicates a successful install:
+
+.. code:: bash
+
+ [SmartSim] INFO Verifying Tensor Transfer
+ [SmartSim] INFO Verifying Torch Backend
+ [SmartSim] INFO Verifying ONNX Backend
+ [SmartSim] INFO Verifying TensorFlow Backend
+ 16:26:35 login SmartSim[557020:MainThread] INFO Success!
+
+Post-installation
+-----------------
+
+After completing the above steps to install SmartSim in a conda environment, you
+can reload the conda environment by running the following commands:
+
+.. code:: bash
+
+ module load cudatoolkit/12.4.1 git # cudnn should NOT be loaded
+ source /scyllafps/scratch/$USER/venvs/smartsim/bin/activate
+
diff --git a/doc/installation_instructions/site-install.rst b/doc/installation_instructions/site-install.rst
index 26ecd6c138..53e0ff8bf0 100644
--- a/doc/installation_instructions/site-install.rst
+++ b/doc/installation_instructions/site-install.rst
@@ -11,5 +11,5 @@ from source with the following steps replacing ``COMPILER_VERSION`` and
module use -a /lus/scratch/smartsim/local/modulefiles
module load cudatoolkit/11.8 cudnn smartsim-deps/COMPILER_VERSION/SMARTSIM_VERSION
- pip install smartsim[ml]
- smart build --only_python_packages --device gpu [--onnx]
+ pip install smartsim
+ smart build --skip-backends --device gpu [--onnx]
diff --git a/doc/tutorials/ml_inference/Inference-in-SmartSim.ipynb b/doc/tutorials/ml_inference/Inference-in-SmartSim.ipynb
index 2d19cab138..4afdc38955 100644
--- a/doc/tutorials/ml_inference/Inference-in-SmartSim.ipynb
+++ b/doc/tutorials/ml_inference/Inference-in-SmartSim.ipynb
@@ -44,8 +44,9 @@
],
"source": [
"## Installing the ML backends\n",
- "from smartsim._core.utils.helpers import installed_redisai_backends\n",
- "print(installed_redisai_backends())\n"
+ "# from smartsim._core.utils.helpers import installed_redisai_backends\n",
+ "#print(installed_redisai_backends())\n",
+ "# TODO: replace deprecated installed_redisai_backends"
]
},
{
@@ -132,7 +133,7 @@
"\n",
"ML Backends Requested\n",
"╒════════════╤════════╤══════╕\n",
- "│ PyTorch │ 2.0.1 │ \u001b[32mTrue\u001b[0m │\n",
+ "│ PyTorch │ 2.1.0 │ \u001b[32mTrue\u001b[0m │\n",
"│ TensorFlow │ 2.13.1 │ \u001b[32mTrue\u001b[0m │\n",
"│ ONNX │ 1.14.1 │ \u001b[32mTrue\u001b[0m │\n",
"╘════════════╧════════╧══════╛\n",
diff --git a/docker-compose.yml b/docker-compose.yml
index 0473616560..e652591620 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -14,9 +14,9 @@ services:
- "8888:8888"
tutorials-prod:
- image: smartsim-tutorials:v0.7.0
+ image: smartsim-tutorials:v0.8.0
build:
context: .
dockerfile: ./docker/prod/Dockerfile
ports:
- - "8888:8888"
\ No newline at end of file
+ - "8888:8888"
diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile
index bc92e2fd79..faeeae8f37 100644
--- a/docker/dev/Dockerfile
+++ b/docker/dev/Dockerfile
@@ -50,7 +50,7 @@ COPY . /home/craylabs/SmartSim
RUN chown craylabs:root -R SmartSim
USER craylabs
-RUN cd SmartSim && SMARTSIM_SUFFIX=dev python -m pip install .[ml]
+RUN cd SmartSim && SMARTSIM_SUFFIX=dev python -m pip install .
RUN export PATH=/home/craylabs/.local/bin:$PATH && \
echo "export PATH=/home/craylabs/.local/bin:$PATH" >> /home/craylabs/.bashrc && \
diff --git a/docker/prod-cuda11/Dockerfile b/docker/prod-cuda11/Dockerfile
new file mode 100644
index 0000000000..fc27479051
--- /dev/null
+++ b/docker/prod-cuda11/Dockerfile
@@ -0,0 +1,61 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+FROM ubuntu:22.04
+
+LABEL maintainer="Cray Labs"
+LABEL org.opencontainers.image.source https://github.com/CrayLabs/SmartSim
+
+ARG DEBIAN_FRONTEND="noninteractive"
+ENV TZ=US/Seattle
+
+# Make basic dependencies
+RUN apt-get update \
+ && apt-get install --no-install-recommends -y build-essential \
+ git gcc make git-lfs wget libopenmpi-dev openmpi-bin unzip \
+ python3-pip python3 python3-dev cmake wget apt-utils
+
+# # Install Cudatoolkit 11.8
+ENV TERM="xterm"
+RUN wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run && \
+ chmod +x ./cuda_11.8.0_520.61.05_linux.run && \
+ ./cuda_11.8.0_520.61.05_linux.run --silent --toolkit && \
+ rm ./cuda_11.8.0_520.61.05_linux.run
+
+# Install cuDNN 8.9.7
+RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcudnn8_8.9.7.29-1+cuda11.8_amd64.deb && \
+ dpkg -i libcudnn8_8.9.7.29-1+cuda11.8_amd64.deb && \
+ rm ./libcudnn8_8.9.7.29-1+cuda11.8_amd64.deb
+
+ # Install SmartSim and SmartRedis
+ RUN pip install git+https://github.com/CrayLabs/SmartRedis.git && \
+ pip install "smartsim @ git+https://github.com/CrayLabs/SmartSim.git"
+
+ ENV CUDA_HOME="/usr/local/cuda/"
+ ENV PATH="${PATH}:${CUDA_HOME}/bin"
+
+ # Build ML Backends
+ RUN smart build --device=gpu --onnx
diff --git a/docker/prod-cuda12/Dockerfile b/docker/prod-cuda12/Dockerfile
new file mode 100644
index 0000000000..bbdfd35131
--- /dev/null
+++ b/docker/prod-cuda12/Dockerfile
@@ -0,0 +1,64 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+FROM ubuntu:22.04
+
+LABEL maintainer="Cray Labs"
+LABEL org.opencontainers.image.source https://github.com/CrayLabs/SmartSim
+
+ARG DEBIAN_FRONTEND="noninteractive"
+ENV TZ=US/Seattle
+
+# Make basic dependencies
+RUN apt-get update \
+ && apt-get install --no-install-recommends -y build-essential \
+ git gcc make git-lfs wget libopenmpi-dev openmpi-bin unzip \
+ python3-pip python3 python3-dev cmake wget
+
+# Install Cudatoolkit 12.5
+RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
+ dpkg -i cuda-keyring_1.1-1_all.deb && \
+ apt-get update -y && \
+ apt-get install -y cuda-toolkit-12-5
+
+# Install cuDNN 8.9.7
+RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcudnn8_8.9.7.29-1+cuda12.2_amd64.deb && \
+ dpkg -i libcudnn8_8.9.7.29-1+cuda12.2_amd64.deb
+
+# Install SmartSim and SmartRedis
+RUN pip install git+https://github.com/CrayLabs/SmartRedis.git && \
+ pip install git+https://github.com/CrayLabs/SmartSim.git@cuda-12-support
+
+ENV CUDA_HOME="/usr/local/cuda/"
+ENV PATH="${PATH}:${CUDA_HOME}/bin"
+
+# Install machine-learning python packages consistent with RedisAI
+# Note: pytorch gets installed in the smart build step
+# This step will be deprecated in a future update
+RUN pip install tensorflow==2.15.0
+
+# Build ML Backends
+RUN smart build --device=cuda121
diff --git a/docker/prod/Dockerfile b/docker/prod/Dockerfile
index 0f5b8dafc5..f8560f7bda 100644
--- a/docker/prod/Dockerfile
+++ b/docker/prod/Dockerfile
@@ -46,7 +46,7 @@ COPY --chown=craylabs:root ./doc/tutorials/ /home/craylabs/tutorials/
USER craylabs
RUN export PATH=/home/craylabs/.local/bin:$PATH && \
echo "export PATH=/home/craylabs/.local/bin:$PATH" >> /home/craylabs/.bashrc && \
- python -m pip install smartsim[ml]==0.7.0 jupyter jupyterlab "ipython<8" matplotlib && \
+ python -m pip install smartsim==0.8.0 jupyter jupyterlab "ipython<8" matplotlib && \
smart build --device cpu -v && \
chown craylabs:root -R /home/craylabs/.local && \
rm -rf ~/.cache/pip
diff --git a/ex/high_throughput_inference/mli_driver.py b/ex/high_throughput_inference/mli_driver.py
new file mode 100644
index 0000000000..36f427937c
--- /dev/null
+++ b/ex/high_throughput_inference/mli_driver.py
@@ -0,0 +1,77 @@
+import os
+import base64
+import cloudpickle
+import sys
+from smartsim import Experiment
+from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
+from smartsim.status import TERMINAL_STATUSES
+from smartsim.settings import DragonRunSettings
+import time
+import typing as t
+
+DEVICE = "gpu"
+NUM_RANKS = 4
+NUM_WORKERS = 1
+filedir = os.path.dirname(__file__)
+worker_manager_script_name = os.path.join(filedir, "standalone_worker_manager.py")
+app_script_name = os.path.join(filedir, "mock_app.py")
+model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt")
+
+transport: t.Literal["hsta", "tcp"] = "hsta"
+
+os.environ["SMARTSIM_DRAGON_TRANSPORT"] = transport
+
+exp_path = os.path.join(filedir, f"MLI_proto_{transport.upper()}")
+os.makedirs(exp_path, exist_ok=True)
+exp = Experiment("MLI_proto", launcher="dragon", exp_path=exp_path)
+
+torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii")
+
+worker_manager_rs: DragonRunSettings = exp.create_run_settings(
+ sys.executable,
+ [
+ worker_manager_script_name,
+ "--device",
+ DEVICE,
+ "--worker_class",
+ torch_worker_str,
+ "--batch_size",
+ str(NUM_RANKS//NUM_WORKERS),
+ "--batch_timeout",
+ str(0.00),
+ "--num_workers",
+ str(NUM_WORKERS)
+ ],
+)
+
+aff = []
+
+worker_manager_rs.set_cpu_affinity(aff)
+
+worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs)
+worker_manager.attach_generator_files(to_copy=[worker_manager_script_name])
+
+app_rs: DragonRunSettings = exp.create_run_settings(
+ sys.executable,
+ exe_args=[app_script_name, "--device", DEVICE, "--log_max_batchsize", str(6)],
+)
+app_rs.set_tasks_per_node(NUM_RANKS)
+
+
+app = exp.create_model("app", run_settings=app_rs)
+app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])
+
+exp.generate(worker_manager, app, overwrite=True)
+exp.start(worker_manager, app, block=False)
+
+while True:
+ if exp.get_status(app)[0] in TERMINAL_STATUSES:
+ time.sleep(10)
+ exp.stop(worker_manager)
+ break
+ if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES:
+ time.sleep(10)
+ exp.stop(app)
+ break
+
+print("Exiting.")
diff --git a/ex/high_throughput_inference/mock_app.py b/ex/high_throughput_inference/mock_app.py
new file mode 100644
index 0000000000..c3b3eaaf4c
--- /dev/null
+++ b/ex/high_throughput_inference/mock_app.py
@@ -0,0 +1,142 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# isort: off
+import dragon
+from dragon import fli
+from dragon.channels import Channel
+import dragon.channels
+from dragon.data.ddict.ddict import DDict
+from dragon.globalservices.api_setup import connect_to_infrastructure
+from dragon.utils import b64decode, b64encode
+
+# isort: on
+
+import argparse
+import io
+
+import torch
+
+from smartsim.log import get_logger
+
+torch.set_num_interop_threads(16)
+torch.set_num_threads(1)
+
+logger = get_logger("App")
+logger.info("Started app")
+
+from collections import OrderedDict
+
+from smartsim.log import get_logger, log_to_file
+from smartsim._core.mli.client.protoclient import ProtoClient
+
+logger = get_logger("App")
+
+
+CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
+
+
+class ResNetWrapper:
+ """Wrapper around a pre-rained ResNet model."""
+ def __init__(self, name: str, model: str):
+ """Initialize the instance.
+
+ :param name: The name to use for the model
+ :param model: The path to the pre-trained PyTorch model"""
+ self._model = torch.jit.load(model)
+ self._name = name
+ buffer = io.BytesIO()
+ scripted = torch.jit.trace(self._model, self.get_batch())
+ torch.jit.save(scripted, buffer)
+ self._serialized_model = buffer.getvalue()
+
+ def get_batch(self, batch_size: int = 32):
+ """Create a random batch of data with the correct dimensions to
+ invoke a ResNet model.
+
+ :param batch_size: The desired number of samples to produce
+ :returns: A PyTorch tensor"""
+ return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32)
+
+ @property
+ def model(self) -> bytes:
+ """The content of a model file.
+
+ :returns: The model bytes"""
+ return self._serialized_model
+
+ @property
+ def name(self) -> str:
+ """The name applied to the model.
+
+ :returns: The name"""
+ return self._name
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser("Mock application")
+ parser.add_argument("--device", default="cpu", type=str)
+ parser.add_argument("--log_max_batchsize", default=8, type=int)
+ args = parser.parse_args()
+
+ resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt")
+
+ client = ProtoClient(timing_on=True)
+ client.set_model(resnet.name, resnet.model)
+
+ if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
+ # TODO: adapt to non-Nvidia devices
+ torch_device = args.device.replace("gpu", "cuda")
+ pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(
+ torch_device
+ )
+
+ TOTAL_ITERATIONS = 100
+
+ for log2_bsize in range(args.log_max_batchsize + 1):
+ b_size: int = 2**log2_bsize
+ logger.info(f"Batch size: {b_size}")
+ for iteration_number in range(TOTAL_ITERATIONS + int(b_size == 1)):
+ logger.info(f"Iteration: {iteration_number}")
+ sample_batch = resnet.get_batch(b_size)
+ remote_result = client.run_model(resnet.name, sample_batch)
+ logger.info(client.perf_timer.get_last("total_time"))
+ if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
+ local_res = pt_model(sample_batch.to(torch_device))
+ err_norm = torch.linalg.vector_norm(
+ torch.flatten(remote_result).to(torch_device)
+ - torch.flatten(local_res),
+ ord=1,
+ ).cpu()
+ res_norm = torch.linalg.vector_norm(remote_result, ord=1).item()
+ local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item()
+ logger.info(
+ f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}"
+ )
+ torch.cuda.synchronize()
+
+ client.perf_timer.print_timings(to_file=True)
diff --git a/ex/high_throughput_inference/mock_app_redis.py b/ex/high_throughput_inference/mock_app_redis.py
new file mode 100644
index 0000000000..8978bcea23
--- /dev/null
+++ b/ex/high_throughput_inference/mock_app_redis.py
@@ -0,0 +1,90 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import argparse
+import io
+import numpy
+import time
+import torch
+from mpi4py import MPI
+from smartsim.log import get_logger
+from smartsim._core.utils.timings import PerfTimer
+from smartredis import Client
+
+logger = get_logger("App")
+
+class ResNetWrapper():
+ def __init__(self, name: str, model: str):
+ self._model = torch.jit.load(model)
+ self._name = name
+ buffer = io.BytesIO()
+ scripted = torch.jit.trace(self._model, self.get_batch())
+ torch.jit.save(scripted, buffer)
+ self._serialized_model = buffer.getvalue()
+
+ def get_batch(self, batch_size: int=32):
+ return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32)
+
+ @property
+ def model(self):
+ return self._serialized_model
+
+ @property
+ def name(self):
+ return self._name
+
+if __name__ == "__main__":
+
+ comm = MPI.COMM_WORLD
+ rank = comm.Get_rank()
+
+ parser = argparse.ArgumentParser("Mock application")
+ parser.add_argument("--device", default="cpu")
+ args = parser.parse_args()
+
+ resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt")
+
+ client = Client(cluster=False, address=None)
+ client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper())
+
+ perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"redis{rank}_")
+
+ total_iterations = 100
+ timings=[]
+ for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
+ logger.info(f"Batch size: {batch_size}")
+ for iteration_number in range(total_iterations + int(batch_size==1)):
+ perf_timer.start_timings("batch_size", batch_size)
+ logger.info(f"Iteration: {iteration_number}")
+ input_name = f"batch_{rank}"
+ output_name = f"result_{rank}"
+ client.put_tensor(name=input_name, data=resnet.get_batch(batch_size).numpy())
+ client.run_model(name=resnet.name, inputs=[input_name], outputs=[output_name])
+ result = client.get_tensor(name=output_name)
+ perf_timer.end_timings()
+
+
+ perf_timer.print_timings(True)
diff --git a/ex/high_throughput_inference/redis_driver.py b/ex/high_throughput_inference/redis_driver.py
new file mode 100644
index 0000000000..ff57725d40
--- /dev/null
+++ b/ex/high_throughput_inference/redis_driver.py
@@ -0,0 +1,66 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+import sys
+from smartsim import Experiment
+from smartsim.status import TERMINAL_STATUSES
+import time
+
+DEVICE = "gpu"
+filedir = os.path.dirname(__file__)
+app_script_name = os.path.join(filedir, "mock_app_redis.py")
+model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt")
+
+
+exp_path = os.path.join(filedir, "redis_ai_multi")
+os.makedirs(exp_path, exist_ok=True)
+exp = Experiment("redis_ai_multi", launcher="slurm", exp_path=exp_path)
+
+db = exp.create_database(interface="hsn0")
+
+app_rs = exp.create_run_settings(
+ sys.executable, exe_args = [app_script_name, "--device", DEVICE]
+ )
+app_rs.set_nodes(1)
+app_rs.set_tasks(4)
+app = exp.create_model("app", run_settings=app_rs)
+app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])
+
+exp.generate(db, app, overwrite=True)
+
+exp.start(db, app, block=False)
+
+while True:
+ if exp.get_status(app)[0] in TERMINAL_STATUSES:
+ exp.stop(db)
+ break
+ if exp.get_status(db)[0] in TERMINAL_STATUSES:
+ exp.stop(app)
+ break
+ time.sleep(5)
+
+print("Exiting.")
\ No newline at end of file
diff --git a/ex/high_throughput_inference/standalone_worker_manager.py b/ex/high_throughput_inference/standalone_worker_manager.py
new file mode 100644
index 0000000000..b4527bc5d2
--- /dev/null
+++ b/ex/high_throughput_inference/standalone_worker_manager.py
@@ -0,0 +1,218 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+import dragon
+
+# pylint disable=import-error
+import dragon.infrastructure.policy as dragon_policy
+import dragon.infrastructure.process_desc as dragon_process_desc
+import dragon.native.process as dragon_process
+from dragon import fli
+from dragon.channels import Channel
+from dragon.data.ddict.ddict import DDict
+from dragon.globalservices.api_setup import connect_to_infrastructure
+from dragon.managed_memory import MemoryPool
+from dragon.utils import b64decode, b64encode
+
+# pylint enable=import-error
+
+# isort: off
+# isort: on
+
+import argparse
+import base64
+import multiprocessing as mp
+import os
+import socket
+import time
+import typing as t
+
+import cloudpickle
+
+from smartsim._core.entrypoints.service import Service
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
+from smartsim._core.mli.comm.channel.dragon_util import create_local
+from smartsim._core.mli.infrastructure.control.request_dispatcher import (
+ RequestDispatcher,
+)
+from smartsim._core.mli.infrastructure.control.worker_manager import WorkerManager
+from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
+ DragonFeatureStore,
+)
+from smartsim.log import get_logger
+
+logger = get_logger("Worker Manager Entry Point")
+
+mp.set_start_method("dragon")
+
+pid = os.getpid()
+affinity = os.sched_getaffinity(pid)
+logger.info(f"Entry point: {socket.gethostname()}, {affinity}")
+logger.info(f"CPUS: {os.cpu_count()}")
+
+
+def service_as_dragon_proc(
+ service: Service, cpu_affinity: list[int], gpu_affinity: list[int]
+) -> dragon_process.Process:
+
+ options = dragon_process_desc.ProcessOptions(make_inf_channels=True)
+ local_policy = dragon_policy.Policy(
+ placement=dragon_policy.Policy.Placement.HOST_NAME,
+ host_name=socket.gethostname(),
+ cpu_affinity=cpu_affinity,
+ gpu_affinity=gpu_affinity,
+ )
+ return dragon_process.Process(
+ target=service.execute,
+ args=[],
+ cwd=os.getcwd(),
+ policy=local_policy,
+ options=options,
+ stderr=dragon_process.Popen.STDOUT,
+ stdout=dragon_process.Popen.STDOUT,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("Worker Manager")
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="gpu",
+ choices="gpu cpu".split(),
+ help="Device on which the inference takes place",
+ )
+ parser.add_argument(
+ "--worker_class",
+ type=str,
+ required=True,
+ help="Serialized class of worker to run",
+ )
+ parser.add_argument(
+ "--num_workers", type=int, default=1, help="Number of workers to run"
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="How many requests the workers will try to aggregate before processing them",
+ )
+ parser.add_argument(
+ "--batch_timeout",
+ type=float,
+ default=0.001,
+ help="How much time (in seconds) should be waited before processing an incomplete aggregated request",
+ )
+ args = parser.parse_args()
+
+ connect_to_infrastructure()
+ ddict_str = os.environ[BackboneFeatureStore.MLI_BACKBONE]
+
+ backbone = BackboneFeatureStore.from_descriptor(ddict_str)
+
+ to_worker_channel = create_local()
+ to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None)
+ to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli)
+
+ backbone.worker_queue = to_worker_fli_comm_ch.descriptor
+
+ os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor
+ os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor
+
+ arg_worker_type = cloudpickle.loads(
+ base64.b64decode(args.worker_class.encode("ascii"))
+ )
+
+ config_loader = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=DragonCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+
+ dispatcher = RequestDispatcher(
+ batch_timeout=args.batch_timeout,
+ batch_size=args.batch_size,
+ config_loader=config_loader,
+ worker_type=arg_worker_type,
+ )
+
+ wms = []
+ worker_device = args.device
+ for wm_idx in range(args.num_workers):
+
+ worker_manager = WorkerManager(
+ config_loader=config_loader,
+ worker_type=arg_worker_type,
+ as_service=True,
+ cooldown=10,
+ device=worker_device,
+ dispatcher_queue=dispatcher.task_queue,
+ )
+
+ wms.append(worker_manager)
+
+ wm_affinity: list[int] = []
+ disp_affinity: list[int] = []
+
+ # This is hardcoded for a specific type of node:
+ # the GPU-to-CPU mapping is taken from the nvidia-smi tool
+ # TODO can this be computed on the fly?
+ gpu_to_cpu_aff: dict[int, list[int]] = {}
+ gpu_to_cpu_aff[0] = list(range(48, 64)) + list(range(112, 128))
+ gpu_to_cpu_aff[1] = list(range(32, 48)) + list(range(96, 112))
+ gpu_to_cpu_aff[2] = list(range(16, 32)) + list(range(80, 96))
+ gpu_to_cpu_aff[3] = list(range(0, 16)) + list(range(64, 80))
+
+ worker_manager_procs = []
+ for worker_idx in range(args.num_workers):
+ wm_cpus = len(gpu_to_cpu_aff[worker_idx]) - 4
+ wm_affinity = gpu_to_cpu_aff[worker_idx][:wm_cpus]
+ disp_affinity.extend(gpu_to_cpu_aff[worker_idx][wm_cpus:])
+ worker_manager_procs.append(
+ service_as_dragon_proc(
+ worker_manager, cpu_affinity=wm_affinity, gpu_affinity=[worker_idx]
+ )
+ )
+
+ dispatcher_proc = service_as_dragon_proc(
+ dispatcher, cpu_affinity=disp_affinity, gpu_affinity=[]
+ )
+
+ # TODO: use ProcessGroup and restart=True?
+ all_procs = [dispatcher_proc, *worker_manager_procs]
+
+ print(f"Dispatcher proc: {dispatcher_proc}")
+ for proc in all_procs:
+ proc.start()
+
+ while all(proc.is_alive for proc in all_procs):
+ time.sleep(1)
diff --git a/pyproject.toml b/pyproject.toml
index 5b81676a35..bf721b0c99 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,6 +69,7 @@ markers = [
"group_a: fast test subset a",
"group_b: fast test subset b",
"slow_tests: tests that take a long duration to complete",
+ "dragon: tests that must be executed in a dragon runtime",
]
[tool.isort]
diff --git a/setup.py b/setup.py
index 328bf1ffb6..cd5ace55db 100644
--- a/setup.py
+++ b/setup.py
@@ -119,6 +119,7 @@
class BuildError(Exception):
pass
+
# Define needed dependencies for the installation
extras_require = {
@@ -137,7 +138,7 @@ class BuildError(Exception):
"types-redis",
"types-tabulate",
"types-tqdm",
- "types-tensorflow==2.12.0.9",
+ "types-tensorflow",
"types-setuptools",
"typing_extensions>=4.1.0",
],
@@ -151,7 +152,7 @@ class BuildError(Exception):
"nbsphinx==0.9.3",
"docutils==0.18.1",
"torch==2.0.1",
- "tensorflow==2.13.1",
+ "tensorflow>=2.14,<3.0",
"ipython",
"jinja2==3.1.2",
"sphinx-design",
@@ -159,8 +160,6 @@ class BuildError(Exception):
"sphinx-autodoc-typehints",
"myst_parser",
],
- # see smartsim/_core/_install/buildenv.py for more details
- **versions.ml_extras_required(),
}
@@ -175,14 +174,16 @@ class BuildError(Exception):
"redis>=4.5",
"tqdm>=4.50.2",
"filelock>=3.4.2",
- "protobuf~=3.20",
+ "GitPython<=3.1.43",
+ "protobuf<=3.20.3",
"jinja2>=3.1.2",
- "watchdog>=4.0.0",
- "pydantic==1.10.14",
+ "pycapnp==2.0.0",
+ "watchdog>4,<5",
+ "pydantic>2",
"pyzmq>=25.1.2",
"pygithub>=2.3.0",
"numpy<2",
- "smartredis>=0.5,<0.6",
+ "smartredis>=0.6,<0.7",
],
zip_safe=False,
extras_require=extras_require,
diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py
index 8e6f94722c..a0dc489f6a 100644
--- a/smartsim/_core/_cli/build.py
+++ b/smartsim/_core/_cli/build.py
@@ -25,23 +25,38 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
+import importlib.metadata
+import operator
import os
-import platform
+from pathlib import Path
+import re
+import textwrap
import typing as t
from tabulate import tabulate
-from smartsim._core._cli.scripts.dragon_install import install_dragon
-from smartsim._core._cli.utils import SMART_LOGGER_FORMAT, pip
-from smartsim._core._install import builder
-from smartsim._core._install.buildenv import (
- BuildEnv,
- SetupError,
- Version_,
- VersionConflictError,
- Versioner,
+from smartsim._core._cli.scripts.dragon_install import (
+ DEFAULT_DRAGON_REPO,
+ DEFAULT_DRAGON_VERSION,
+ DragonInstallRequest,
+ display_post_install_logs,
+ install_dragon,
)
-from smartsim._core._install.builder import BuildError, Device
+from smartsim._core._cli.utils import SMART_LOGGER_FORMAT
+from smartsim._core._install.buildenv import BuildEnv, Version_, Versioner
+from smartsim._core._install.mlpackages import (
+ DEFAULT_MLPACKAGE_PATH,
+ DEFAULT_MLPACKAGES,
+ MLPackageCollection,
+ load_platform_configs,
+)
+from smartsim._core._install.platform import (
+ Architecture,
+ Device,
+ OperatingSystem,
+ Platform,
+)
+
from smartsim._core.config import CONFIG
from smartsim.log import get_logger
@@ -55,154 +70,66 @@
# NOTE: all smartsim modules need full paths as the smart cli
# may be installed into a different directory.
-_TPinningStr = t.Literal["==", "!=", ">=", ">", "<=", "<", "~="]
-
-
-def check_py_onnx_version(versions: Versioner) -> None:
- """Check Python environment for ONNX installation"""
- _check_packages_in_python_env(
- {
- "onnx": Version_(versions.ONNX),
- "skl2onnx": "1.16.0",
- "onnxmltools": "1.12.0",
- "scikit-learn": "1.3.2",
- },
- )
-
-
-def check_py_tf_version(versions: Versioner) -> None:
- """Check Python environment for TensorFlow installation"""
- _check_packages_in_python_env({"tensorflow": Version_(versions.TENSORFLOW)})
-
-def build_feature_store(build_env: BuildEnv, verbose: bool) -> None:
- # check feature store installation
- feature_store_builder = builder.FeatureStoreBuilder(
- build_env(),
- jobs=build_env.JOBS,
- _os=builder.OperatingSystem.from_str(platform.system()),
- architecture=builder.Architecture.from_str(platform.machine()),
- malloc=build_env.MALLOC,
- verbose=verbose,
- )
-
- if not feature_store_builder.is_built:
- logger.info("No feature store is currently being built by 'smart build'")
-
- feature_store_builder.cleanup()
- logger.info("No feature store is currently being built by 'smart build'")
-
-
-def check_py_torch_version(versions: Versioner, device: Device = Device.CPU) -> None:
- """Check Python environment for TensorFlow installation"""
- if BuildEnv.is_macos():
- if device == Device.GPU:
- raise BuildError("SmartSim does not support GPU on MacOS")
- device_suffix = ""
- else: # linux
- if device == Device.CPU:
- device_suffix = versions.TORCH_CPU_SUFFIX
- elif device == Device.GPU:
- device_suffix = versions.TORCH_CUDA_SUFFIX
- else:
- raise BuildError("Unrecognized device requested")
-
- torch_deps = {
- "torch": Version_(f"{versions.TORCH}{device_suffix}"),
- "torchvision": Version_(f"{versions.TORCHVISION}{device_suffix}"),
+def parse_requirement(
+ requirement: str,
+) -> t.Tuple[str, t.Optional[str], t.Callable[[Version_], bool]]:
+ operators = {
+ "==": operator.eq,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "<": operator.lt,
+ ">": operator.gt,
}
- missing, conflicts = _assess_python_env(
- torch_deps,
- package_pinning="==",
- validate_installed_version=_create_torch_version_validator(
- with_suffix=device_suffix
- ),
+ semantic_version_pattern = r"\d+(?:\.\d+(?:\.\d+)?)?([^\s]*)"
+ pattern = (
+ r"^" # Start
+ r"([a-zA-Z0-9_\-]+)" # Package name
+ r"(?:\[[a-zA-Z0-9_\-,]+\])?" # Any extras
+ r"(?:([<>=!~]{1,2})" # Pinning string
+ rf"({semantic_version_pattern}))?" # A version number
+ r"$" # End
)
-
- if len(missing) == len(torch_deps) and not conflicts:
- # All PyTorch deps are not installed and there are no conflicting
- # python packages. We can try to install torch deps into the current env.
- logger.info(
- "Torch version not found in python environment. "
- "Attempting to install via `pip`"
- )
- wheel_device = (
- device.value if device == Device.CPU else device_suffix.replace("+", "")
- )
- pip(
- "install",
- "--extra-index-url",
- f"https://download.pytorch.org/whl/{wheel_device}",
- *(f"{package}=={version}" for package, version in torch_deps.items()),
- )
- elif missing or conflicts:
- logger.warning(_format_incompatible_python_env_message(missing, conflicts))
-
-
-def _create_torch_version_validator(
- with_suffix: str,
-) -> t.Callable[[str, t.Optional[Version_]], bool]:
- def check_torch_version(package: str, version: t.Optional[Version_]) -> bool:
- if not BuildEnv.check_installed(package, version):
- return False
- # Default check only looks at major/minor version numbers,
- # Torch requires we look at the patch as well
- installed = BuildEnv.get_py_package_version(package)
- if with_suffix and with_suffix not in installed.patch:
- raise VersionConflictError(
- package,
- installed,
- version or Version_(f"X.X.X{with_suffix}"),
- msg=(
- f"{package}=={installed} does not satisfy device "
- f"suffix requirement: {with_suffix}"
- ),
+ match = re.match(pattern, requirement)
+ if match is None:
+ raise ValueError(f"Invalid requirement string: {requirement}")
+ module_name, cmp_op, version_str, suffix = match.groups()
+ version = Version_(version_str) if version_str is not None else None
+ if cmp_op is None:
+ is_compatible = lambda _: True # pylint: disable=unnecessary-lambda-assignment
+ elif (cmp := operators.get(cmp_op, None)) is None:
+ raise ValueError(f"Unrecognized comparison operator: {cmp_op}")
+ else:
+
+ def is_compatible(other: Version_) -> bool:
+ assert version is not None # For type check, always should be true
+ match_ = re.match(rf"^{semantic_version_pattern}$", other)
+ return (
+ cmp(other, version) and match_ is not None and match_.group(1) == suffix
)
- return True
-
- return check_torch_version
-
-
-def _check_packages_in_python_env(
- packages: t.Mapping[str, t.Optional[Version_]],
- package_pinning: _TPinningStr = "==",
- validate_installed_version: t.Optional[
- t.Callable[[str, t.Optional[Version_]], bool]
- ] = None,
-) -> None:
- # TODO: Do not like how the default validation function will always look for
- # a `==` pinning. Maybe turn `BuildEnv.check_installed` into a factory
- # that takes a pinning and returns an appropriate validation fn?
- validate_installed_version = validate_installed_version or BuildEnv.check_installed
- missing, conflicts = _assess_python_env(
- packages,
- package_pinning,
- validate_installed_version,
- )
- if missing or conflicts:
- logger.warning(_format_incompatible_python_env_message(missing, conflicts))
+ return module_name, f"{cmp_op}{version}" if version else None, is_compatible
-def _assess_python_env(
- packages: t.Mapping[str, t.Optional[Version_]],
- package_pinning: _TPinningStr,
- validate_installed_version: t.Callable[[str, t.Optional[Version_]], bool],
-) -> t.Tuple[t.List[str], t.List[str]]:
- missing: t.List[str] = []
- conflicts: t.List[str] = []
+def check_ml_python_packages(packages: MLPackageCollection) -> None:
+ missing = []
+ conflicts = []
- for name, version in packages.items():
- spec = f"{name}{package_pinning}{version}" if version else name
- try:
- if not validate_installed_version(name, version):
- # Not installed!
- missing.append(spec)
- except VersionConflictError:
- # Incompatible version found
- conflicts.append(spec)
+ for package in packages.values():
+ for requirement in package.python_packages:
+ module_name, version_spec, is_compatible = parse_requirement(requirement)
+ try:
+ installed = BuildEnv.get_py_package_version(module_name)
+ if not is_compatible(installed):
+ conflicts.append(
+ f"{module_name}: {installed} is installed, "
+ f"but {version_spec or 'Any'} is required"
+ )
+ except importlib.metadata.PackageNotFoundError:
+ missing.append(module_name)
- return missing, conflicts
+ if missing or conflicts:
+ logger.warning(_format_incompatible_python_env_message(missing, conflicts))
def _format_incompatible_python_env_message(
@@ -215,13 +142,19 @@ def _format_incompatible_python_env_message(
missing_str = fmt_list("Missing", missing)
conflict_str = fmt_list("Conflicting", conflicting)
sep = "\n" if missing_str and conflict_str else ""
- return (
- "Python Env Status Warning!\n"
- "Requested Packages are Missing or Conflicting:\n\n"
- f"{missing_str}{sep}{conflict_str}\n\n"
- "Consider installing packages at the requested versions via `pip` or "
- "uninstalling them, installing SmartSim with optional ML dependencies "
- "(`pip install smartsim[ml]`), and running `smart clean && smart build ...`"
+
+ return textwrap.dedent(
+ f"""\
+ Python Package Warning:
+
+ Requested packages are missing or have a version mismatch with
+ their respective backend:
+
+ {missing_str}{sep}{conflict_str}
+
+ Consider uninstalling any conflicting packages and rerunning
+ `smart build` if you encounter issues.
+ """
)
@@ -229,13 +162,30 @@ def _format_incompatible_python_env_message(
def execute(
args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, /
) -> int:
+
+ # Unpack various arguments
verbose = args.v
device = Device(args.device.lower())
is_dragon_requested = args.dragon
- # torch and tf build by default
- pt = not args.no_pt # pylint: disable=invalid-name
- tf = not args.no_tf # pylint: disable=invalid-name
- onnx = args.onnx
+ dragon_repo = args.dragon_repo
+ dragon_version = args.dragon_version
+
+ # The user should never have to specify the OS and Architecture
+ current_platform = Platform(
+ OperatingSystem.autodetect(), Architecture.autodetect(), device
+ )
+
+ # Configure the ML Packages
+ configs = load_platform_configs(Path(args.config_dir))
+ mlpackages = configs[current_platform]
+
+ # Build all backends by default, pop off the ones that user wants skipped
+ if args.skip_torch and "libtorch" in mlpackages:
+ mlpackages.pop("libtorch")
+ if args.skip_tensorflow and "libtensorflow" in mlpackages:
+ mlpackages.pop("libtensorflow")
+ if args.skip_onnx and "onnxruntime" in mlpackages:
+ mlpackages.pop("onnxruntime")
build_env = BuildEnv(checks=True)
logger.info("Running SmartSim build process...")
@@ -257,41 +207,40 @@ def execute(
version_names = list(vers.keys())
print(tabulate(vers, headers=version_names, tablefmt="github"), "\n")
- if is_dragon_requested:
+ logger.info("ML Packages")
+ print(mlpackages)
+
+ if is_dragon_requested or dragon_repo or dragon_version:
install_to = CONFIG.core_path / ".dragon"
- return_code = install_dragon(install_to)
+
+ try:
+ request = DragonInstallRequest(
+ install_to,
+ dragon_repo,
+ dragon_version,
+ )
+ return_code = install_dragon(request)
+ except ValueError as ex:
+ return_code = 2
+ logger.error(" ".join(ex.args))
if return_code == 0:
- logger.info("Dragon installation complete")
+ display_post_install_logs()
+
elif return_code == 1:
logger.info("Dragon installation not supported on platform")
else:
logger.warning("Dragon installation failed")
- try:
- if not args.only_python_packages:
- ...
-
- except (SetupError, BuildError) as e:
- logger.error(str(e))
- return os.EX_SOFTWARE
-
backends = []
backends_str = ", ".join(s.capitalize() for s in backends) if backends else "No"
- logger.info(f"{backends_str} backend(s) built")
-
- try:
- # TODO: always installing torch, otherwise tests will fail.
- # Should revert once torch install has been revamped
- if "torch" in backends or True:
- check_py_torch_version(versions, device)
- if "tensorflow" in backends:
- check_py_tf_version(versions)
- if "onnxruntime" in backends:
- check_py_onnx_version(versions)
- except (SetupError, BuildError) as e:
- logger.error(str(e))
- return os.EX_SOFTWARE
+ logger.info(f"{backends_str} backend(s) available")
+
+ if not args.skip_python_packages:
+ for package in mlpackages.values():
+ logger.info(f"Installing python packages for {package.name}")
+ package.pip_install(quiet=not verbose)
+ check_ml_python_packages(mlpackages)
logger.info("SmartSim build complete!")
return os.EX_OK
@@ -299,7 +248,14 @@ def execute(
def configure_parser(parser: argparse.ArgumentParser) -> None:
"""Builds the parser for the command"""
- warn_usage = "(ONLY USE IF NEEDED)"
+
+ available_devices = []
+ for platform in DEFAULT_MLPACKAGES:
+ if (platform.operating_system == OperatingSystem.autodetect()) and (
+ platform.architecture == Architecture.autodetect()
+ ):
+ available_devices.append(platform.device.value)
+
parser.add_argument(
"-v",
action="store_true",
@@ -310,7 +266,7 @@ def configure_parser(parser: argparse.ArgumentParser) -> None:
"--device",
type=str.lower,
default=Device.CPU.value,
- choices=[device.value for device in Device],
+ choices=available_devices,
help="Device to build ML runtimes for",
)
parser.add_argument(
@@ -320,44 +276,48 @@ def configure_parser(parser: argparse.ArgumentParser) -> None:
help="Install the dragon runtime",
)
parser.add_argument(
- "--only_python_packages",
- action="store_true",
- default=False,
- help="Only evaluate the python packages (i.e. skip building backends)",
+ "--dragon-repo",
+ default=None,
+ type=str,
+ help=(
+ "Specify a git repo containing dragon release assets "
+ f"(e.g. {DEFAULT_DRAGON_REPO})"
+ ),
+ )
+ parser.add_argument(
+ "--dragon-version",
+ default=None,
+ type=str,
+ help=f"Specify the dragon version to install (e.g. {DEFAULT_DRAGON_VERSION})",
)
parser.add_argument(
- "--no_pt",
+ "--skip-python-packages",
action="store_true",
- default=False,
- help="Do not build PyTorch backend",
+ help="Do not install the python packages that match the backends",
)
parser.add_argument(
- "--no_tf",
+ "--skip-backends",
action="store_true",
- default=False,
- help="Do not build TensorFlow backend",
+ help="Do not compile RedisAI and the backends",
)
parser.add_argument(
- "--onnx",
+ "--skip-torch",
action="store_true",
- default=False,
- help="Build ONNX backend (off by default)",
+ help="Do not build PyTorch backend",
)
parser.add_argument(
- "--torch_dir",
- default=None,
- type=str,
- help=f"Path to custom /torch/share/cmake/Torch/ directory {warn_usage}",
+ "--skip-tensorflow",
+ action="store_true",
+ help="Do not build TensorFlow backend",
)
parser.add_argument(
- "--libtensorflow_dir",
- default=None,
- type=str,
- help=f"Path to custom libtensorflow directory {warn_usage}",
+ "--skip-onnx",
+ action="store_true",
+ help="Do not build the ONNX backend",
)
parser.add_argument(
- "--no_torch_with_mkl",
- dest="torch_with_mkl",
- action="store_false",
- help="Do not build Torch with Intel MKL",
+ "--config-dir",
+ default=str(DEFAULT_MLPACKAGE_PATH),
+ type=str,
+ help="Path to directory with JSON files describing platform and packages",
)
diff --git a/smartsim/_core/_cli/info.py b/smartsim/_core/_cli/info.py
index ec50e151aa..7fa094fbdc 100644
--- a/smartsim/_core/_cli/info.py
+++ b/smartsim/_core/_cli/info.py
@@ -6,9 +6,7 @@
from tabulate import tabulate
-import smartsim._core._cli.utils as _utils
import smartsim._core.utils.helpers as _helpers
-from smartsim._core._cli.scripts.dragon_install import dragon_pin
from smartsim._core._install.buildenv import BuildEnv as _BuildEnv
_MISSING_DEP = _helpers.colorize("Not Installed", "red")
@@ -30,7 +28,8 @@ def execute(
)
print("Dragon Installation:")
- dragon_version = dragon_pin()
+ # TODO: Fix hardcoded dragon version
+ dragon_version = "0.10"
fs_table = [["Version", str(dragon_version)]]
print(tabulate(fs_table, tablefmt="fancy_outline"), end="\n\n")
diff --git a/smartsim/_core/_cli/scripts/dragon_install.py b/smartsim/_core/_cli/scripts/dragon_install.py
index a2e8ed36ff..7a7d75f1d2 100644
--- a/smartsim/_core/_cli/scripts/dragon_install.py
+++ b/smartsim/_core/_cli/scripts/dragon_install.py
@@ -1,13 +1,19 @@
import os
import pathlib
+import re
+import shutil
import sys
import typing as t
+from urllib.request import Request, urlopen
from github import Github
+from github.Auth import Token
+from github.GitRelease import GitRelease
from github.GitReleaseAsset import GitReleaseAsset
+from github.Repository import Repository
from smartsim._core._cli.utils import pip
-from smartsim._core._install.builder import WebTGZ
+from smartsim._core._install.utils import retrieve
from smartsim._core.config import CONFIG
from smartsim._core.utils.helpers import check_platform, is_crayex_platform
from smartsim.error.errors import SmartSimCLIActionCancelled
@@ -15,20 +21,90 @@
logger = get_logger(__name__)
+DEFAULT_DRAGON_REPO = "DragonHPC/dragon"
+DEFAULT_DRAGON_VERSION = "0.10"
+DEFAULT_DRAGON_VERSION_TAG = f"v{DEFAULT_DRAGON_VERSION}"
+_GH_TOKEN = "SMARTSIM_DRAGON_TOKEN"
-def create_dotenv(dragon_root_dir: pathlib.Path) -> None:
+
+class DragonInstallRequest:
+ """Encapsulates a request to install the dragon package"""
+
+ def __init__(
+ self,
+ working_dir: pathlib.Path,
+ repo_name: t.Optional[str] = None,
+ version: t.Optional[str] = None,
+ ) -> None:
+ """Initialize an install request.
+
+ :param working_dir: A path to store temporary files used during installation
+ :param repo_name: The name of a repository to install from, e.g. DragonHPC/dragon
+ :param version: The version to install, e.g. v0.10
+ """
+
+ self.working_dir = working_dir
+ """A path to store temporary files used during installation"""
+
+ self.repo_name = repo_name or DEFAULT_DRAGON_REPO
+ """The name of a repository to install from, e.g. DragonHPC/dragon"""
+
+ self.pkg_version = version or DEFAULT_DRAGON_VERSION
+ """The version to install, e.g. 0.10"""
+
+ self._check()
+
+ def _check(self) -> None:
+ """Perform validation of this instance
+
+ :raises ValueError: if any value fails validation"""
+ if not self.repo_name or len(self.repo_name.split("/")) != 2:
+ raise ValueError(
+ f"Invalid dragon repository name. Example: `dragonhpc/dragon`"
+ )
+
+ # version must match standard dragon tag & filename format `vX.YZ`
+ match = re.match(r"^\d\.\d+$", self.pkg_version)
+ if not self.pkg_version or not match:
+ raise ValueError("Invalid dragon version. Examples: `0.9, 0.91, 0.10`")
+
+ # attempting to retrieve from a non-default repository requires an auth token
+ if self.repo_name.lower() != DEFAULT_DRAGON_REPO.lower() and not self.raw_token:
+ raise ValueError(
+ f"An access token must be available to access {self.repo_name}. "
+ f"Set the `{_GH_TOKEN}` env var to pass your access token."
+ )
+
+ @property
+ def raw_token(self) -> t.Optional[str]:
+ """Returns the raw access token from the environment, if available"""
+ return os.environ.get(_GH_TOKEN, None)
+
+
+def get_auth_token(request: DragonInstallRequest) -> t.Optional[Token]:
+ """Create a Github.Auth.Token if an access token can be found
+ in the environment
+
+ :param request: details of a request for the installation of the dragon package
+ :returns: an auth token if one can be built, otherwise `None`"""
+ if gh_token := request.raw_token:
+ return Token(gh_token)
+ return None
+
+
+def create_dotenv(dragon_root_dir: pathlib.Path, dragon_version: str) -> None:
"""Create a .env file with required environment variables for the Dragon runtime"""
dragon_root = str(dragon_root_dir)
- dragon_inc_dir = str(dragon_root_dir / "include")
- dragon_lib_dir = str(dragon_root_dir / "lib")
- dragon_bin_dir = str(dragon_root_dir / "bin")
+ dragon_inc_dir = dragon_root + "/include"
+ dragon_lib_dir = dragon_root + "/lib"
+ dragon_bin_dir = dragon_root + "/bin"
dragon_vars = {
"DRAGON_BASE_DIR": dragon_root,
- "DRAGON_ROOT_DIR": dragon_root, # note: same as base_dir
+ "DRAGON_ROOT_DIR": dragon_root,
"DRAGON_INCLUDE_DIR": dragon_inc_dir,
"DRAGON_LIB_DIR": dragon_lib_dir,
- "DRAGON_VERSION": dragon_pin(),
+ "DRAGON_VERSION": dragon_version,
"PATH": dragon_bin_dir,
"LD_LIBRARY_PATH": dragon_lib_dir,
}
@@ -48,12 +124,6 @@ def python_version() -> str:
return f"py{sys.version_info.major}.{sys.version_info.minor}"
-def dragon_pin() -> str:
- """Return a string indicating the pinned major/minor version of the dragon
- package to install"""
- return "0.9"
-
-
def _platform_filter(asset_name: str) -> bool:
"""Return True if the asset name matches naming standard for current
platform (Cray or non-Cray). Otherwise, returns False.
@@ -75,67 +145,125 @@ def _version_filter(asset_name: str) -> bool:
return python_version() in asset_name
-def _pin_filter(asset_name: str) -> bool:
+def _pin_filter(asset_name: str, dragon_version: str) -> bool:
"""Return true if the supplied value contains a dragon version pin match
- :param asset_name: A value to inspect for keywords indicating a dragon version
+ :param asset_name: the asset name to inspect for keywords indicating a dragon version
+ :param dragon_version: the dragon version to match
:returns: True if supplied value is correct for current dragon version"""
- return f"dragon-{dragon_pin()}" in asset_name
+ return f"dragon-{dragon_version}" in asset_name
+
+
+def _get_all_releases(dragon_repo: Repository) -> t.Collection[GitRelease]:
+ """Retrieve all available releases for the configured dragon repository
+
+ :param dragon_repo: A GitHub repository object for the dragon package
+ :returns: A list of GitRelease"""
+ all_releases = [release for release in list(dragon_repo.get_releases())]
+ return all_releases
-def _get_release_assets() -> t.Collection[GitReleaseAsset]:
+def _get_release_assets(request: DragonInstallRequest) -> t.Collection[GitReleaseAsset]:
"""Retrieve a collection of available assets for all releases that satisfy
the dragon version pin
+ :param request: details of a request for the installation of the dragon package
:returns: A collection of release assets"""
- git = Github()
-
- dragon_repo = git.get_repo("DragonHPC/dragon")
+ auth = get_auth_token(request)
+ git = Github(auth=auth)
+ dragon_repo = git.get_repo(request.repo_name)
if dragon_repo is None:
raise SmartSimCLIActionCancelled("Unable to locate dragon repo")
- # find any releases matching our pinned version requirement
- tags = [tag for tag in dragon_repo.get_tags() if dragon_pin() in tag.name]
- # repo.get_latest_release fails if only pre-release results are returned
- pin_releases = list(dragon_repo.get_release(tag.name) for tag in tags)
- releases = sorted(pin_releases, key=lambda r: r.published_at, reverse=True)
+ all_releases = sorted(
+ _get_all_releases(dragon_repo), key=lambda r: r.published_at, reverse=True
+ )
- # take the most recent release for the given pin
- assets = releases[0].assets
+ # filter the list of releases to include only the target version
+ releases = [
+ release
+ for release in all_releases
+ if request.pkg_version in release.title or release.tag_name
+ ]
+
+ releases = sorted(releases, key=lambda r: r.published_at, reverse=True)
+
+ if not releases:
+ release_titles = ", ".join(release.title for release in all_releases)
+ raise SmartSimCLIActionCancelled(
+ f"Unable to find a release for dragon version {request.pkg_version}. "
+ f"Available releases: {release_titles}"
+ )
+
+ assets: t.List[GitReleaseAsset] = []
+
+ # install the latest release of the target version (including pre-release)
+ for release in releases:
+ # delay in attaching release assets may leave us with an empty list, retry
+ # with the next available release
+ if assets := list(release.get_assets()):
+ logger.debug(f"Found assets for dragon release {release.title}")
+ break
+ else:
+ logger.debug(f"No assets for dragon release {release.title}. Retrying.")
+
+ if not assets:
+ raise SmartSimCLIActionCancelled(
+ f"Unable to find assets for dragon release {release.title}"
+ )
return assets
-def filter_assets(assets: t.Collection[GitReleaseAsset]) -> t.Optional[GitReleaseAsset]:
+def filter_assets(
+ request: DragonInstallRequest, assets: t.Collection[GitReleaseAsset]
+) -> t.Optional[GitReleaseAsset]:
"""Filter the available release assets so that HSTA agents are used
when run on a Cray EX platform
+ :param request: details of a request for the installation of the dragon package
:param assets: The collection of dragon release assets to filter
:returns: An asset meeting platform & version filtering requirements"""
# Expect cray & non-cray assets that require a filter, e.g.
# 'dragon-0.8-py3.9.4.1-bafaa887f.tar.gz',
# 'dragon-0.8-py3.9.4.1-CRAYEX-ac132fe95.tar.gz'
- asset = next(
- (
- asset
- for asset in assets
- if _version_filter(asset.name)
- and _platform_filter(asset.name)
- and _pin_filter(asset.name)
- ),
- None,
+ all_assets = [asset.name for asset in assets]
+
+ assets = list(
+ asset
+ for asset in assets
+ if _version_filter(asset.name) and _pin_filter(asset.name, request.pkg_version)
)
+
+ if len(assets) == 0:
+ available = "\n\t".join(all_assets)
+ logger.warning(
+ f"Please specify a dragon version (e.g. {DEFAULT_DRAGON_VERSION}) "
+ f"of an asset available in the repository:\n\t{available}"
+ )
+ return None
+
+ asset: t.Optional[GitReleaseAsset] = None
+
+ # Apply platform filter if we have multiple matches for python/dragon version
+ if len(assets) > 0:
+ asset = next((asset for asset in assets if _platform_filter(asset.name)), None)
+
+ if not asset:
+ asset = assets[0]
+ logger.warning(f"Platform-specific package not found. Using {asset.name}")
+
return asset
-def retrieve_asset_info() -> GitReleaseAsset:
+def retrieve_asset_info(request: DragonInstallRequest) -> GitReleaseAsset:
"""Find a release asset that meets all necessary filtering criteria
- :param dragon_pin: identify the dragon version to install (e.g. dragon-0.8)
+ :param request: details of a request for the installation of the dragon package
:returns: A GitHub release asset"""
- assets = _get_release_assets()
- asset = filter_assets(assets)
+ assets = _get_release_assets(request)
+ asset = filter_assets(request, assets)
platform_result = check_platform()
if not platform_result.is_cray:
@@ -150,43 +278,79 @@ def retrieve_asset_info() -> GitReleaseAsset:
return asset
-def retrieve_asset(working_dir: pathlib.Path, asset: GitReleaseAsset) -> pathlib.Path:
+def retrieve_asset(
+ request: DragonInstallRequest, asset: GitReleaseAsset
+) -> pathlib.Path:
"""Retrieve the physical file associated to a given GitHub release asset
- :param working_dir: location in file system where assets should be written
+ :param request: details of a request for the installation of the dragon package
:param asset: GitHub release asset to retrieve
- :returns: path to the downloaded asset"""
- if working_dir.exists() and list(working_dir.rglob("*.whl")):
- return working_dir
+ :returns: path to the directory containing the extracted release asset
+ :raises SmartSimCLIActionCancelled: if the asset cannot be downloaded or extracted
+ """
+ download_dir = request.working_dir / str(asset.id)
+
+ # if we've previously downloaded the release and still have
+ # wheels laying around, use that cached version instead
+ cleanup(download_dir)
+ download_dir.mkdir(parents=True, exist_ok=True)
+
+ # grab a copy of the complete asset
+ asset_path = download_dir / str(asset.name)
+
+ # use the asset URL instead of the browser_download_url to enable
+ # using auth for private repositories
+ headers: t.Dict[str, str] = {"Accept": "application/octet-stream"}
+
+ if request.raw_token:
+ headers["Authorization"] = f"Bearer {request.raw_token}"
+
+ try:
+ # a github asset endpoint causes a redirect. the first request
+ # receives a pre-signed URL to the asset to pass on to retrieve
+ dl_request = Request(asset.url, headers=headers)
+ response = urlopen(dl_request)
+ presigned_url = response.url
+
+ logger.debug(f"Retrieved asset {asset.name} metadata from {asset.url}")
+ except Exception:
+ logger.exception(f"Unable to download {asset.name} from: {asset.url}")
+ presigned_url = asset.url
+
+ # extract the asset
+ try:
+ retrieve(presigned_url, asset_path)
- archive = WebTGZ(asset.browser_download_url)
- archive.extract(working_dir)
+ logger.debug(f"Extracted {asset.name} to {download_dir}")
+ except Exception as ex:
+ raise SmartSimCLIActionCancelled(
+ f"Unable to extract {asset.name} from {download_dir}"
+ ) from ex
- logger.debug(f"Retrieved {asset.browser_download_url} to {working_dir}")
- return working_dir
+ return download_dir
-def install_package(asset_dir: pathlib.Path) -> int:
+def install_package(request: DragonInstallRequest, asset_dir: pathlib.Path) -> int:
"""Install the package found in `asset_dir` into the current python environment
- :param asset_dir: path to a decompressed archive contents for a release asset"""
- wheels = asset_dir.rglob("*.whl")
- wheel_path = next(wheels, None)
- if not wheel_path:
- logger.error(f"No wheel found for package in {asset_dir}")
+ :param request: details of a request for the installation of the dragon package
+ :param asset_dir: path to a decompressed archive contents for a release asset
+ :returns: Integer return code, 0 for success, non-zero on failures"""
+ found_wheels = list(asset_dir.rglob("*.whl"))
+ if not found_wheels:
+ logger.error(f"No wheel(s) found for package in {asset_dir}")
return 1
- create_dotenv(wheel_path.parent)
-
- while wheel_path is not None:
- logger.info(f"Installing package: {wheel_path.absolute()}")
+ create_dotenv(found_wheels[0].parent, request.pkg_version)
- try:
- pip("install", "--force-reinstall", str(wheel_path), "numpy<2")
- wheel_path = next(wheels, None)
- except Exception:
- logger.error(f"Unable to install from {asset_dir}")
- return 1
+ try:
+ wheels = list(map(str, found_wheels))
+ for wheel_path in wheels:
+ logger.info(f"Installing package: {wheel_path}")
+ pip("install", wheel_path)
+ except Exception:
+ logger.error(f"Unable to install from {asset_dir}")
+ return 1
return 0
@@ -197,36 +361,83 @@ def cleanup(
"""Delete the downloaded asset and any files extracted during installation
:param archive_path: path to a downloaded archive for a release asset"""
- if archive_path:
- archive_path.unlink(missing_ok=True)
- logger.debug(f"Deleted archive: {archive_path}")
+ if not archive_path:
+ return
+
+ if archive_path.exists() and archive_path.is_file():
+ archive_path.unlink()
+ archive_path = archive_path.parent
+ if archive_path.exists() and archive_path.is_dir():
+ shutil.rmtree(archive_path, ignore_errors=True)
+ logger.debug(f"Deleted temporary files in: {archive_path}")
-def install_dragon(extraction_dir: t.Union[str, os.PathLike[str]]) -> int:
+
+def install_dragon(request: DragonInstallRequest) -> int:
"""Retrieve a dragon runtime appropriate for the current platform
and install to the current python environment
- :param extraction_dir: path for download and extraction of assets
+
+ :param request: details of a request for the installation of the dragon package
:returns: Integer return code, 0 for success, non-zero on failures"""
if sys.platform == "darwin":
logger.debug(f"Dragon not supported on platform: {sys.platform}")
return 1
- extraction_dir = pathlib.Path(extraction_dir)
- filename: t.Optional[pathlib.Path] = None
asset_dir: t.Optional[pathlib.Path] = None
try:
- asset_info = retrieve_asset_info()
- asset_dir = retrieve_asset(extraction_dir, asset_info)
+ asset_info = retrieve_asset_info(request)
+ if asset_info is not None:
+ asset_dir = retrieve_asset(request, asset_info)
+ return install_package(request, asset_dir)
- return install_package(asset_dir)
+ except SmartSimCLIActionCancelled as ex:
+ logger.warning(*ex.args)
except Exception as ex:
- logger.error("Unable to install dragon runtime", exc_info=ex)
- finally:
- cleanup(filename)
+ logger.error("Unable to install dragon runtime", exc_info=True)
return 2
+def display_post_install_logs() -> None:
+ """Display post-installation instructions for the user"""
+
+ examples = {
+ "ofi-include": "/opt/cray/include",
+ "ofi-build-lib": "/opt/cray/lib64",
+ "ofi-runtime-lib": "/opt/cray/lib64",
+ }
+
+ config = ":".join(f"{k}={v}" for k, v in examples.items())
+ example_msg1 = f"dragon-config -a \\"
+ example_msg2 = f' "{config}"'
+
+ logger.info(
+ "************************** Dragon Package Installed *****************************"
+ )
+ logger.info("To enable Dragon to use HSTA (default: TCP), configure the following:")
+
+ for key in examples:
+ logger.info(f"\t{key}")
+
+ logger.info("Example:")
+ logger.info(example_msg1)
+ logger.info(example_msg2)
+ logger.info(
+ "*********************************************************************************"
+ )
+
+
if __name__ == "__main__":
- sys.exit(install_dragon(CONFIG.core_path / ".dragon"))
+ # path for download and extraction of assets
+ extraction_dir = CONFIG.core_path / ".dragon"
+ dragon_repo = DEFAULT_DRAGON_REPO
+ dragon_version = DEFAULT_DRAGON_VERSION
+
+ request = DragonInstallRequest(
+ extraction_dir,
+ dragon_repo,
+ dragon_version,
+ )
+
+ sys.exit(install_dragon(request))
diff --git a/smartsim/_core/_cli/validate.py b/smartsim/_core/_cli/validate.py
index 16b6ec4ea8..a87642e49f 100644
--- a/smartsim/_core/_cli/validate.py
+++ b/smartsim/_core/_cli/validate.py
@@ -33,7 +33,7 @@
from types import TracebackType
from smartsim._core._cli.utils import SMART_LOGGER_FORMAT
-from smartsim._core._install.builder import Device
+from smartsim._core._install.platform import Device
from smartsim.log import get_logger
logger = get_logger("Smart", fmt=SMART_LOGGER_FORMAT)
@@ -69,7 +69,9 @@ def __exit__(
self._finalizer.detach() # type: ignore[attr-defined]
-def execute(args: argparse.Namespace) -> int:
+def execute(
+ args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None
+) -> int:
"""Validate the SmartSim installation works as expected given a
simple experiment
"""
diff --git a/smartsim/_core/_install/buildenv.py b/smartsim/_core/_install/buildenv.py
index ca52520695..552f9e28b0 100644
--- a/smartsim/_core/_install/buildenv.py
+++ b/smartsim/_core/_install/buildenv.py
@@ -53,30 +53,6 @@ class SetupError(Exception):
"""
-class VersionConflictError(SetupError):
- """An error for when version numbers of some library/package/program/etc
- do not match and build may not be able to continue
- """
-
- def __init__(
- self,
- name: str,
- current_version: "Version_",
- target_version: "Version_",
- msg: t.Optional[str] = None,
- ) -> None:
- if msg is None:
- msg = (
- f"Incompatible version for {name} detected: "
- f"{name} {target_version} requested but {name} {current_version} "
- "installed."
- )
- super().__init__(msg)
- self.name = name
- self.current_version = current_version
- self.target_version = target_version
-
-
# so as to not conflict with pkg_resources.packaging.version.Version
# pylint: disable-next=invalid-name
class Version_(str):
@@ -183,58 +159,29 @@ class Versioner:
PYTHON_MIN = Version_("3.9.0")
# Versions
- SMARTSIM = Version_(get_env("SMARTSIM_VERSION", "0.7.0"))
+ SMARTSIM = Version_(get_env("SMARTSIM_VERSION", "0.8.0"))
SMARTSIM_SUFFIX = get_env("SMARTSIM_SUFFIX", "")
- # ML/DL
- # torch can be set by the user because we download that for them
- TORCH = Version_(get_env("SMARTSIM_TORCH", "2.0.1"))
- TORCHVISION = Version_(get_env("SMARTSIM_TORCHVIS", "0.15.2"))
- TORCH_CPU_SUFFIX = Version_(get_env("TORCH_CPU_SUFFIX", "+cpu"))
- TORCH_CUDA_SUFFIX = Version_(get_env("TORCH_CUDA_SUFFIX", "+cu117"))
-
- # TensorFlow and ONNX only use the defaults
+ # Redis
+ REDIS = Version_(get_env("SMARTSIM_REDIS", "7.2.4"))
+ REDIS_URL = get_env("SMARTSIM_REDIS_URL", "https://github.com/redis/redis.git")
+ REDIS_BRANCH = get_env("SMARTSIM_REDIS_BRANCH", REDIS)
- TENSORFLOW = Version_("2.13.1")
- ONNX = Version_("1.14.1")
+ # RedisAI
+ REDISAI = "1.2.7"
+ REDISAI_URL = get_env(
+ "SMARTSIM_REDISAI_URL", "https://github.com/RedisAI/RedisAI.git"
+ )
+ REDISAI_BRANCH = get_env("SMARTSIM_REDISAI_BRANCH", f"v{REDISAI}")
- def as_dict(self) -> t.Dict[str, t.Tuple[str, ...]]:
+ def as_dict(self, db_name: DbEngine = "REDIS") -> t.Dict[str, t.Tuple[str, ...]]:
pkg_map = {
"SMARTSIM": self.SMARTSIM,
- "TORCH": self.TORCH,
- "TENSORFLOW": self.TENSORFLOW,
- "ONNX": self.ONNX,
+ db_name: self.REDIS,
+ "REDISAI": self.REDISAI,
}
return {"Packages": tuple(pkg_map), "Versions": tuple(pkg_map.values())}
- # TODO add a backend for ml libraries
- def ml_extras_required(self) -> t.Dict[str, t.List[str]]:
- """Optional ML/DL dependencies we suggest for the user."""
- ml_defaults = {
- "torch": self.TORCH,
- "tensorflow": self.TENSORFLOW,
- "onnx": self.ONNX,
- "skl2onnx": "1.16.0",
- "onnxmltools": "1.12.0",
- "scikit-learn": "1.3.2",
- "torchvision": "0.15.2",
- "torch_cpu_suffix": "+cpu",
- "torch_cuda_suffix": "+cu117",
- }
-
- # remove torch-related fields as they are subject to change
- # by having the user change hardware (cpu/gpu)
- _torch_fields = [
- "torch",
- "torchvision",
- "torch_cpu_suffix",
- "torch_cuda_suffix",
- ]
- for field in _torch_fields:
- ml_defaults.pop(field)
-
- return {"ml": [f"{lib}=={vers}" for lib, vers in ml_defaults.items()]}
-
@staticmethod
def get_sha(setup_py_dir: Path) -> str:
"""Get the git sha of the current branch"""
@@ -304,7 +251,7 @@ def __init__(self, checks: bool = True) -> None:
self.check_dependencies()
def check_dependencies(self) -> None:
- deps = ["git", "git-lfs", "make", "wget", "cmake", self.CC, self.CXX]
+ deps = ["git", "make", "wget", "cmake", self.CC, self.CXX]
if int(self.CHECKS) == 0:
for dep in deps:
self.check_build_dependency(dep)
@@ -417,23 +364,6 @@ def check_build_dependency(command: str) -> None:
except OSError:
raise SetupError(f"{command} must be installed to build SmartSim") from None
- @classmethod
- def check_installed(
- cls, package: str, version: t.Optional[Version_] = None
- ) -> bool:
- """Check if a package is installed. If version is provided, check if
- it's a compatible version. (major and minor the same)
- """
- try:
- installed = cls.get_py_package_version(package)
- except importlib.metadata.PackageNotFoundError:
- return False
- if version:
- # detect if major or minor versions differ
- if installed.major != version.major or installed.minor != version.minor:
- raise VersionConflictError(package, installed, version)
- return True
-
@staticmethod
def get_py_package_version(package: str) -> Version_:
return Version_(importlib.metadata.version(package))
diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py
index 8cda07ede5..a1a4cb93b5 100644
--- a/smartsim/_core/_install/builder.py
+++ b/smartsim/_core/_install/builder.py
@@ -26,28 +26,15 @@
# pylint: disable=too-many-lines
-import enum
-import fileinput
-import itertools
import os
-import platform
import re
import shutil
import stat
import subprocess
-import tarfile
-import tempfile
import typing as t
-import urllib.request
-import zipfile
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
from pathlib import Path
-from shutil import which
from subprocess import SubprocessError
-# NOTE: This will be imported by setup.py and hence no smartsim related
-# items should be imported into this file.
# TODO: check cmake version and use system if possible to avoid conflicts
@@ -56,66 +43,10 @@
_U = t.TypeVar("_U")
-def expand_exe_path(exe: str) -> str:
- """Takes an executable and returns the full path to that executable
-
- :param exe: executable or file
- :raises TypeError: if file is not an executable
- :raises FileNotFoundError: if executable cannot be found
- """
-
- # which returns none if not found
- in_path = which(exe)
- if not in_path:
- if os.path.isfile(exe) and os.access(exe, os.X_OK):
- return os.path.abspath(exe)
- if os.path.isfile(exe) and not os.access(exe, os.X_OK):
- raise TypeError(f"File, {exe}, is not an executable")
- raise FileNotFoundError(f"Could not locate executable {exe}")
- return os.path.abspath(in_path)
-
-
class BuildError(Exception):
pass
-class Architecture(enum.Enum):
- X64 = ("x86_64", "amd64")
- ARM64 = ("arm64",)
-
- @classmethod
- def from_str(cls, string: str, /) -> "Architecture":
- string = string.lower()
- for type_ in cls:
- if string in type_.value:
- return type_
- raise BuildError(f"Unrecognized or unsupported architecture: {string}")
-
-
-class Device(enum.Enum):
- CPU = "cpu"
- GPU = "gpu"
-
-
-class OperatingSystem(enum.Enum):
- LINUX = ("linux", "linux2")
- DARWIN = ("darwin",)
-
- @classmethod
- def from_str(cls, string: str, /) -> "OperatingSystem":
- string = string.lower()
- for type_ in cls:
- if string in type_.value:
- return type_
- raise BuildError(f"Unrecognized or unsupported operating system: {string}")
-
-
-class Platform(t.NamedTuple):
- os: OperatingSystem
- architecture: Architecture
-
-
-# TODO: Add FeatureStoreBuilder member
class Builder:
"""Base class for building third-party libraries"""
@@ -133,13 +64,10 @@ def __init__(
self,
env: t.Dict[str, str],
jobs: int = 1,
- _os: OperatingSystem = OperatingSystem.from_str(platform.system()),
- architecture: Architecture = Architecture.from_str(platform.machine()),
verbose: bool = False,
) -> None:
# build environment from buildenv
self.env = env
- self._platform = Platform(_os, architecture)
# Find _core directory and set up paths
_core_dir = Path(os.path.abspath(__file__)).parent.parent
@@ -174,11 +102,6 @@ def out(self) -> t.Optional[int]:
def is_built(self) -> bool:
raise NotImplementedError
- def build_from_git(
- self, git_url: str, branch: str, device: Device = Device.CPU
- ) -> None:
- raise NotImplementedError
-
@staticmethod
def binary_path(binary: str) -> str:
binary_ = shutil.which(binary)
@@ -239,281 +162,3 @@ def run_command(
raise BuildError(error)
except (OSError, SubprocessError) as e:
raise BuildError(e) from e
-
-
-class _WebLocation(ABC):
- @property
- @abstractmethod
- def url(self) -> str: ...
-
-
-class _WebGitRepository(_WebLocation):
- def clone(
- self,
- target: _PathLike,
- depth: t.Optional[int] = None,
- branch: t.Optional[str] = None,
- ) -> None:
- depth_ = ("--depth", str(depth)) if depth is not None else ()
- branch_ = ("--branch", branch) if branch is not None else ()
- _git("clone", "-q", *depth_, *branch_, self.url, os.fspath(target))
-
-
-@t.final
-@dataclass(frozen=True)
-class _DLPackRepository(_WebGitRepository):
- version: str
-
- @staticmethod
- def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]:
- return (
- (OperatingSystem.LINUX, Architecture.X64),
- (OperatingSystem.DARWIN, Architecture.X64),
- (OperatingSystem.DARWIN, Architecture.ARM64),
- )
-
- @property
- def url(self) -> str:
- return ""
-
-
-class _WebArchive(_WebLocation):
- @property
- def name(self) -> str:
- _, name = self.url.rsplit("/", 1)
- return name
-
- def download(self, target: _PathLike) -> Path:
- target = Path(target)
- if target.is_dir():
- target = target / self.name
- file, _ = urllib.request.urlretrieve(self.url, target)
- return Path(file).resolve()
-
-
-class _ExtractableWebArchive(_WebArchive, ABC):
- @abstractmethod
- def _extract_download(self, download_path: Path, target: _PathLike) -> None: ...
-
- def extract(self, target: _PathLike) -> None:
- with tempfile.TemporaryDirectory() as tmp_dir:
- arch_path = self.download(tmp_dir)
- self._extract_download(arch_path, target)
-
-
-class _WebTGZ(_ExtractableWebArchive):
- def _extract_download(self, download_path: Path, target: _PathLike) -> None:
- with tarfile.open(download_path, "r") as tgz_file:
- tgz_file.extractall(target)
-
-
-class _WebZip(_ExtractableWebArchive):
- def _extract_download(self, download_path: Path, target: _PathLike) -> None:
- with zipfile.ZipFile(download_path, "r") as zip_file:
- zip_file.extractall(target)
-
-
-class WebTGZ(_WebTGZ):
- def __init__(self, url: str) -> None:
- self._url = url
-
- @property
- def url(self) -> str:
- return self._url
-
-
-@dataclass(frozen=True)
-class _PTArchive(_WebZip):
- architecture: Architecture
- device: Device
- version: str
- with_mkl: bool
-
- @staticmethod
- def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]:
- # TODO: This will need to be revisited if the inheritance tree gets deeper
- return tuple(
- itertools.chain.from_iterable(
- var.supported_platforms() for var in _PTArchive.__subclasses__()
- )
- )
-
- @staticmethod
- def _patch_out_mkl(libtorch_root: Path) -> None:
- _modify_source_files(
- libtorch_root / "share/cmake/Caffe2/public/mkl.cmake",
- r"find_package\(MKL QUIET\)",
- "# find_package(MKL QUIET)",
- )
-
- def extract(self, target: _PathLike) -> None:
- super().extract(target)
- if not self.with_mkl:
- self._patch_out_mkl(Path(target))
-
-
-@t.final
-class _PTArchiveLinux(_PTArchive):
- @staticmethod
- def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]:
- return ((OperatingSystem.LINUX, Architecture.X64),)
-
- @property
- def url(self) -> str:
- if self.device == Device.GPU:
- pt_build = "cu117"
- else:
- pt_build = Device.CPU.value
- # pylint: disable-next=line-too-long
- libtorch_archive = (
- f"libtorch-cxx11-abi-shared-without-deps-{self.version}%2B{pt_build}.zip"
- )
- return f"https://download.pytorch.org/libtorch/{pt_build}/{libtorch_archive}"
-
-
-@t.final
-class _PTArchiveMacOSX(_PTArchive):
- @staticmethod
- def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]:
- return (
- (OperatingSystem.DARWIN, Architecture.ARM64),
- (OperatingSystem.DARWIN, Architecture.X64),
- )
-
- @property
- def url(self) -> str:
- if self.architecture == Architecture.X64:
- pt_build = Device.CPU.value
- libtorch_archive = f"libtorch-macos-{self.version}.zip"
- root_url = "https://download.pytorch.org/libtorch"
- return f"{root_url}/{pt_build}/{libtorch_archive}"
- if self.architecture == Architecture.ARM64:
- libtorch_archive = f"libtorch-macos-arm64-{self.version}.zip"
- # pylint: disable-next=line-too-long
- root_url = (
- "https://github.com/CrayLabs/ml_lib_builder/releases/download/v0.1/"
- )
- return f"{root_url}/{libtorch_archive}"
-
- raise BuildError(f"Unsupported architecture for Pytorch: {self.architecture}")
-
-
-def _choose_pt_variant(
- os_: OperatingSystem,
-) -> t.Union[t.Type[_PTArchiveLinux], t.Type[_PTArchiveMacOSX]]:
- if os_ == OperatingSystem.DARWIN:
- return _PTArchiveMacOSX
- if os_ == OperatingSystem.LINUX:
- return _PTArchiveLinux
-
- raise BuildError(f"Unsupported OS for PyTorch: {os_}")
-
-
-@t.final
-@dataclass(frozen=True)
-class _TFArchive(_WebTGZ):
- os_: OperatingSystem
- architecture: Architecture
- device: Device
- version: str
-
- @staticmethod
- def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]:
- return (
- (OperatingSystem.LINUX, Architecture.X64),
- (OperatingSystem.DARWIN, Architecture.X64),
- )
-
- @property
- def url(self) -> str:
- if self.architecture == Architecture.X64:
- tf_arch = "x86_64"
- else:
- raise BuildError(
- f"Unexpected Architecture for TF Archive: {self.architecture}"
- )
-
- if self.os_ == OperatingSystem.LINUX:
- tf_os = "linux"
- tf_device = self.device
- elif self.os_ == OperatingSystem.DARWIN:
- tf_os = "darwin"
- tf_device = Device.CPU
- else:
- raise BuildError(f"Unexpected OS for TF Archive: {self.os_}")
- return (
- "https://storage.googleapis.com/tensorflow/libtensorflow/"
- f"libtensorflow-{tf_device.value}-{tf_os}-{tf_arch}-{self.version}.tar.gz"
- )
-
-
-@t.final
-@dataclass(frozen=True)
-class _ORTArchive(_WebTGZ):
- os_: OperatingSystem
- device: Device
- version: str
-
- @staticmethod
- def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]:
- return (
- (OperatingSystem.LINUX, Architecture.X64),
- (OperatingSystem.DARWIN, Architecture.X64),
- )
-
- @property
- def url(self) -> str:
- ort_url_base = (
- "https://github.com/microsoft/onnxruntime/releases/"
- f"download/v{self.version}"
- )
- if self.os_ == OperatingSystem.LINUX:
- ort_os = "linux"
- ort_arch = "x64"
- ort_build = "-gpu" if self.device == Device.GPU else ""
- elif self.os_ == OperatingSystem.DARWIN:
- ort_os = "osx"
- ort_arch = "x86_64"
- ort_build = ""
- else:
- raise BuildError(f"Unexpected OS for TF Archive: {self.os_}")
- ort_archive = f"onnxruntime-{ort_os}-{ort_arch}{ort_build}-{self.version}.tgz"
- return f"{ort_url_base}/{ort_archive}"
-
-
-def _git(*args: str) -> None:
- git = Builder.binary_path("git")
- cmd = (git,) + args
- with subprocess.Popen(cmd) as proc:
- proc.wait()
- if proc.returncode != 0:
- raise BuildError(
- f"Command `{' '.join(cmd)}` failed with exit code {proc.returncode}"
- )
-
-
-def config_git_command(plat: Platform, cmd: t.Sequence[str]) -> t.List[str]:
- """Modify git commands to include autocrlf when on a platform that needs
- autocrlf enabled to behave correctly
- """
- cmd = list(cmd)
- where = next((i for i, tok in enumerate(cmd) if tok.endswith("git")), len(cmd)) + 2
- if where >= len(cmd):
- raise ValueError(f"Failed to locate git command in '{' '.join(cmd)}'")
- if plat == Platform(OperatingSystem.DARWIN, Architecture.ARM64):
- cmd = (
- cmd[:where]
- + ["--config", "core.autocrlf=false", "--config", "core.eol=lf"]
- + cmd[where:]
- )
- return cmd
-
-
-def _modify_source_files(
- files: t.Union[_PathLike, t.Iterable[_PathLike]], regex: str, replacement: str
-) -> None:
- compiled_regex = re.compile(regex)
- with fileinput.input(files=files, inplace=True) as handles:
- for line in handles:
- line = compiled_regex.sub(replacement, line)
- print(line, end="")
diff --git a/smartsim/_core/_install/configs/mlpackages/DarwinARM64CPU.json b/smartsim/_core/_install/configs/mlpackages/DarwinARM64CPU.json
new file mode 100644
index 0000000000..5109cf376c
--- /dev/null
+++ b/smartsim/_core/_install/configs/mlpackages/DarwinARM64CPU.json
@@ -0,0 +1,59 @@
+{
+ "platform": {
+ "operating_system":"darwin",
+ "architecture":"arm64",
+ "device":"cpu"
+ },
+ "ml_packages": [
+ {
+ "name": "dlpack",
+ "version": "v0.5_RAI",
+ "pip_index": "",
+ "python_packages": [],
+ "lib_source": "https://github.com/RedisAI/dlpack.git"
+ },
+ {
+ "name": "libtorch",
+ "version": "2.4.0",
+ "pip_index": "",
+ "python_packages": [
+ "torch==2.4.0",
+ "torchvision==0.19.0",
+ "torchaudio==2.4.0"
+ ],
+ "lib_source": "https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.4.0.zip",
+ "rai_patches": [
+ {
+ "description": "Patch RedisAI module to require C++17 standard instead of C++14",
+ "source_file": "src/backends/libtorch_c/CMakeLists.txt",
+ "regex": "set_property\\(TARGET\\storch_c\\sPROPERTY\\sCXX_STANDARD\\s(98|11|14)\\)",
+ "replacement": "set_property(TARGET torch_c PROPERTY CXX_STANDARD 17)"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input inputs",
+ "replacement": "TF_Output inputs"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input port",
+ "replacement": "TF_Output port"
+ }
+ ]
+ },
+ {
+ "name": "onnxruntime",
+ "version": "1.17.3",
+ "pip_index": "",
+ "python_packages": [
+ "onnx==1.15",
+ "skl2onnx",
+ "scikit-learn",
+ "onnxmltools"
+ ],
+ "lib_source": "https://github.com/microsoft/onnxruntime/releases/download/v1.17.3/onnxruntime-osx-arm64-1.17.3.tgz"
+ }
+ ]
+}
diff --git a/smartsim/_core/_install/configs/mlpackages/DarwinX64CPU.json b/smartsim/_core/_install/configs/mlpackages/DarwinX64CPU.json
new file mode 100644
index 0000000000..06e30cbf8b
--- /dev/null
+++ b/smartsim/_core/_install/configs/mlpackages/DarwinX64CPU.json
@@ -0,0 +1,68 @@
+{
+ "platform": {
+ "operating_system":"darwin",
+ "architecture":"x86_64",
+ "device":"cpu"
+ },
+ "ml_packages": [
+ {
+ "name": "dlpack",
+ "version": "v0.5_RAI",
+ "pip_index": "",
+ "python_packages": [],
+ "lib_source": "https://github.com/RedisAI/dlpack.git"
+ },
+ {
+ "name": "libtorch",
+ "version": "2.2.2",
+ "pip_index": "",
+ "python_packages": [
+ "torch==2.2.2",
+ "torchvision==0.17.2",
+ "torchaudio==2.2.2"
+ ],
+ "lib_source": "https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.2.zip",
+ "rai_patches": [
+ {
+ "description": "Patch RedisAI module to require C++17 standard instead of C++14",
+ "source_file": "src/backends/libtorch_c/CMakeLists.txt",
+ "regex": "set_property\\(TARGET\\storch_c\\sPROPERTY\\sCXX_STANDARD\\s(98|11|14)\\)",
+ "replacement": "set_property(TARGET torch_c PROPERTY CXX_STANDARD 17)"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input inputs",
+ "replacement": "TF_Output inputs"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input port",
+ "replacement": "TF_Output port"
+ }
+ ]
+ },
+ {
+ "name": "libtensorflow",
+ "version": "2.15",
+ "pip_index": "",
+ "python_packages": [
+ "tensorflow==2.15"
+ ],
+ "lib_source": "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.15.0.tar.gz"
+ },
+ {
+ "name": "onnxruntime",
+ "version": "1.17.3",
+ "pip_index": "",
+ "python_packages": [
+ "onnx==1.15",
+ "skl2onnx",
+ "scikit-learn",
+ "onnxmltools"
+ ],
+ "lib_source": "https://github.com/microsoft/onnxruntime/releases/download/v1.17.3/onnxruntime-osx-x86_64-1.17.3.tgz"
+ }
+ ]
+}
diff --git a/smartsim/_core/_install/configs/mlpackages/LinuxX64CPU.json b/smartsim/_core/_install/configs/mlpackages/LinuxX64CPU.json
new file mode 100644
index 0000000000..2b1224df46
--- /dev/null
+++ b/smartsim/_core/_install/configs/mlpackages/LinuxX64CPU.json
@@ -0,0 +1,68 @@
+{
+ "platform": {
+ "operating_system":"linux",
+ "architecture":"x86_64",
+ "device":"cpu"
+ },
+ "ml_packages": [
+ {
+ "name": "dlpack",
+ "version": "v0.5_RAI",
+ "pip_index": "",
+ "python_packages": [],
+ "lib_source": "https://github.com/RedisAI/dlpack.git"
+ },
+ {
+ "name": "libtorch",
+ "version": "2.4.0",
+ "pip_index": "https://download.pytorch.org/whl/cpu",
+ "python_packages": [
+ "torch==2.4.0+cpu",
+ "torchvision==0.19.0+cpu",
+ "torchaudio==2.4.0+cpu"
+ ],
+ "lib_source": "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcpu.zip",
+ "rai_patches": [
+ {
+ "description": "Patch RedisAI module to require C++17 standard instead of C++14",
+ "source_file": "src/backends/libtorch_c/CMakeLists.txt",
+ "regex": "set_property\\(TARGET\\storch_c\\sPROPERTY\\sCXX_STANDARD\\s(98|11|14)\\)",
+ "replacement": "set_property(TARGET torch_c PROPERTY CXX_STANDARD 17)"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input inputs",
+ "replacement": "TF_Output inputs"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input port",
+ "replacement": "TF_Output port"
+ }
+ ]
+ },
+ {
+ "name": "libtensorflow",
+ "version": "2.15",
+ "pip_index": "",
+ "python_packages": [
+ "tensorflow==2.15"
+ ],
+ "lib_source": "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.15.0.tar.gz"
+ },
+ {
+ "name": "onnxruntime",
+ "version": "1.17.3",
+ "pip_index": "",
+ "python_packages": [
+ "onnx<=1.15",
+ "skl2onnx",
+ "scikit-learn",
+ "onnxmltools"
+ ],
+ "lib_source": "https://github.com/microsoft/onnxruntime/releases/download/v1.17.3/onnxruntime-linux-x64-1.17.3.tgz"
+ }
+ ]
+}
diff --git a/smartsim/_core/_install/configs/mlpackages/LinuxX64CUDA11.json b/smartsim/_core/_install/configs/mlpackages/LinuxX64CUDA11.json
new file mode 100644
index 0000000000..30d9cbf516
--- /dev/null
+++ b/smartsim/_core/_install/configs/mlpackages/LinuxX64CUDA11.json
@@ -0,0 +1,68 @@
+{
+ "platform": {
+ "operating_system":"linux",
+ "architecture":"x86_64",
+ "device":"cuda-11"
+ },
+ "ml_packages": [
+ {
+ "name": "dlpack",
+ "version": "v0.5_RAI",
+ "pip_index": "",
+ "python_packages": [],
+ "lib_source": "https://github.com/RedisAI/dlpack.git"
+ },
+ {
+ "name": "libtorch",
+ "version": "2.3.1",
+ "pip_index": "https://download.pytorch.org/whl/cu118",
+ "python_packages": [
+ "torch==2.3.1+cu118",
+ "torchvision==0.18.1+cu118",
+ "torchaudio==2.3.1+cu118"
+ ],
+ "lib_source": "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.3.1%2Bcu118.zip",
+ "rai_patches": [
+ {
+ "description": "Patch RedisAI module to require C++17 standard instead of C++14",
+ "source_file": "src/backends/libtorch_c/CMakeLists.txt",
+ "regex": "set_property\\(TARGET\\storch_c\\sPROPERTY\\sCXX_STANDARD\\s(98|11|14)\\)",
+ "replacement": "set_property(TARGET torch_c PROPERTY CXX_STANDARD 17)"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input inputs",
+ "replacement": "TF_Output inputs"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input port",
+ "replacement": "TF_Output port"
+ }
+ ]
+ },
+ {
+ "name": "libtensorflow",
+ "version": "2.14.1",
+ "pip_index": "",
+ "python_packages": [
+ "tensorflow==2.14.1"
+ ],
+ "lib_source": "https://github.com/CrayLabs/ml_lib_builder/releases/download/v0.2/libtensorflow-2.14.1-linux-x64-cuda-11.8.0.tgz"
+ },
+ {
+ "name": "onnxruntime",
+ "version": "1.17.3",
+ "pip_index": "",
+ "python_packages": [
+ "onnx==1.15",
+ "skl2onnx",
+ "scikit-learn",
+ "onnxmltools"
+ ],
+ "lib_source": "https://github.com/microsoft/onnxruntime/releases/download/v1.17.3/onnxruntime-linux-x64-gpu-1.17.3.tgz"
+ }
+ ]
+}
diff --git a/smartsim/_core/_install/configs/mlpackages/LinuxX64CUDA12.json b/smartsim/_core/_install/configs/mlpackages/LinuxX64CUDA12.json
new file mode 100644
index 0000000000..a8bf330b4f
--- /dev/null
+++ b/smartsim/_core/_install/configs/mlpackages/LinuxX64CUDA12.json
@@ -0,0 +1,76 @@
+{
+ "platform": {
+ "operating_system":"linux",
+ "architecture":"x86_64",
+ "device":"cuda-12"
+ },
+ "ml_packages": [
+ {
+ "name": "dlpack",
+ "version": "v0.5_RAI",
+ "pip_index": "",
+ "python_packages": [],
+ "lib_source": "https://github.com/RedisAI/dlpack.git"
+ },
+ {
+ "name": "libtorch",
+ "version": "2.3.1",
+ "pip_index": "https://download.pytorch.org/whl/cu121",
+ "python_packages": [
+ "torch==2.3.1+cu121",
+ "torchvision==0.18.1+cu121",
+ "torchaudio==2.3.1+cu121"
+ ],
+ "lib_source": "https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.3.1%2Bcu121.zip",
+ "rai_patches": [
+ {
+ "description": "Patch RedisAI module to require C++17 standard instead of C++14",
+ "source_file": "src/backends/libtorch_c/CMakeLists.txt",
+ "regex": "set_property\\(TARGET\\storch_c\\sPROPERTY\\sCXX_STANDARD\\s(98|11|14)\\)",
+ "replacement": "set_property(TARGET torch_c PROPERTY CXX_STANDARD 17)"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input inputs",
+ "replacement": "TF_Output inputs"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input port",
+ "replacement": "TF_Output port"
+ }
+ ]
+ },
+ {
+ "name": "libtensorflow",
+ "version": "2.15",
+ "pip_index": "",
+ "python_packages": [
+ "tensorflow==2.15"
+ ],
+ "lib_source": "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.15.0.tar.gz",
+ "rai_patches": [
+ {
+ "description": "Patch RedisAI to point to correct tsl directory",
+ "source_file": "CMakeLists.txt",
+ "regex": "INCLUDE_DIRECTORIES\\(\\$\\{depsAbs\\}/libtensorflow/include\\)",
+ "replacement": "INCLUDE_DIRECTORIES(${depsAbs}/libtensorflow/include ${depsAbs}/libtensorflow/include/external/local_tsl)"
+ }
+ ]
+ },
+ {
+ "name": "onnxruntime",
+ "version": "1.17.3",
+ "pip_index": "",
+ "python_packages": [
+ "onnx==1.15",
+ "skl2onnx",
+ "scikit-learn",
+ "onnxmltools"
+ ],
+ "lib_source": "https://github.com/microsoft/onnxruntime/releases/download/v1.17.3/onnxruntime-linux-x64-gpu-cuda12-1.17.3.tgz"
+ }
+ ]
+}
diff --git a/smartsim/_core/_install/configs/mlpackages/LinuxX64ROCM6.json b/smartsim/_core/_install/configs/mlpackages/LinuxX64ROCM6.json
new file mode 100644
index 0000000000..ba3c9a0bfb
--- /dev/null
+++ b/smartsim/_core/_install/configs/mlpackages/LinuxX64ROCM6.json
@@ -0,0 +1,59 @@
+{
+ "platform": {
+ "operating_system":"linux",
+ "architecture":"x86_64",
+ "device":"rocm-6"
+ },
+ "ml_packages": [
+ {
+ "name": "dlpack",
+ "version": "v0.5_RAI",
+ "pip_index": "",
+ "python_packages": [],
+ "lib_source": "https://github.com/RedisAI/dlpack.git"
+ },
+ {
+ "name": "libtorch",
+ "version": "2.4.0",
+ "pip_index": "https://download.pytorch.org/whl/rocm6.1",
+ "python_packages": [
+ "torch==2.4.0+rocm6.1",
+ "torchvision==0.19.0+rocm6.1",
+ "torchaudio==2.4.0+rocm6.1"
+ ],
+ "lib_source": "https://download.pytorch.org/libtorch/rocm6.1/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Brocm6.1.zip",
+ "rai_patches": [
+ {
+ "description": "Patch RedisAI module to require C++17 standard instead of C++14",
+ "source_file": "src/backends/libtorch_c/CMakeLists.txt",
+ "regex": "set_property\\(TARGET\\storch_c\\sPROPERTY\\sCXX_STANDARD\\s(98|11|14)\\)",
+ "replacement": "set_property(TARGET torch_c PROPERTY CXX_STANDARD 17)"
+ },
+ {
+ "description": "Fix Regex, Load HIP",
+ "source_file": "../package/libtorch/share/cmake/Caffe2/public/LoadHIP.cmake",
+ "regex": ".*string.*",
+ "replacement": ""
+ },
+ {
+ "description": "Replace `/opt/rocm` with `$ENV{ROCM_PATH}`",
+ "source_file": "../package/libtorch/share/cmake/Caffe2/Caffe2Targets.cmake",
+ "regex": "/opt/rocm",
+ "replacement": "$ENV{ROCM_PATH}"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input inputs",
+ "replacement": "TF_Output inputs"
+ },
+ {
+ "description": "Fix the type in a Tensorflow function signature",
+ "source_file": "src/backends/tensorflow.c",
+ "regex": "TF_Input port",
+ "replacement": "TF_Output port"
+ }
+ ]
+ }
+ ]
+}
diff --git a/smartsim/_core/_install/mlpackages.py b/smartsim/_core/_install/mlpackages.py
new file mode 100644
index 0000000000..04e3798d35
--- /dev/null
+++ b/smartsim/_core/_install/mlpackages.py
@@ -0,0 +1,198 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import os
+import pathlib
+import re
+import subprocess
+import sys
+import typing as t
+from collections.abc import MutableMapping
+from dataclasses import dataclass
+
+from tabulate import tabulate
+
+from .platform import Platform
+from .types import PathLike
+from .utils import retrieve
+
+
+class RequireRelativePath(Exception):
+ pass
+
+
+@dataclass
+class RAIPatch:
+ """Holds information about how to patch a RedisAI source file
+
+ :param description: Human-readable description of the patch's purpose
+ :param replacement: "The replacement for the line found by the regex"
+ :param source_file: A relative path to the chosen file
+ :param regex: A regex pattern to match in the given file
+
+ """
+
+ description: str
+ replacement: str
+ source_file: pathlib.Path
+ regex: re.Pattern[str]
+
+ def __post_init__(self) -> None:
+ self.source_file = pathlib.Path(self.source_file)
+ self.regex = re.compile(self.regex)
+
+
+@dataclass
+class MLPackage:
+ """Describes the python and C/C++ library for an ML package"""
+
+ name: str
+ version: str
+ pip_index: str
+ python_packages: t.List[str]
+ lib_source: PathLike
+ rai_patches: t.Tuple[RAIPatch, ...] = ()
+
+ def retrieve(self, destination: PathLike) -> None:
+ """Retrieve an archive and/or repository for the package
+
+ :param destination: Path to place the extracted package or repository
+ """
+ retrieve(self.lib_source, pathlib.Path(destination))
+
+ def pip_install(self, quiet: bool = False) -> None:
+ """Install associated python packages
+
+ :param quiet: If True, suppress most of the pip output, defaults to False
+ """
+ if self.python_packages:
+ install_command = [sys.executable, "-m", "pip", "install"]
+ if self.pip_index:
+ install_command += ["--index-url", self.pip_index]
+ if quiet:
+ install_command += ["--quiet", "--no-warn-conflicts"]
+ install_command += self.python_packages
+ subprocess.check_call(install_command)
+
+
+class MLPackageCollection(MutableMapping[str, MLPackage]):
+ """Collects multiple MLPackages
+
+ Define a collection of MLPackages available for a specific platform
+ """
+
+ def __init__(self, platform: Platform, ml_packages: t.Sequence[MLPackage]):
+ self.platform = platform
+ self._ml_packages = {pkg.name: pkg for pkg in ml_packages}
+
+ @classmethod
+ def from_json_file(cls, json_file: PathLike) -> "MLPackageCollection":
+ """Create an MLPackageCollection specified from a JSON file
+
+ :param json_file: path to the JSON file
+ :return: An instance of MLPackageCollection for a platform
+ """
+ with open(json_file, "r", encoding="utf-8") as file_handle:
+ config_json = json.load(file_handle)
+ platform = Platform.from_strs(**config_json["platform"])
+
+ for ml_package in config_json["ml_packages"]:
+ # Convert the dictionary representation to a RAIPatch
+ if "rai_patches" in ml_package:
+ patch_list = ml_package.pop("rai_patches")
+ ml_package["rai_patches"] = [RAIPatch(**patch) for patch in patch_list]
+
+ ml_packages = [
+ MLPackage(**ml_package) for ml_package in config_json["ml_packages"]
+ ]
+ return cls(platform, ml_packages)
+
+ def __iter__(self) -> t.Iterator[str]:
+ """Iterate over the mlpackages in the collection
+
+ :return: Iterator over mlpackages
+ """
+ return iter(self._ml_packages)
+
+ def __getitem__(self, key: str) -> MLPackage:
+ """Retrieve an MLPackage based on its name
+
+ :param key: Name of the python package (e.g. libtorch)
+ :return: MLPackage with all requirements
+ """
+ return self._ml_packages[key]
+
+ def __len__(self) -> int:
+ return len(self._ml_packages)
+
+ def __delitem__(self, key: str) -> None:
+ del self._ml_packages[key]
+
+ def __setitem__(self, key: t.Any, value: t.Any) -> t.NoReturn:
+ raise TypeError(f"{type(self).__name__} does not support item assignment")
+
+ def __contains__(self, key: object) -> bool:
+ return key in self._ml_packages
+
+ def __str__(self, tablefmt: str = "github") -> str:
+ """Display package names and versions as a table
+
+ :param tablefmt: Tabulate format, defaults to "github"
+ """
+
+ return tabulate(
+ [[k, v.version] for k, v in self._ml_packages.items()],
+ headers=["Package", "Version"],
+ tablefmt=tablefmt,
+ )
+
+
+def load_platform_configs(
+ config_file_path: pathlib.Path,
+) -> t.Dict[Platform, MLPackageCollection]:
+ """Create MLPackageCollections from JSON files in directory
+
+ :param config_file_path: Directory with JSON files describing the
+ configuration by platform
+ :return: Dictionary whose keys are the supported platform and values
+ are its associated MLPackageCollection
+ """
+ if not config_file_path.is_dir():
+ path = os.fspath(config_file_path)
+ msg = f"Platform configuration directory `{path}` does not exist"
+ raise FileNotFoundError(msg)
+ configs = {}
+ for config_file in config_file_path.glob("*.json"):
+ dependencies = MLPackageCollection.from_json_file(config_file)
+ configs[dependencies.platform] = dependencies
+ return configs
+
+
+DEFAULT_MLPACKAGE_PATH: t.Final = (
+ pathlib.Path(__file__).parent / "configs" / "mlpackages"
+)
+DEFAULT_MLPACKAGES: t.Final = load_platform_configs(DEFAULT_MLPACKAGE_PATH)
diff --git a/smartsim/_core/_install/platform.py b/smartsim/_core/_install/platform.py
new file mode 100644
index 0000000000..bef13c6a0a
--- /dev/null
+++ b/smartsim/_core/_install/platform.py
@@ -0,0 +1,226 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import enum
+import json
+import os
+import pathlib
+import platform
+import typing as t
+from dataclasses import dataclass
+
+from typing_extensions import Self
+
+
+class PlatformError(Exception):
+ pass
+
+
+class UnsupportedError(PlatformError):
+ pass
+
+
+class Architecture(enum.Enum):
+ """Identifiers for supported CPU architectures
+
+ :return: An enum representing the CPU architecture
+ """
+
+ X64 = "x86_64"
+ ARM64 = "arm64"
+
+ @classmethod
+ def from_str(cls, string: str) -> "Architecture":
+ """Return enum associated with the architecture
+
+ :param string: String representing the architecture, see platform.machine
+ :return: Enum for a specific architecture
+ """
+ string = string.lower()
+ return cls(string)
+
+ @classmethod
+ def autodetect(cls) -> "Architecture":
+ """Automatically return the architecture of the current machine
+
+ :return: enum of this platform's architecture
+ """
+ return cls.from_str(platform.machine())
+
+
+class Device(enum.Enum):
+ """Identifiers for the device stack
+
+ :return: Enum associated with the device stack
+ """
+
+ CPU = "cpu"
+ CUDA11 = "cuda-11"
+ CUDA12 = "cuda-12"
+ ROCM5 = "rocm-5"
+ ROCM6 = "rocm-6"
+
+ @classmethod
+ def from_str(cls, str_: str) -> "Device":
+ """Return enum associated with the device
+
+ :param string: String representing the device and version
+ :return: Enum for a specific device
+ """
+ str_ = str_.lower()
+ if str_ == "gpu":
+ # TODO: auto detect which device to use
+ # currently hard coded to `cuda11`
+ return cls.CUDA11
+ return cls(str_)
+
+ @classmethod
+ def detect_cuda_version(cls) -> t.Optional["Device"]:
+ """Find the enum based on environment CUDA
+
+ :return: Enum for the version of CUDA currently available
+ """
+ if cuda_home := os.environ.get("CUDA_HOME"):
+ cuda_path = pathlib.Path(cuda_home)
+ with open(cuda_path / "version.json", "r", encoding="utf-8") as file_handle:
+ cuda_versions = json.load(file_handle)
+ major = cuda_versions["cuda"]["version"].split(".")[0]
+ return cls.from_str(f"cuda-{major}")
+ return None
+
+ @classmethod
+ def detect_rocm_version(cls) -> t.Optional["Device"]:
+ """Find the enum based on environment ROCm
+
+ :return: Enum for the version of ROCm currently available
+ """
+ if rocm_home := os.environ.get("ROCM_HOME"):
+ rocm_path = pathlib.Path(rocm_home)
+ fname = rocm_path / ".info" / "version"
+ with open(fname, "r", encoding="utf-8") as file_handle:
+ major = file_handle.readline().split("-")[0].split(".")[0]
+ return cls.from_str(f"rocm-{major}")
+ return None
+
+ def is_gpu(self) -> bool:
+ """Whether the enum is categorized as a GPU
+
+ :return: True if GPU
+ """
+ return self != type(self).CPU
+
+ def is_cuda(self) -> bool:
+ """Whether the enum is associated with a CUDA device
+
+ :return: True for any supported CUDA enums
+ """
+ cls = type(self)
+ return self in cls.cuda_enums()
+
+ def is_rocm(self) -> bool:
+ """Whether the enum is associated with a ROCm device
+
+ :return: True for any supported ROCm enums
+ """
+ cls = type(self)
+ return self in cls.rocm_enums()
+
+ @classmethod
+ def cuda_enums(cls) -> t.Tuple["Device", ...]:
+ """Detect all CUDA devices supported by SmartSim
+
+ :return: all enums associated with CUDA
+ """
+ return tuple(device for device in cls if "cuda" in device.value)
+
+ @classmethod
+ def rocm_enums(cls) -> t.Tuple["Device", ...]:
+ """Detect all ROCm devices supported by SmartSim
+
+ :return: all enums associated with ROCm
+ """
+ return tuple(device for device in cls if "rocm" in device.value)
+
+
+class OperatingSystem(enum.Enum):
+ """Enum for all supported operating systems"""
+
+ LINUX = "linux"
+ DARWIN = "darwin"
+
+ @classmethod
+ def from_str(cls, string: str, /) -> "OperatingSystem":
+ """Return enum associated with the OS
+
+ :param string: String representing the OS
+ :return: Enum for a specific OS
+ """
+ string = string.lower()
+ return cls(string)
+
+ @classmethod
+ def autodetect(cls) -> "OperatingSystem":
+ """Automatically return the OS of the current machine
+
+ :return: enum of this platform's OS
+ """
+ return cls.from_str(platform.system())
+
+
+@dataclass(frozen=True)
+class Platform:
+ """Container describing relevant identifiers for a platform"""
+
+ operating_system: OperatingSystem
+ architecture: Architecture
+ device: Device
+
+ @classmethod
+ def from_strs(cls, operating_system: str, architecture: str, device: str) -> Self:
+ """Factory method for Platform from string onput
+
+ :param os: String identifier for the OS
+ :param architecture: String identifier for the architecture
+ :param device: String identifer for the device and version
+ :return: Instance of Platform
+ """
+ return cls(
+ OperatingSystem.from_str(operating_system),
+ Architecture.from_str(architecture),
+ Device.from_str(device),
+ )
+
+ def __str__(self) -> str:
+ """Human-readable representation of Platform
+
+ :return: String created from the values of the enums for each property
+ """
+ output = [
+ self.operating_system.name,
+ self.architecture.name,
+ self.device.name,
+ ]
+ return "-".join(output)
diff --git a/smartsim/_core/_install/redisaiBuilder.py b/smartsim/_core/_install/redisaiBuilder.py
new file mode 100644
index 0000000000..1dce6ddb45
--- /dev/null
+++ b/smartsim/_core/_install/redisaiBuilder.py
@@ -0,0 +1,301 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import fileinput
+import os
+import pathlib
+import shutil
+import stat
+import subprocess
+import typing as t
+from collections import deque
+
+from smartsim._core._cli.utils import SMART_LOGGER_FORMAT
+from smartsim._core._install.buildenv import BuildEnv
+from smartsim._core._install.mlpackages import MLPackageCollection, RAIPatch
+from smartsim._core._install.platform import OperatingSystem, Platform
+from smartsim._core._install.utils import retrieve
+from smartsim._core.config import CONFIG
+from smartsim.log import get_logger
+
+logger = get_logger("Smart", fmt=SMART_LOGGER_FORMAT)
+_SUPPORTED_ROCM_ARCH = "gfx90a"
+
+
+class RedisAIBuildError(Exception):
+ pass
+
+
+class RedisAIBuilder:
+ """Class to build RedisAI from Source"""
+
+ def __init__(
+ self,
+ platform: Platform,
+ mlpackages: MLPackageCollection,
+ build_env: BuildEnv,
+ main_build_path: pathlib.Path,
+ verbose: bool = False,
+ source: t.Union[str, pathlib.Path] = "https://github.com/RedisAI/RedisAI.git",
+ version: str = "v1.2.7",
+ ) -> None:
+
+ self.platform = platform
+ self.mlpackages = mlpackages
+ self.build_env = build_env
+ self.verbose = verbose
+ self.source = source
+ self.version = version
+ self._root_path = main_build_path / "RedisAI"
+
+ self.cleanup_build()
+
+ @property
+ def src_path(self) -> pathlib.Path:
+ return pathlib.Path(self._root_path / "src")
+
+ @property
+ def build_path(self) -> pathlib.Path:
+ return pathlib.Path(self._root_path / "build")
+
+ @property
+ def package_path(self) -> pathlib.Path:
+ return pathlib.Path(self._root_path / "package")
+
+ def cleanup_build(self) -> None:
+ """Removes all directories associated with the build"""
+ shutil.rmtree(self.src_path, ignore_errors=True)
+ shutil.rmtree(self.build_path, ignore_errors=True)
+ shutil.rmtree(self.package_path, ignore_errors=True)
+
+ @property
+ def is_built(self) -> bool:
+ """Determine whether RedisAI and backends were built
+
+ :return: True if all backends and RedisAI module are in
+ the expected location
+ """
+ backend_dir = CONFIG.lib_path / "backends"
+ rai_exists = [
+ (backend_dir / f"redisai_{backend_name}").is_dir()
+ for backend_name in self.mlpackages
+ ]
+ rai_exists.append((CONFIG.lib_path / "redisai.so").is_file())
+ return all(rai_exists)
+
+ @property
+ def build_torch(self) -> bool:
+ """Whether to build torch backend
+
+ :return: True if torch backend should be built
+ """
+ return "libtorch" in self.mlpackages
+
+ @property
+ def build_tensorflow(self) -> bool:
+ """Whether to build tensorflow backend
+
+ :return: True if tensorflow backend should be built
+ """
+ return "libtensorflow" in self.mlpackages
+
+ @property
+ def build_onnxruntime(self) -> bool:
+ """Whether to build onnx backend
+
+ :return: True if onnx backend should be built
+ """
+ return "onnxruntime" in self.mlpackages
+
+ def build(self) -> None:
+ """Build RedisAI
+
+ :param git_url: url from which to retrieve RedisAI
+ :param branch: branch to checkout
+ :param device: cpu or gpu
+ """
+
+ # Following is needed to make sure that the clone/checkout is not
+ # impeded by git LFS limits imposed by RedisAI
+ os.environ["GIT_LFS_SKIP_SMUDGE"] = "1"
+
+ self.src_path.mkdir(parents=True)
+ self.build_path.mkdir(parents=True)
+ self.package_path.mkdir(parents=True)
+
+ retrieve(self.source, self.src_path, depth=1, branch=self.version)
+
+ self._prepare_packages()
+
+ for package in self.mlpackages.values():
+ self._patch_source_files(package.rai_patches)
+ cmake_command = self._rai_cmake_cmd()
+ build_command = self._rai_build_cmd
+
+ if self.platform.device.is_rocm() and "libtorch" in self.mlpackages:
+ pytorch_rocm_arch = os.environ.get("PYTORCH_ROCM_ARCH")
+ if not pytorch_rocm_arch:
+ logger.info(
+ f"PYTORCH_ROCM_ARCH not set. Defaulting to '{_SUPPORTED_ROCM_ARCH}'"
+ )
+ os.environ["PYTORCH_ROCM_ARCH"] = _SUPPORTED_ROCM_ARCH
+ elif pytorch_rocm_arch != _SUPPORTED_ROCM_ARCH:
+ logger.warning(
+ f"PYTORCH_ROCM_ARCH is not {_SUPPORTED_ROCM_ARCH} which is the "
+ "only officially supported architecture. This may still work "
+ "if you are supplying your own version of libtensorflow."
+ )
+
+ logger.info("Configuring CMake Build")
+ if self.verbose:
+ print(" ".join(cmake_command))
+ self.run_command(cmake_command, self.build_path)
+
+ logger.info("Building RedisAI")
+ if self.verbose:
+ print(" ".join(build_command))
+ self.run_command(build_command, self.build_path)
+
+ if self.platform.operating_system == OperatingSystem.LINUX:
+ self._set_execute(CONFIG.lib_path / "redisai.so")
+
+ @staticmethod
+ def _set_execute(target: pathlib.Path) -> None:
+ """Set execute permissions for file
+
+ :param target: The target file to add execute permission
+ """
+ permissions = os.stat(target).st_mode | stat.S_IXUSR
+ os.chmod(target, permissions)
+
+ @staticmethod
+ def _find_closest_object(
+ start_path: pathlib.Path, target_obj: str
+ ) -> t.Optional[pathlib.Path]:
+ queue = deque([start_path])
+ while queue:
+ current_dir = queue.popleft()
+ current_target = current_dir / target_obj
+ if current_target.exists():
+ return current_target.parent
+ for sub_dir in current_dir.iterdir():
+ if sub_dir.is_dir():
+ queue.append(sub_dir)
+ return None
+
+ def _prepare_packages(self) -> None:
+ """Ensure that retrieved archives/packages are in the expected location
+
+ RedisAI requires that the root directory of the backend is at
+ DEP_PATH/example_backend. Due to difficulties in retrieval methods and
+ naming conventions from different sources, this cannot be standardized.
+ Instead we try to find the parent of the "include" directory and assume
+ this is the root.
+ """
+
+ for package in self.mlpackages.values():
+ logger.info(f"Retrieving package: {package.name} {package.version}")
+ target_dir = self.package_path / package.name
+ package.retrieve(target_dir)
+ # Move actual contents to root of the expected location
+ actual_root = self._find_closest_object(target_dir, "include")
+ if actual_root and actual_root != target_dir:
+ logger.debug(
+ (
+ "Non-standard location found: \n",
+ f"{actual_root} -> {target_dir}",
+ )
+ )
+ for file in actual_root.iterdir():
+ file.rename(target_dir / file.name)
+
+ def run_command(self, cmd: t.Union[str, t.List[str]], cwd: pathlib.Path) -> None:
+ """Executor of commands usedi in the build
+
+ :param cmd: The actual command to execute
+ :param cwd: The working directory to execute in
+ """
+ stdout = None if self.verbose else subprocess.DEVNULL
+ stderr = None if self.verbose else subprocess.PIPE
+ proc = subprocess.run(
+ cmd, cwd=str(cwd), stdout=stdout, stderr=stderr, check=False
+ )
+ if proc.returncode != 0:
+ if stderr:
+ print(proc.stderr.decode("utf-8"))
+ raise RedisAIBuildError(
+ f"RedisAI build failed during command: {' '.join(cmd)}"
+ )
+
+ def _rai_cmake_cmd(self) -> t.List[str]:
+ """Build the CMake configuration command
+
+ :return: CMake command with correct options
+ """
+
+ def on_off(expression: bool) -> t.Literal["ON", "OFF"]:
+ return "ON" if expression else "OFF"
+
+ cmake_args = {
+ "BUILD_TF": on_off(self.build_tensorflow),
+ "BUILD_ORT": on_off(self.build_onnxruntime),
+ "BUILD_TORCH": on_off(self.build_torch),
+ "BUILD_TFLITE": "OFF",
+ "DEPS_PATH": str(self.package_path),
+ "DEVICE": "gpu" if self.platform.device.is_gpu() else "cpu",
+ "INSTALL_PATH": str(CONFIG.lib_path),
+ "CMAKE_C_COMPILER": self.build_env.CC,
+ "CMAKE_CXX_COMPILER": self.build_env.CXX,
+ }
+ if self.platform.device.is_rocm():
+ cmake_args["Torch_DIR"] = str(self.package_path / "libtorch")
+ cmd = ["cmake"]
+ cmd += (f"-D{key}={value}" for key, value in cmake_args.items())
+ cmd.append(str(self.src_path))
+ return cmd
+
+ @property
+ def _rai_build_cmd(self) -> t.List[str]:
+ """Shell command to build RedisAI and modules
+
+ With the CMake based install, very little needs to be done here.
+ "make install" is used to ensure that all resulting RedisAI backends
+ and their dependencies end up in the same location with the correct
+ RPATH if applicable.
+
+ :return: Command used to compile RedisAI and backends
+ """
+ return "make install -j VERBOSE=1".split(" ")
+
+ def _patch_source_files(self, patches: t.Tuple[RAIPatch, ...]) -> None:
+ """Apply specified RedisAI patches"""
+ for patch in patches:
+ with fileinput.input(
+ str(self.src_path / patch.source_file), inplace=True
+ ) as file_handle:
+ for line in file_handle:
+ line = patch.regex.sub(patch.replacement, line)
+ print(line, end="")
diff --git a/smartsim/_core/_install/types.py b/smartsim/_core/_install/types.py
new file mode 100644
index 0000000000..0266ace341
--- /dev/null
+++ b/smartsim/_core/_install/types.py
@@ -0,0 +1,30 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pathlib
+import typing as t
+
+PathLike = t.Union[str, pathlib.Path]
diff --git a/smartsim/_core/_install/utils/__init__.py b/smartsim/_core/_install/utils/__init__.py
new file mode 100644
index 0000000000..4e47cf282b
--- /dev/null
+++ b/smartsim/_core/_install/utils/__init__.py
@@ -0,0 +1,27 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from .retrieve import retrieve
diff --git a/smartsim/_core/_install/utils/retrieve.py b/smartsim/_core/_install/utils/retrieve.py
new file mode 100644
index 0000000000..fcac565d4b
--- /dev/null
+++ b/smartsim/_core/_install/utils/retrieve.py
@@ -0,0 +1,185 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+import pathlib
+import shutil
+import tarfile
+import typing as t
+import zipfile
+from urllib.parse import urlparse
+from urllib.request import urlretrieve
+
+import git
+from tqdm import tqdm
+
+from smartsim._core._install.platform import Architecture, OperatingSystem
+from smartsim._core._install.types import PathLike
+
+
+class UnsupportedArchive(Exception):
+ pass
+
+
+class _TqdmUpTo(tqdm): # type: ignore[type-arg]
+ """Provides `update_to(n)` which uses `tqdm.update(delta_n)`
+
+ From tqdm doumentation for progress bar when downloading
+ """
+
+ def update_to(
+ self, num_blocks: int = 1, bsize: int = 1, tsize: t.Optional[int] = None
+ ) -> t.Optional[bool]:
+ """Update progress in tqdm-like way
+
+ :param b: number of blocks transferred so far, defaults to 1
+ :param bsize: size of each block (in tqdm units), defaults to 1
+ :param tsize: total size (in tqdm units), defaults to None
+ :return: Update
+ """
+
+ if tsize is not None:
+ self.total = tsize
+ return self.update(num_blocks * bsize - self.n) # also sets self.n = b * bsize
+
+
+def _from_local_archive(
+ source: PathLike,
+ destination: pathlib.Path,
+ **kwargs: t.Any,
+) -> None:
+ """Decompress a local archive
+
+ :param source: Path to the archive on a local system
+ :param destination: Where to unpack the archive
+ """
+ if tarfile.is_tarfile(source):
+ with tarfile.open(source) as archive:
+ archive.extractall(path=destination, **kwargs)
+ if zipfile.is_zipfile(source):
+ with zipfile.ZipFile(source) as archive:
+ archive.extractall(path=destination, **kwargs)
+
+
+def _from_local_directory(
+ source: PathLike,
+ destination: pathlib.Path,
+ **kwargs: t.Any,
+) -> None:
+ """Copy the contents of a directory
+
+ :param source: source directory
+ :param destination: desitnation directory
+ """
+ shutil.copytree(source, destination, **kwargs)
+
+
+def _from_http(
+ source: str,
+ destination: pathlib.Path,
+ **kwargs: t.Any,
+) -> None:
+ """Download and decompress a package
+
+ :param source: URL to a particular package
+ :param destination: Where to unpack the archive
+ """
+ with _TqdmUpTo(
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ miniters=1,
+ desc=source.split("/")[-1],
+ ) as _t: # all optional kwargs
+ local_file, _ = urlretrieve(source, reporthook=_t.update_to, **kwargs)
+ _t.total = _t.n
+
+ _from_local_archive(local_file, destination)
+ os.remove(local_file)
+
+
+def _from_git(source: str, destination: pathlib.Path, **clone_kwargs: t.Any) -> None:
+ """Clone a repository
+
+ :param source: Path to the remote (URL or local) repository
+ :param destination: where to clone the repository
+ :param clone_kwargs: various options to send to the clone command
+ """
+ is_mac = OperatingSystem.autodetect() == OperatingSystem.DARWIN
+ is_arm64 = Architecture.autodetect() == Architecture.ARM64
+ if is_mac and is_arm64:
+ config_options = ["--config core.autocrlf=false", "--config core.eol=lf"]
+ allow_unsafe_options = True
+ else:
+ config_options = None
+ allow_unsafe_options = False
+ git.Repo.clone_from(
+ source,
+ destination,
+ multi_options=config_options,
+ allow_unsafe_options=allow_unsafe_options,
+ **clone_kwargs,
+ )
+
+
+def retrieve(
+ source: PathLike, destination: pathlib.Path, **retrieve_kwargs: t.Any
+) -> None:
+ """Primary method for retrieval
+
+ Automatically choose the correct method based on the extension and/or source
+ of the archive. If downloaded, this will also decompress the archive and
+ extract
+
+ :param source: URL or path to find the package
+ :param destination: where to place the package
+ :raises UnsupportedArchive: Unknown archive type
+ :raises FileNotFound: Path to archive does not exist
+ """
+ parsed_url = urlparse(str(source))
+ url_scheme = parsed_url.scheme
+ if parsed_url.path.endswith(".git"):
+ _from_git(str(source), destination, **retrieve_kwargs)
+ elif url_scheme == "http":
+ _from_http(str(source), destination, **retrieve_kwargs)
+ elif url_scheme == "https":
+ _from_http(str(source), destination, **retrieve_kwargs)
+ else: # This is probably a path
+ source_path = pathlib.Path(source)
+ if not source_path.exists():
+ raise FileNotFoundError(f"Package path or file does not exist: {source}")
+ if source_path.is_dir():
+ _from_local_directory(source, destination, **retrieve_kwargs)
+ elif source_path.is_file() and source_path.suffix in (
+ ".gz",
+ ".zip",
+ ".tgz",
+ ):
+ _from_local_archive(source, destination, **retrieve_kwargs)
+ else:
+ raise UnsupportedArchive(
+ f"Source ({source}) is not a supported archive or directory "
+ )
diff --git a/smartsim/_core/commands/command_list.py b/smartsim/_core/commands/command_list.py
index 9554776e8d..d3d6eace4d 100644
--- a/smartsim/_core/commands/command_list.py
+++ b/smartsim/_core/commands/command_list.py
@@ -83,8 +83,8 @@ def __setitem__(
isinstance(item, str) for item in sublist.command
):
raise TypeError(
- "Value sublists must be a list of Commands when \
-assigning to a slice"
+ "Value sublists must be a list of Commands when assigning \
+to a slice"
)
self._commands[idx] = (deepcopy(val) for val in value)
diff --git a/smartsim/_core/config/config.py b/smartsim/_core/config/config.py
index af4bca6a79..478ab02da3 100644
--- a/smartsim/_core/config/config.py
+++ b/smartsim/_core/config/config.py
@@ -32,6 +32,7 @@
import psutil
+
# Configuration Values
#
# These values can be set through environment variables to
@@ -94,11 +95,14 @@ def database_file_parse_interval(self) -> int:
@property
def dragon_dotenv(self) -> Path:
"""Returns the path to a .env file containing dragon environment variables"""
- return self.conf_dir / "dragon" / ".env"
+ return Path(self.conf_dir / "dragon" / ".env")
@property
def dragon_server_path(self) -> t.Optional[str]:
- return os.getenv("SMARTSIM_DRAGON_SERVER_PATH", None)
+ return os.getenv(
+ "SMARTSIM_DRAGON_SERVER_PATH",
+ os.getenv("_SMARTSIM_DRAGON_SERVER_PATH_EXP", None),
+ )
@property
def dragon_server_timeout(self) -> int:
@@ -225,10 +229,6 @@ def smartsim_key_path(self) -> str:
default_path = Path.home() / ".smartsim" / "keys"
return os.environ.get("SMARTSIM_KEY_PATH", str(default_path))
- @property
- def dragon_pin(self) -> str:
- return "0.9"
-
@lru_cache(maxsize=128, typed=False)
def get_config() -> Config:
diff --git a/smartsim/_core/control/job.py b/smartsim/_core/control/job.py
index 5cf5aea8b6..bb8ab31ea5 100644
--- a/smartsim/_core/control/job.py
+++ b/smartsim/_core/control/job.py
@@ -76,8 +76,8 @@ def __init__(self) -> None:
@property
def is_fs(self) -> bool:
- """Returns `True` if the entity represents a feature store or
- feature store shard"""
+ """Returns `True` if the entity represents a feature store or feature
+ store shard"""
return self.type in ["featurestore", "fsnode"]
@property
diff --git a/smartsim/_core/entrypoints/service.py b/smartsim/_core/entrypoints/service.py
new file mode 100644
index 0000000000..719c2a60fe
--- /dev/null
+++ b/smartsim/_core/entrypoints/service.py
@@ -0,0 +1,185 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import datetime
+import time
+import typing as t
+from abc import ABC, abstractmethod
+
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class Service(ABC):
+ """Core API for standalone entrypoint scripts. Makes use of overridable hook
+ methods to modify behaviors (event loop, automatic shutdown, cooldown) as
+ well as simple hooks for status changes"""
+
+ def __init__(
+ self,
+ as_service: bool = False,
+ cooldown: float = 0,
+ loop_delay: float = 0,
+ health_check_frequency: float = 0,
+ ) -> None:
+ """Initialize the Service
+
+ :param as_service: Determines lifetime of the service. When `True`, calling
+ execute on the service will run continuously until shutdown criteria are met.
+ Otherwise, `execute` performs a single pass through the service lifecycle and
+ automatically exits (regardless of the result of `_can_shutdown`).
+ :param cooldown: Period of time (in seconds) to allow the service to run
+ after a shutdown is permitted. Enables the service to avoid restarting if
+ new work is discovered. A value of 0 disables the cooldown.
+ :param loop_delay: Duration (in seconds) of a forced delay between
+ iterations of the event loop
+ :param health_check_frequency: Time (in seconds) between calls to a
+ health check handler. A value of 0 triggers the health check on every
+ iteration.
+ """
+ self._as_service = as_service
+ """Determines lifetime of the service. When `True`, calling
+ `execute` on the service will run continuously until shutdown criteria are met.
+ Otherwise, `execute` performs a single pass through the service lifecycle and
+ automatically exits (regardless of the result of `_can_shutdown`)."""
+ self._cooldown = abs(cooldown)
+ """Period of time (in seconds) to allow the service to run
+ after a shutdown is permitted. Enables the service to avoid restarting if
+ new work is discovered. A value of 0 disables the cooldown."""
+ self._loop_delay = abs(loop_delay)
+ """Duration (in seconds) of a forced delay between
+ iterations of the event loop"""
+ self._health_check_frequency = health_check_frequency
+ """Time (in seconds) between calls to a
+ health check handler. A value of 0 triggers the health check on every
+ iteration."""
+ self._last_health_check = time.time()
+ """The timestamp of the latest health check"""
+
+ @abstractmethod
+ def _on_iteration(self) -> None:
+ """The user-defined event handler. Executed repeatedly until shutdown
+ conditions are satisfied and cooldown is elapsed.
+ """
+
+ @abstractmethod
+ def _can_shutdown(self) -> bool:
+ """Return true when the criteria to shut down the service are met."""
+
+ def _on_start(self) -> None:
+ """Empty hook method for use by subclasses. Called on initial entry into
+ Service `execute` event loop before `_on_iteration` is invoked."""
+ logger.debug(f"Starting {self.__class__.__name__}")
+
+ def _on_shutdown(self) -> None:
+ """Empty hook method for use by subclasses. Called immediately after exiting
+ the main event loop during automatic shutdown."""
+ logger.debug(f"Shutting down {self.__class__.__name__}")
+
+ def _on_health_check(self) -> None:
+ """Empty hook method for use by subclasses. Invoked based on the
+ value of `self._health_check_frequency`."""
+ logger.debug(f"Performing health check for {self.__class__.__name__}")
+
+ def _on_cooldown_elapsed(self) -> None:
+ """Empty hook method for use by subclasses. Called on every event loop
+ iteration immediately upon exceeding the cooldown period"""
+ logger.debug(f"Cooldown exceeded by {self.__class__.__name__}")
+
+ def _on_delay(self) -> None:
+ """Empty hook method for use by subclasses. Called on every event loop
+ iteration immediately before executing a delay before the next iteration"""
+ logger.debug(f"Service iteration waiting for {self.__class__.__name__}s")
+
+ def _log_cooldown(self, elapsed: float) -> None:
+ """Log the remaining cooldown time, if any"""
+ remaining = self._cooldown - elapsed
+ if remaining > 0:
+ logger.debug(f"{abs(remaining):.2f}s remains of {self._cooldown}s cooldown")
+ else:
+ logger.info(f"exceeded cooldown {self._cooldown}s by {abs(remaining):.2f}s")
+
+ def execute(self) -> None:
+ """The main event loop of a service host. Evaluates shutdown criteria and
+ combines with a cooldown period to allow automatic service termination.
+ Responsible for executing calls to subclass implementation of `_on_iteration`"""
+
+ try:
+ self._on_start()
+ except Exception:
+ logger.exception("Unable to start service.")
+ return
+
+ running = True
+ cooldown_start: t.Optional[datetime.datetime] = None
+
+ while running:
+ try:
+ self._on_iteration()
+ except Exception:
+ running = False
+ logger.exception(
+ "Failure in event loop resulted in service termination"
+ )
+
+ if self._health_check_frequency >= 0:
+ hc_elapsed = time.time() - self._last_health_check
+ if hc_elapsed >= self._health_check_frequency:
+ self._on_health_check()
+ self._last_health_check = time.time()
+
+ # allow immediate shutdown if not set to run as a service
+ if not self._as_service:
+ running = False
+ continue
+
+ # reset cooldown period if shutdown criteria are not met
+ if not self._can_shutdown():
+ cooldown_start = None
+
+ # start tracking cooldown elapsed once eligible to quit
+ if cooldown_start is None:
+ cooldown_start = datetime.datetime.now()
+
+ # change running state if cooldown period is exceeded
+ if self._cooldown > 0:
+ elapsed = datetime.datetime.now() - cooldown_start
+ running = elapsed.total_seconds() < self._cooldown
+ self._log_cooldown(elapsed.total_seconds())
+ if not running:
+ self._on_cooldown_elapsed()
+ elif self._cooldown < 1 and self._can_shutdown():
+ running = False
+
+ if self._loop_delay:
+ self._on_delay()
+ time.sleep(self._loop_delay)
+
+ try:
+ self._on_shutdown()
+ except Exception:
+ logger.exception("Service shutdown may not have completed.")
diff --git a/smartsim/_core/launcher/dragon/dragon_backend.py b/smartsim/_core/launcher/dragon/dragon_backend.py
index 7d77aaaacc..82863d73b5 100644
--- a/smartsim/_core/launcher/dragon/dragon_backend.py
+++ b/smartsim/_core/launcher/dragon/dragon_backend.py
@@ -26,6 +26,8 @@
import collections
import functools
import itertools
+import os
+import socket
import time
import typing as t
from dataclasses import dataclass, field
@@ -34,15 +36,27 @@
from tabulate import tabulate
-# pylint: disable=import-error
+# pylint: disable=import-error,C0302,R0915
# isort: off
+
import dragon.infrastructure.connection as dragon_connection
import dragon.infrastructure.policy as dragon_policy
-import dragon.native.group_state as dragon_group_state
+import dragon.infrastructure.process_desc as dragon_process_desc
+
import dragon.native.process as dragon_process
import dragon.native.process_group as dragon_process_group
import dragon.native.machine as dragon_machine
+from smartsim._core.launcher.dragon.pqueue import NodePrioritizer, PrioritizerFilter
+from smartsim._core.mli.infrastructure.control.listener import (
+ ConsumerRegistrationListener,
+)
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.dragon_util import create_ddict
+from smartsim.error.errors import SmartSimError
+
# pylint: enable=import-error
# isort: on
from ....log import get_logger
@@ -68,8 +82,8 @@
class DragonStatus(str, Enum):
- ERROR = str(dragon_group_state.Error())
- RUNNING = str(dragon_group_state.Running())
+ ERROR = "Error"
+ RUNNING = "Running"
def __str__(self) -> str:
return self.value
@@ -86,7 +100,7 @@ class ProcessGroupInfo:
return_codes: t.Optional[t.List[int]] = None
"""List of return codes of completed processes"""
hosts: t.List[str] = field(default_factory=list)
- """List of hosts on which the Process Group """
+ """List of hosts on which the Process Group should be executed"""
redir_workers: t.Optional[dragon_process_group.ProcessGroup] = None
"""Workers used to redirect stdout and stderr to file"""
@@ -143,6 +157,11 @@ class DragonBackend:
by threads spawned by it.
"""
+ _DEFAULT_NUM_MGR_PER_NODE = 2
+ """The default number of manager processes for each feature store node"""
+ _DEFAULT_MEM_PER_NODE = 512 * 1024**2
+ """The default memory capacity (in bytes) to allocate for a feaure store node"""
+
def __init__(self, pid: int) -> None:
self._pid = pid
"""PID of dragon executable which launched this server"""
@@ -153,7 +172,6 @@ def __init__(self, pid: int) -> None:
self._step_ids = (f"{create_short_id_str()}-{id}" for id in itertools.count())
"""Incremental ID to assign to new steps prior to execution"""
- self._initialize_hosts()
self._queued_steps: "collections.OrderedDict[str, DragonRunRequest]" = (
collections.OrderedDict()
)
@@ -177,16 +195,26 @@ def __init__(self, pid: int) -> None:
"""Whether the server frontend should shut down when the backend does"""
self._shutdown_initiation_time: t.Optional[float] = None
"""The time at which the server initiated shutdown"""
- smartsim_config = get_config()
- self._cooldown_period = (
- smartsim_config.telemetry_frequency * 2 + 5
- if smartsim_config.telemetry_enabled
- else 5
- )
- """Time in seconds needed to server to complete shutdown"""
+ self._cooldown_period = self._initialize_cooldown()
+ """Time in seconds needed by the server to complete shutdown"""
+ self._backbone: t.Optional[BackboneFeatureStore] = None
+ """The backbone feature store"""
+ self._listener: t.Optional[dragon_process.Process] = None
+ """The standalone process executing the event consumer"""
+
+ self._nodes: t.List["dragon_machine.Node"] = []
+ """Node capability information for hosts in the allocation"""
+ self._hosts: t.List[str] = []
+ """List of hosts available in allocation"""
+ self._cpus: t.List[int] = []
+ """List of cpu-count by node"""
+ self._gpus: t.List[int] = []
+ """List of gpu-count by node"""
+ self._allocated_hosts: t.Dict[str, t.Set[str]] = {}
+ """Mapping with hostnames as keys and a set of running step IDs as the value"""
- self._view = DragonBackendView(self)
- logger.debug(self._view.host_desc)
+ self._initialize_hosts()
+ self._prioritizer = NodePrioritizer(self._nodes, self._queue_lock)
@property
def hosts(self) -> list[str]:
@@ -194,34 +222,39 @@ def hosts(self) -> list[str]:
return self._hosts
@property
- def allocated_hosts(self) -> dict[str, str]:
+ def allocated_hosts(self) -> dict[str, t.Set[str]]:
+ """A map of host names to the step id executing on a host
+
+ :returns: Dictionary with host name as key and step id as value"""
with self._queue_lock:
return self._allocated_hosts
@property
- def free_hosts(self) -> t.Deque[str]:
+ def free_hosts(self) -> t.Sequence[str]:
+ """Find hosts that do not have a step assigned
+
+ :returns: List of host names"""
with self._queue_lock:
- return self._free_hosts
+ return list(map(lambda x: x.hostname, self._prioritizer.unassigned()))
@property
def group_infos(self) -> dict[str, ProcessGroupInfo]:
+ """Find information pertaining to process groups executing on a host
+
+ :returns: Dictionary with host name as key and group information as value"""
with self._queue_lock:
return self._group_infos
def _initialize_hosts(self) -> None:
+ """Prepare metadata about the allocation"""
with self._queue_lock:
self._nodes = [
dragon_machine.Node(node) for node in dragon_machine.System().nodes
]
- self._hosts: t.List[str] = sorted(node.hostname for node in self._nodes)
+ self._hosts = sorted(node.hostname for node in self._nodes)
self._cpus = [node.num_cpus for node in self._nodes]
self._gpus = [node.num_gpus for node in self._nodes]
-
- """List of hosts available in allocation"""
- self._free_hosts: t.Deque[str] = collections.deque(self._hosts)
- """List of hosts on which steps can be launched"""
- self._allocated_hosts: t.Dict[str, str] = {}
- """Mapping of hosts on which a step is already running to step ID"""
+ self._allocated_hosts = collections.defaultdict(set)
def __str__(self) -> str:
return self.status_message
@@ -230,21 +263,19 @@ def __str__(self) -> str:
def status_message(self) -> str:
"""Message with status of available nodes and history of launched jobs.
- :returns: Status message
+ :returns: a status message
"""
- return (
- "Dragon server backend update\n"
- f"{self._view.host_table}\n{self._view.step_table}"
- )
+ view = DragonBackendView(self)
+ return "Dragon server backend update\n" f"{view.host_table}\n{view.step_table}"
def _heartbeat(self) -> None:
+ """Update the value of the last heartbeat to the current time."""
self._last_beat = self.current_time
@property
def cooldown_period(self) -> int:
- """Time (in seconds) the server will wait before shutting down
-
- when exit conditions are met (see ``should_shutdown()`` for further details).
+ """Time (in seconds) the server will wait before shutting down when
+ exit conditions are met (see ``should_shutdown()`` for further details).
"""
return self._cooldown_period
@@ -278,6 +309,8 @@ def should_shutdown(self) -> bool:
and it requested immediate shutdown, or if it did not request immediate
shutdown, but all jobs have been executed.
In both cases, a cooldown period may need to be waited before shutdown.
+
+ :returns: `True` if the server should terminate, otherwise `False`
"""
if self._shutdown_requested and self._can_shutdown:
return self._has_cooled_down
@@ -285,7 +318,9 @@ def should_shutdown(self) -> bool:
@property
def current_time(self) -> float:
- """Current time for DragonBackend object, in seconds since the Epoch"""
+ """Current time for DragonBackend object, in seconds since the Epoch
+
+ :returns: the current timestamp"""
return time.time()
def _can_honor_policy(
@@ -293,63 +328,149 @@ def _can_honor_policy(
) -> t.Tuple[bool, t.Optional[str]]:
"""Check if the policy can be honored with resources available
in the allocation.
- :param request: DragonRunRequest containing policy information
+
+ :param request: `DragonRunRequest` to validate
:returns: Tuple indicating if the policy can be honored and
an optional error message"""
# ensure the policy can be honored
if request.policy:
+ logger.debug(f"{request.policy=}{self._cpus=}{self._gpus=}")
+
if request.policy.cpu_affinity:
# make sure some node has enough CPUs
- available = max(self._cpus)
+ last_available = max(self._cpus or [-1])
requested = max(request.policy.cpu_affinity)
-
- if requested >= available:
+ if not any(self._cpus) or requested >= last_available:
return False, "Cannot satisfy request, not enough CPUs available"
-
if request.policy.gpu_affinity:
# make sure some node has enough GPUs
- available = max(self._gpus)
+ last_available = max(self._gpus or [-1])
requested = max(request.policy.gpu_affinity)
-
- if requested >= available:
+ if not any(self._gpus) or requested >= last_available:
+ logger.warning(
+ f"failed check w/{self._gpus=}, {requested=}, {last_available=}"
+ )
return False, "Cannot satisfy request, not enough GPUs available"
-
return True, None
def _can_honor(self, request: DragonRunRequest) -> t.Tuple[bool, t.Optional[str]]:
- """Check if request can be honored with resources available in the allocation.
-
- Currently only checks for total number of nodes,
- in the future it will also look at other constraints
- such as memory, accelerators, and so on.
+ """Check if request can be honored with resources available in
+ the allocation. Currently only checks for total number of nodes,
+ in the future it will also look at other constraints such as memory,
+ accelerators, and so on.
+
+ :param request: `DragonRunRequest` to validate
+ :returns: Tuple indicating if the request can be honored and
+ an optional error message
"""
- if request.nodes > len(self._hosts):
- message = f"Cannot satisfy request. Requested {request.nodes} nodes, "
- message += f"but only {len(self._hosts)} nodes are available."
- return False, message
- if self._shutdown_requested:
- message = "Cannot satisfy request, server is shutting down."
- return False, message
+ honorable, err = self._can_honor_state(request)
+ if not honorable:
+ return False, err
honorable, err = self._can_honor_policy(request)
if not honorable:
return False, err
+ honorable, err = self._can_honor_hosts(request)
+ if not honorable:
+ return False, err
+
+ return True, None
+
+ def _can_honor_hosts(
+ self, request: DragonRunRequest
+ ) -> t.Tuple[bool, t.Optional[str]]:
+ """Check if the current state of the backend process inhibits executing
+ the request.
+
+ :param request: `DragonRunRequest` to validate
+ :returns: Tuple indicating if the request can be honored and
+ an optional error message"""
+ all_hosts = frozenset(self._hosts)
+ num_nodes = request.nodes
+
+ # fail if requesting more nodes than the total number available
+ if num_nodes > len(all_hosts):
+ message = f"Cannot satisfy request. {num_nodes} requested nodes"
+ message += f" exceeds {len(all_hosts)} available."
+ return False, message
+
+ requested_hosts = all_hosts
+ if request.hostlist:
+ requested_hosts = frozenset(
+ {host.strip() for host in request.hostlist.split(",")}
+ )
+
+ valid_hosts = all_hosts.intersection(requested_hosts)
+ invalid_hosts = requested_hosts - valid_hosts
+
+ logger.debug(f"{num_nodes=}{valid_hosts=}{invalid_hosts=}")
+
+ if invalid_hosts:
+ logger.warning(f"Some invalid hostnames were requested: {invalid_hosts}")
+
+ # fail if requesting specific hostnames and there aren't enough available
+ if num_nodes > len(valid_hosts):
+ message = f"Cannot satisfy request. Requested {num_nodes} nodes, "
+ message += f"but only {len(valid_hosts)} named hosts are available."
+ return False, message
+
+ return True, None
+
+ def _can_honor_state(
+ self, _request: DragonRunRequest
+ ) -> t.Tuple[bool, t.Optional[str]]:
+ """Check if the current state of the backend process inhibits executing
+ the request.
+ :param _request: the DragonRunRequest to verify
+ :returns: Tuple indicating if the request can be honored and
+ an optional error message"""
+ if self._shutdown_requested:
+ message = "Cannot satisfy request, server is shutting down."
+ return False, message
+
return True, None
def _allocate_step(
self, step_id: str, request: DragonRunRequest
) -> t.Optional[t.List[str]]:
+ """Identify the hosts on which the request will be executed
+ :param step_id: The identifier of a step that will be executed on the host
+ :param request: The request to be executed
+ :returns: A list of selected hostnames"""
+ # ensure at least one host is selected
num_hosts: int = request.nodes
+
with self._queue_lock:
- if num_hosts <= 0 or num_hosts > len(self._free_hosts):
+ if num_hosts <= 0 or num_hosts > len(self._hosts):
+ logger.debug(
+ f"The number of requested hosts ({num_hosts}) is invalid or"
+ f" cannot be satisfied with {len(self._hosts)} available nodes"
+ )
return None
- to_allocate = []
- for _ in range(num_hosts):
- host = self._free_hosts.popleft()
- self._allocated_hosts[host] = step_id
- to_allocate.append(host)
+
+ hosts = []
+ if request.hostlist:
+ # convert the comma-separated argument into a real list
+ hosts = [host for host in request.hostlist.split(",") if host]
+
+ filter_on: t.Optional[PrioritizerFilter] = None
+ if request.policy and request.policy.gpu_affinity:
+ filter_on = PrioritizerFilter.GPU
+
+ nodes = self._prioritizer.next_n(num_hosts, filter_on, step_id, hosts)
+
+ if len(nodes) < num_hosts:
+ # exit if the prioritizer can't identify enough nodes
+ return None
+
+ to_allocate = [node.hostname for node in nodes]
+
+ for hostname in to_allocate:
+ # track assigning this step to each node
+ self._allocated_hosts[hostname].add(step_id)
+
return to_allocate
@staticmethod
@@ -389,6 +510,7 @@ def _create_redirect_workers(
return grp_redir
def _stop_steps(self) -> None:
+ """Trigger termination of all currently executing steps"""
self._heartbeat()
with self._queue_lock:
while len(self._stop_requests) > 0:
@@ -427,18 +549,96 @@ def _stop_steps(self) -> None:
self._group_infos[step_id].status = JobStatus.CANCELLED
self._group_infos[step_id].return_codes = [-9]
+ def _create_backbone(self) -> BackboneFeatureStore:
+ """
+ Creates a BackboneFeatureStore if one does not exist. Updates
+ environment variables of this process to include the backbone
+ descriptor.
+
+ :returns: The backbone feature store
+ """
+ if self._backbone is None:
+ backbone_storage = create_ddict(
+ len(self._hosts),
+ self._DEFAULT_NUM_MGR_PER_NODE,
+ self._DEFAULT_MEM_PER_NODE,
+ )
+
+ self._backbone = BackboneFeatureStore(
+ backbone_storage, allow_reserved_writes=True
+ )
+
+ # put the backbone descriptor in the env vars
+ os.environ.update(self._backbone.get_env())
+
+ return self._backbone
+
+ @staticmethod
+ def _initialize_cooldown() -> int:
+ """Load environment configuration and determine the correct cooldown
+ period to apply to the backend process.
+
+ :returns: The calculated cooldown (in seconds)
+ """
+ smartsim_config = get_config()
+ return (
+ smartsim_config.telemetry_frequency * 2 + 5
+ if smartsim_config.telemetry_enabled
+ else 5
+ )
+
+ def start_event_listener(
+ self, cpu_affinity: list[int], gpu_affinity: list[int]
+ ) -> dragon_process.Process:
+ """Start a standalone event listener.
+
+ :param cpu_affinity: The CPU affinity for the process
+ :param gpu_affinity: The GPU affinity for the process
+ :returns: The dragon Process managing the process
+ :raises SmartSimError: If the backbone is not provided
+ """
+ if self._backbone is None:
+ raise SmartSimError("Backbone feature store is not available")
+
+ service = ConsumerRegistrationListener(
+ self._backbone, 1.0, 2.0, as_service=True, health_check_frequency=90
+ )
+
+ options = dragon_process_desc.ProcessOptions(make_inf_channels=True)
+ local_policy = dragon_policy.Policy(
+ placement=dragon_policy.Policy.Placement.HOST_NAME,
+ host_name=socket.gethostname(),
+ cpu_affinity=cpu_affinity,
+ gpu_affinity=gpu_affinity,
+ )
+ process = dragon_process.Process(
+ target=service.execute,
+ args=[],
+ cwd=os.getcwd(),
+ env={
+ **os.environ,
+ **self._backbone.get_env(),
+ },
+ policy=local_policy,
+ options=options,
+ stderr=dragon_process.Popen.STDOUT,
+ stdout=dragon_process.Popen.STDOUT,
+ )
+ process.start()
+ return process
+
@staticmethod
def create_run_policy(
request: DragonRequest, node_name: str
) -> "dragon_policy.Policy":
"""Create a dragon Policy from the request and node name
+
:param request: DragonRunRequest containing policy information
:param node_name: Name of the node on which the process will run
:returns: dragon_policy.Policy object mapped from request properties"""
if isinstance(request, DragonRunRequest):
run_request: DragonRunRequest = request
- affinity = dragon_policy.Policy.Affinity.DEFAULT
cpu_affinity: t.List[int] = []
gpu_affinity: t.List[int] = []
@@ -446,25 +646,20 @@ def create_run_policy(
if run_request.policy is not None:
# Affinities are not mutually exclusive. If specified, both are used
if run_request.policy.cpu_affinity:
- affinity = dragon_policy.Policy.Affinity.SPECIFIC
cpu_affinity = run_request.policy.cpu_affinity
if run_request.policy.gpu_affinity:
- affinity = dragon_policy.Policy.Affinity.SPECIFIC
gpu_affinity = run_request.policy.gpu_affinity
logger.debug(
- f"Affinity strategy: {affinity}, "
f"CPU affinity mask: {cpu_affinity}, "
f"GPU affinity mask: {gpu_affinity}"
)
- if affinity != dragon_policy.Policy.Affinity.DEFAULT:
- return dragon_policy.Policy(
- placement=dragon_policy.Policy.Placement.HOST_NAME,
- host_name=node_name,
- affinity=affinity,
- cpu_affinity=cpu_affinity,
- gpu_affinity=gpu_affinity,
- )
+ return dragon_policy.Policy(
+ placement=dragon_policy.Policy.Placement.HOST_NAME,
+ host_name=node_name,
+ cpu_affinity=cpu_affinity,
+ gpu_affinity=gpu_affinity,
+ )
return dragon_policy.Policy(
placement=dragon_policy.Policy.Placement.HOST_NAME,
@@ -472,7 +667,9 @@ def create_run_policy(
)
def _start_steps(self) -> None:
+ """Start all new steps created since the last update."""
self._heartbeat()
+
with self._queue_lock:
started = []
for step_id, request in self._queued_steps.items():
@@ -482,10 +679,8 @@ def _start_steps(self) -> None:
logger.debug(f"Step id {step_id} allocated on {hosts}")
- global_policy = dragon_policy.Policy(
- placement=dragon_policy.Policy.Placement.HOST_NAME,
- host_name=hosts[0],
- )
+ global_policy = self.create_run_policy(request, hosts[0])
+ options = dragon_process_desc.ProcessOptions(make_inf_channels=True)
grp = dragon_process_group.ProcessGroup(
restart=False, pmi_enabled=request.pmi_enabled, policy=global_policy
)
@@ -498,10 +693,15 @@ def _start_steps(self) -> None:
target=request.exe,
args=request.exe_args,
cwd=request.path,
- env={**request.current_env, **request.env},
+ env={
+ **request.current_env,
+ **request.env,
+ **(self._backbone.get_env() if self._backbone else {}),
+ },
stdout=dragon_process.Popen.PIPE,
stderr=dragon_process.Popen.PIPE,
policy=local_policy,
+ options=options,
)
grp.add_process(nproc=request.tasks_per_node, template=tmp_proc)
@@ -567,9 +767,11 @@ def _start_steps(self) -> None:
logger.error(e)
def _refresh_statuses(self) -> None:
+ """Query underlying management system for step status and update
+ stored assigned and unassigned task information"""
self._heartbeat()
with self._queue_lock:
- terminated = []
+ terminated: t.Set[str] = set()
for step_id in self._running_steps:
group_info = self._group_infos[step_id]
grp = group_info.process_group
@@ -603,11 +805,15 @@ def _refresh_statuses(self) -> None:
)
if group_info.status in TERMINAL_STATUSES:
- terminated.append(step_id)
+ terminated.add(step_id)
if terminated:
logger.debug(f"{terminated=}")
+ # remove all the terminated steps from all hosts
+ for host in list(self._allocated_hosts.keys()):
+ self._allocated_hosts[host].difference_update(terminated)
+
for step_id in terminated:
self._running_steps.remove(step_id)
self._completed_steps.append(step_id)
@@ -615,15 +821,20 @@ def _refresh_statuses(self) -> None:
if group_info is not None:
for host in group_info.hosts:
logger.debug(f"Releasing host {host}")
- try:
- self._allocated_hosts.pop(host)
- except KeyError:
+ if host not in self._allocated_hosts:
logger.error(f"Tried to free a non-allocated host: {host}")
- self._free_hosts.append(host)
+ else:
+ # remove any hosts that have had all their steps terminated
+ if not self._allocated_hosts[host]:
+ self._allocated_hosts.pop(host)
+ self._prioritizer.decrement(host, step_id)
group_info.process_group = None
group_info.redir_workers = None
def _update_shutdown_status(self) -> None:
+ """Query the status of running tasks and update the status
+ of any that have completed.
+ """
self._heartbeat()
with self._queue_lock:
self._can_shutdown |= (
@@ -637,12 +848,18 @@ def _update_shutdown_status(self) -> None:
)
def _should_print_status(self) -> bool:
+ """Determine if status messages should be printed based off the last
+ update. Returns `True` to trigger prints, `False` otherwise.
+ """
if self.current_time - self._last_update_time > 10:
self._last_update_time = self.current_time
return True
return False
def _update(self) -> None:
+ """Trigger all update queries and update local state database"""
+ self._create_backbone()
+
self._stop_steps()
self._start_steps()
self._refresh_statuses()
@@ -650,6 +867,9 @@ def _update(self) -> None:
def _kill_all_running_jobs(self) -> None:
with self._queue_lock:
+ if self._listener and self._listener.is_alive:
+ self._listener.kill()
+
for step_id, group_info in self._group_infos.items():
if group_info.status not in TERMINAL_STATUSES:
self._stop_requests.append(DragonStopRequest(step_id=step_id))
@@ -728,8 +948,14 @@ def _(self, request: DragonShutdownRequest) -> DragonShutdownResponse:
class DragonBackendView:
- def __init__(self, backend: DragonBackend):
+ def __init__(self, backend: DragonBackend) -> None:
+ """Initialize the instance
+
+ :param backend: A dragon backend used to produce the view"""
self._backend = backend
+ """A dragon backend used to produce the view"""
+
+ logger.debug(self.host_desc)
@property
def host_desc(self) -> str:
@@ -791,9 +1017,7 @@ def step_table(self) -> str:
@property
def host_table(self) -> str:
"""Table representation of current state of nodes available
-
- in the allocation.
- """
+ in the allocation."""
headers = ["Host", "Status"]
hosts = self._backend.hosts
free_hosts = self._backend.free_hosts
diff --git a/smartsim/_core/launcher/dragon/dragon_connector.py b/smartsim/_core/launcher/dragon/dragon_connector.py
index 7ff4cdc1c8..9c96592776 100644
--- a/smartsim/_core/launcher/dragon/dragon_connector.py
+++ b/smartsim/_core/launcher/dragon/dragon_connector.py
@@ -76,9 +76,11 @@ class DragonConnector:
def __init__(self, path: str | os.PathLike[str]) -> None:
self._context: zmq.Context[t.Any] = zmq.Context.instance()
+ """ZeroMQ context used to share configuration across requests"""
self._context.setsockopt(zmq.REQ_CORRELATE, 1)
self._context.setsockopt(zmq.REQ_RELAXED, 1)
self._authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None
+ """ZeroMQ authenticator used to secure queue access"""
config = get_config()
self._reset_timeout(config.dragon_server_timeout)
@@ -88,17 +90,21 @@ def __init__(self, path: str | os.PathLike[str]) -> None:
# fine as we expect the that method should only be called once
# without hitting a guard clause.
self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None
+ """ZeroMQ socket exposing the connection to the DragonBackend"""
self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None
+ """A handle to the process executing the DragonBackend"""
# Returned by dragon head, useful if shutdown is to be requested
# but process was started by another connector
self._dragon_head_pid: t.Optional[int] = None
+ """Process ID of the process executing the DragonBackend"""
self._dragon_server_path = _resolve_dragon_path(path)
+ """Path to a dragon installation"""
logger.debug(f"Dragon Server path was set to {self._dragon_server_path}")
self._env_vars: t.Dict[str, str] = {}
@property
def is_connected(self) -> bool:
- """Whether the Connector established a connection to the server
+ """Whether the Connector established a connection to the server.
:return: True if connected
"""
@@ -107,12 +113,18 @@ def is_connected(self) -> bool:
@property
def can_monitor(self) -> bool:
"""Whether the Connector knows the PID of the dragon server head process
- and can monitor its status
+ and can monitor its status.
:return: True if the server can be monitored"""
return self._dragon_head_pid is not None
def _handshake(self, address: str) -> None:
+ """Perform the handshake process with the DragonBackend and
+ confirm two-way communication is established.
+
+ :param address: The address of the head node socket to initiate a
+ handhake with
+ """
self._dragon_head_socket = dragon_sockets.get_secure_socket(
self._context, zmq.REQ, False
)
@@ -135,6 +147,11 @@ def _handshake(self, address: str) -> None:
) from e
def _reset_timeout(self, timeout: int = get_config().dragon_server_timeout) -> None:
+ """Reset the timeout applied to the ZMQ context. If an authenticator is
+ enabled, also update the authenticator timeouts.
+
+ :param timeout: The timeout value to apply to ZMQ sockets
+ """
self._context.setsockopt(zmq.SNDTIMEO, value=timeout)
self._context.setsockopt(zmq.RCVTIMEO, value=timeout)
if self._authenticator is not None and self._authenticator.thread is not None:
@@ -186,11 +203,19 @@ def _get_new_authenticator(
@staticmethod
def _get_dragon_log_level() -> str:
+ """Maps the log level from SmartSim to a valid log level
+ for a dragon process.
+
+ :returns: The dragon log level string
+ """
smartsim_to_dragon = defaultdict(lambda: "NONE")
smartsim_to_dragon["developer"] = "INFO"
return smartsim_to_dragon.get(get_config().log_level, "NONE")
def _connect_to_existing_server(self, path: Path) -> None:
+ """Connects to an existing DragonBackend using address information from
+ a persisted dragon log file.
+ """
config = get_config()
dragon_config_log = path / config.dragon_log_filename
@@ -220,6 +245,11 @@ def _connect_to_existing_server(self, path: Path) -> None:
return
def _start_connector_socket(self, socket_addr: str) -> zmq.Socket[t.Any]:
+ """Instantiate the ZMQ socket to be used by the connector.
+
+ :param socket_addr: The socket address the connector should bind to
+ :returns: The bound socket
+ """
config = get_config()
connector_socket: t.Optional[zmq.Socket[t.Any]] = None
self._reset_timeout(config.dragon_server_startup_timeout)
@@ -250,9 +280,14 @@ def load_persisted_env(self) -> t.Dict[str, str]:
with open(config.dragon_dotenv, encoding="utf-8") as dot_env:
for kvp in dot_env.readlines():
- split = kvp.strip().split("=", maxsplit=1)
- key, value = split[0], split[-1]
- self._env_vars[key] = value
+ if not kvp:
+ continue
+
+ # skip any commented lines
+ if not kvp.startswith("#"):
+ split = kvp.strip().split("=", maxsplit=1)
+ key, value = split[0], split[-1]
+ self._env_vars[key] = value
return self._env_vars
@@ -422,6 +457,15 @@ def send_request(self, request: DragonRequest, flags: int = 0) -> DragonResponse
def _parse_launched_dragon_server_info_from_iterable(
stream: t.Iterable[str], num_dragon_envs: t.Optional[int] = None
) -> t.List[t.Dict[str, str]]:
+ """Parses dragon backend connection information from a stream.
+
+ :param stream: The stream to inspect. Usually the stdout of the
+ DragonBackend process
+ :param num_dragon_envs: The expected number of dragon environments
+ to parse from the stream.
+ :returns: A list of dictionaries, one per environment, containing
+ the parsed server information
+ """
lines = (line.strip() for line in stream)
lines = (line for line in lines if line)
tokenized = (line.split(maxsplit=1) for line in lines)
@@ -448,6 +492,15 @@ def _parse_launched_dragon_server_info_from_files(
file_paths: t.List[t.Union[str, "os.PathLike[str]"]],
num_dragon_envs: t.Optional[int] = None,
) -> t.List[t.Dict[str, str]]:
+ """Read a known log file into a Stream and parse dragon server configuration
+ from the stream.
+
+ :param file_paths: Path to a file containing dragon server configuration
+ :num_dragon_envs: The expected number of dragon environments to be found
+ in the file
+ :returns: The parsed server configuration, one item per
+ discovered dragon environment
+ """
with fileinput.FileInput(file_paths) as ifstream:
dragon_envs = cls._parse_launched_dragon_server_info_from_iterable(
ifstream, num_dragon_envs
@@ -462,6 +515,15 @@ def _send_req_with_socket(
send_flags: int = 0,
recv_flags: int = 0,
) -> DragonResponse:
+ """Sends a synchronous request through a ZMQ socket.
+
+ :param socket: Socket to send on
+ :param request: The request to send
+ :param send_flags: Configuration to apply to the send operation
+ :param recv_flags: Configuration to apply to the recv operation; used to
+ allow the receiver to immediately respond to the sent request.
+ :returns: The response from the target
+ """
client = dragon_sockets.as_client(socket)
with DRG_LOCK:
logger.debug(f"Sending {type(request).__name__}: {request}")
@@ -473,6 +535,13 @@ def _send_req_with_socket(
def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT:
+ """Verify that objects can be sent as messages acceptable to the target.
+
+ :param obj: The message to test
+ :param typ: The type that is acceptable
+ :returns: The original `obj` if it is of the requested type
+ :raises TypeError: If the object fails the test and is not
+ an instance of the desired type"""
if not isinstance(obj, typ):
raise TypeError(f"Expected schema of type `{typ}`, but got {type(obj)}")
return obj
diff --git a/smartsim/_core/launcher/dragon/dragon_launcher.py b/smartsim/_core/launcher/dragon/dragon_launcher.py
index 5e36b8a3fd..4af93b68ed 100644
--- a/smartsim/_core/launcher/dragon/dragon_launcher.py
+++ b/smartsim/_core/launcher/dragon/dragon_launcher.py
@@ -72,7 +72,7 @@
# ***************************************
# TODO: Remove pylint disable after merge
# ***************************************
-# pylint: disable=protected-access
+# pylint: disable=protected-access,wrong-import-position
class DragonLauncher(WLMLauncher):
@@ -206,6 +206,8 @@ def run(self, step: Step) -> t.Optional[str]:
self._connector.load_persisted_env()
nodes = int(run_args.get("nodes", None) or 1)
tasks_per_node = int(run_args.get("tasks-per-node", None) or 1)
+ hosts = run_args.get("host-list", None)
+
policy = DragonRunPolicy.from_run_args(run_args)
step_id = self.start(
(
@@ -219,6 +221,7 @@ def run(self, step: Step) -> t.Optional[str]:
env=req_env,
output_file=out,
error_file=err,
+ hostlist=hosts,
),
policy,
)
@@ -374,15 +377,15 @@ def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT:
return obj
-from smartsim._core.dispatch import dispatch # pylint: disable=wrong-import-position
+from smartsim._core.dispatch import dispatch
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# TODO: Remove this registry and move back to builder file after fixing
# circular import caused by `DragonLauncher.supported_rs`
# -----------------------------------------------------------------------------
-from smartsim.settings.arguments.launch.dragon import ( # pylint: disable=wrong-import-position
+from smartsim.settings.arguments.launch.dragon import (
DragonLaunchArguments,
-)
+) # pylint: disable=wrong-import-position
def _as_run_request_args_and_policy(
@@ -404,11 +407,6 @@ def _as_run_request_args_and_policy(
DragonRunRequestView(
exe=exe_,
exe_args=args,
- # FIXME: Currently this is hard coded because the schema requires
- # it, but in future, it is almost certainly necessary that
- # this will need to be injected by the user or by us to have
- # the command execute next to any generated files. A similar
- # problem exists for the other settings.
path=path,
env=env,
# TODO: Not sure how this info is injected
diff --git a/smartsim/_core/launcher/dragon/pqueue.py b/smartsim/_core/launcher/dragon/pqueue.py
new file mode 100644
index 0000000000..8c14a828f5
--- /dev/null
+++ b/smartsim/_core/launcher/dragon/pqueue.py
@@ -0,0 +1,461 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# import collections
+import enum
+import heapq
+import threading
+import typing as t
+
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class Node(t.Protocol):
+ """Base Node API required to support the NodePrioritizer"""
+
+ @property
+ def hostname(self) -> str:
+ """The hostname of the node"""
+
+ @property
+ def num_cpus(self) -> int:
+ """The number of CPUs in the node"""
+
+ @property
+ def num_gpus(self) -> int:
+ """The number of GPUs in the node"""
+
+
+class NodeReferenceCount(t.Protocol):
+ """Contains details pertaining to references to a node"""
+
+ @property
+ def hostname(self) -> str:
+ """The hostname of the node"""
+
+ @property
+ def num_refs(self) -> int:
+ """The number of jobs assigned to the node"""
+
+
+class _TrackedNode:
+ """Node API required to have support in the NodePrioritizer"""
+
+ def __init__(self, node: Node) -> None:
+ self._node = node
+ """The node being tracked"""
+ self._num_refs = 0
+ """The number of references to the tracked node"""
+ self._assigned_tasks: t.Set[str] = set()
+ """The unique identifiers of processes using this node"""
+ self._is_dirty = False
+ """Flag indicating that tracking information has been modified"""
+
+ @property
+ def hostname(self) -> str:
+ """Returns the hostname of the node"""
+ return self._node.hostname
+
+ @property
+ def num_cpus(self) -> int:
+ """Returns the number of CPUs in the node"""
+ return self._node.num_cpus
+
+ @property
+ def num_gpus(self) -> int:
+ """Returns the number of GPUs attached to the node"""
+ return self._node.num_gpus
+
+ @property
+ def num_refs(self) -> int:
+ """Returns the number of processes currently running on the node"""
+ return self._num_refs
+
+ @property
+ def is_assigned(self) -> bool:
+ """Returns `True` if no references are currently counted, `False` otherwise"""
+ return self._num_refs > 0
+
+ @property
+ def assigned_tasks(self) -> t.Set[str]:
+ """Returns the set of unique IDs for currently running processes"""
+ return self._assigned_tasks
+
+ @property
+ def is_dirty(self) -> bool:
+ """Returns a flag indicating if the reference counter has changed. `True`
+ if references have been added or removed, `False` otherwise."""
+ return self._is_dirty
+
+ def clean(self) -> None:
+ """Marks the node as unmodified"""
+ self._is_dirty = False
+
+ def add(
+ self,
+ tracking_id: t.Optional[str] = None,
+ ) -> None:
+ """Update the node to indicate the addition of a process that must be
+ reference counted.
+
+ :param tracking_id: a unique task identifier executing on the node
+ to add
+ :raises ValueError: if tracking_id is already assigned to this node"""
+ if tracking_id in self.assigned_tasks:
+ raise ValueError("Attempted adding task more than once")
+
+ self._num_refs = self._num_refs + 1
+ if tracking_id:
+ self._assigned_tasks = self._assigned_tasks.union({tracking_id})
+ self._is_dirty = True
+
+ def remove(
+ self,
+ tracking_id: t.Optional[str] = None,
+ ) -> None:
+ """Update the reference counter to indicate the removal of a process.
+
+ :param tracking_id: a unique task identifier executing on the node
+ to remove
+ :raises ValueError: if tracking_id is already assigned to this node"""
+ self._num_refs = max(self._num_refs - 1, 0)
+ if tracking_id:
+ self._assigned_tasks = self._assigned_tasks - {tracking_id}
+ self._is_dirty = True
+
+ def __lt__(self, other: "_TrackedNode") -> bool:
+ """Comparison operator used to evaluate the ordering of nodes within
+ the prioritizer. This comparison only considers reference counts.
+
+ :param other: Another node to compare against
+ :returns: True if this node has fewer references than the other node"""
+ if self.num_refs < other.num_refs:
+ return True
+
+ return False
+
+
+class PrioritizerFilter(str, enum.Enum):
+ """A filter used to select a subset of nodes to be queried"""
+
+ CPU = enum.auto()
+ GPU = enum.auto()
+
+
+class NodePrioritizer:
+ def __init__(self, nodes: t.List[Node], lock: threading.RLock) -> None:
+ """Initialize the prioritizer
+
+ :param nodes: node attribute information for initializing the priorizer
+ :param lock: a lock used to ensure threadsafe operations
+ :raises SmartSimError: if the nodes collection is empty
+ """
+ if not nodes:
+ raise SmartSimError("Missing nodes to prioritize")
+
+ self._lock = lock
+ """Lock used to ensure thread safe changes of the reference counters"""
+ self._cpu_refs: t.List[_TrackedNode] = []
+ """Track reference counts to CPU-only nodes"""
+ self._gpu_refs: t.List[_TrackedNode] = []
+ """Track reference counts to GPU nodes"""
+ self._nodes: t.Dict[str, _TrackedNode] = {}
+
+ self._initialize_reference_counters(nodes)
+
+ def _initialize_reference_counters(self, nodes: t.List[Node]) -> None:
+ """Perform initialization of reference counters for nodes in the allocation
+
+ :param nodes: node attribute information for initializing the priorizer"""
+ for node in nodes:
+ # create a set of reference counters for the nodes
+ tracked = _TrackedNode(node)
+
+ self._nodes[node.hostname] = tracked # for O(1) access
+
+ if node.num_gpus:
+ self._gpu_refs.append(tracked)
+ else:
+ self._cpu_refs.append(tracked)
+
+ def increment(
+ self, host: str, tracking_id: t.Optional[str] = None
+ ) -> NodeReferenceCount:
+ """Directly increment the reference count of a given node and ensure the
+ ref counter is marked as dirty to trigger a reordering on retrieval
+
+ :param host: a hostname that should have a reference counter selected
+ :param tracking_id: a unique task identifier executing on the node
+ to add"""
+ with self._lock:
+ tracked_node = self._nodes[host]
+ tracked_node.add(tracking_id)
+ return tracked_node
+
+ def _heapify_all_refs(self) -> t.List[_TrackedNode]:
+ """Combine the CPU and GPU nodes into a single heap
+
+ :returns: list of all reference counters"""
+ refs = [*self._cpu_refs, *self._gpu_refs]
+ heapq.heapify(refs)
+ return refs
+
+ def get_tracking_info(self, host: str) -> NodeReferenceCount:
+ """Returns the reference counter information for a single node
+
+ :param host: a hostname that should have a reference counter selected
+ :returns: a reference counter for the node
+ :raises ValueError: if the hostname is not in the set of managed nodes"""
+ if host not in self._nodes:
+ raise ValueError("The supplied hostname was not found")
+
+ return self._nodes[host]
+
+ def decrement(
+ self, host: str, tracking_id: t.Optional[str] = None
+ ) -> NodeReferenceCount:
+ """Directly decrement the reference count of a given node and ensure the
+ ref counter is marked as dirty to trigger a reordering
+
+ :param host: a hostname that should have a reference counter decremented
+ :param tracking_id: unique task identifier to remove"""
+ with self._lock:
+ tracked_node = self._nodes[host]
+ tracked_node.remove(tracking_id)
+
+ return tracked_node
+
+ def _create_sub_heap(
+ self,
+ hosts: t.Optional[t.List[str]] = None,
+ filter_on: t.Optional[PrioritizerFilter] = None,
+ ) -> t.List[_TrackedNode]:
+ """Create a new heap from the primary heap with user-specified nodes
+
+ :param hosts: a list of hostnames used to filter the available nodes
+ :returns: a list of assigned reference counters
+ """
+ nodes_tracking_info: t.List[_TrackedNode] = []
+ heap = self._get_filtered_heap(filter_on)
+
+ # Collect all the tracking info for the requested nodes...
+ for node in heap:
+ if not hosts or node.hostname in hosts:
+ nodes_tracking_info.append(node)
+
+ # ... and use it to create a new heap from a specified subset of nodes
+ heapq.heapify(nodes_tracking_info)
+
+ return nodes_tracking_info
+
+ def unassigned(
+ self, heap: t.Optional[t.List[_TrackedNode]] = None
+ ) -> t.Sequence[Node]:
+ """Select nodes that are currently not assigned a task
+
+ :param heap: a subset of the node heap to consider
+ :returns: a list of reference counts for all unassigned nodes"""
+ if heap is None:
+ heap = list(self._nodes.values())
+
+ nodes: t.List[_TrackedNode] = []
+ for item in heap:
+ if item.num_refs == 0:
+ nodes.append(item)
+ return nodes
+
+ def assigned(
+ self, heap: t.Optional[t.List[_TrackedNode]] = None
+ ) -> t.Sequence[Node]:
+ """Helper method to identify the nodes that are currently assigned
+
+ :param heap: a subset of the node heap to consider
+ :returns: a list of reference counts for all assigned nodes"""
+ if heap is None:
+ heap = list(self._nodes.values())
+
+ nodes: t.List[_TrackedNode] = []
+ for item in heap:
+ if item.num_refs > 0:
+ nodes.append(item)
+ return nodes
+
+ def _check_satisfiable_n(
+ self, num_items: int, heap: t.Optional[t.List[_TrackedNode]] = None
+ ) -> bool:
+ """Validates that a request for some number of nodes `n` can be
+ satisfied by the prioritizer given the set of nodes available
+
+ :param num_items: the desired number of nodes to allocate
+ :param heap: a subset of the node heap to consider
+ :returns: True if the request can be fulfilled, False otherwise"""
+ num_nodes = len(self._nodes.keys())
+
+ if num_items < 1:
+ msg = "Cannot handle request; request requires a positive integer"
+ logger.warning(msg)
+ return False
+
+ if num_nodes < num_items:
+ msg = f"Cannot satisfy request for {num_items} nodes; {num_nodes} in pool"
+ logger.warning(msg)
+ return False
+
+ num_open = len(self.unassigned(heap))
+ if num_open < num_items:
+ msg = f"Cannot satisfy request for {num_items} nodes; {num_open} available"
+ logger.warning(msg)
+ return False
+
+ return True
+
+ def _get_next_unassigned_node(
+ self,
+ heap: t.List[_TrackedNode],
+ tracking_id: t.Optional[str] = None,
+ ) -> t.Optional[Node]:
+ """Finds the next node with no running processes and
+ ensures that any elements that were directly updated are updated in
+ the priority structure before being made available
+
+ :param heap: a subset of the node heap to consider
+ :param tracking_id: unique task identifier to track
+ :returns: a reference counter for an available node if an unassigned node
+ exists, `None` otherwise"""
+ tracking_info: t.Optional[_TrackedNode] = None
+
+ with self._lock:
+ # re-sort the heap to handle any tracking changes
+ if any(node.is_dirty for node in heap):
+ heapq.heapify(heap)
+
+ # grab the min node from the heap
+ tracking_info = heapq.heappop(heap)
+
+ # the node is available if it has no assigned tasks
+ is_assigned = tracking_info.is_assigned
+ if not is_assigned:
+ # track the new process on the node
+ tracking_info.add(tracking_id)
+
+ # add the node that was popped back into the heap
+ heapq.heappush(heap, tracking_info)
+
+ # mark all nodes as clean now that everything is updated & sorted
+ for node in heap:
+ node.clean()
+
+ # next available must only return previously unassigned nodes
+ if is_assigned:
+ return None
+
+ return tracking_info
+
+ def _get_next_n_available_nodes(
+ self,
+ num_items: int,
+ heap: t.List[_TrackedNode],
+ tracking_id: t.Optional[str] = None,
+ ) -> t.List[Node]:
+ """Find the next N available nodes w/least amount of references using
+ the supplied filter to target a specific node capability
+
+ :param num_items: number of nodes to reserve
+ :param heap: a subset of the node heap to consider
+ :param tracking_id: unique task identifier to track
+ :returns: a list of reference counters for a available nodes if enough
+ unassigned nodes exists, `None` otherwise
+ :raises ValueError: if the number of requested nodes is not a positive integer
+ """
+ next_nodes: t.List[Node] = []
+
+ if num_items < 1:
+ raise ValueError(f"Number of items requested {num_items} is invalid")
+
+ if not self._check_satisfiable_n(num_items, heap):
+ return next_nodes
+
+ while len(next_nodes) < num_items:
+ if next_node := self._get_next_unassigned_node(heap, tracking_id):
+ next_nodes.append(next_node)
+ continue
+ break
+
+ return next_nodes
+
+ def _get_filtered_heap(
+ self, filter_on: t.Optional[PrioritizerFilter] = None
+ ) -> t.List[_TrackedNode]:
+ """Helper method to select the set of nodes to include in a filtered
+ heap.
+
+ :param filter_on: A list of nodes that satisfy the filter. If no
+ filter is supplied, all nodes are returned"""
+ if filter_on == PrioritizerFilter.GPU:
+ return self._gpu_refs
+ if filter_on == PrioritizerFilter.CPU:
+ return self._cpu_refs
+
+ return self._heapify_all_refs()
+
+ def next(
+ self,
+ filter_on: t.Optional[PrioritizerFilter] = None,
+ tracking_id: t.Optional[str] = None,
+ hosts: t.Optional[t.List[str]] = None,
+ ) -> t.Optional[Node]:
+ """Find the next unsassigned node using the supplied filter to target
+ a specific node capability
+
+ :param filter_on: the subset of nodes to query for available nodes
+ :param tracking_id: unique task identifier to track
+ :param hosts: a list of hostnames used to filter the available nodes
+ :returns: a reference counter for an available node if an unassigned node
+ exists, `None` otherwise"""
+ if results := self.next_n(1, filter_on, tracking_id, hosts):
+ return results[0]
+ return None
+
+ def next_n(
+ self,
+ num_items: int = 1,
+ filter_on: t.Optional[PrioritizerFilter] = None,
+ tracking_id: t.Optional[str] = None,
+ hosts: t.Optional[t.List[str]] = None,
+ ) -> t.List[Node]:
+ """Find the next N available nodes w/least amount of references using
+ the supplied filter to target a specific node capability
+
+ :param num_items: number of nodes to reserve
+ :param filter_on: the subset of nodes to query for available nodes
+ :param tracking_id: unique task identifier to track
+ :param hosts: a list of hostnames used to filter the available nodes
+ :returns: Collection of reserved nodes
+ :raises ValueError: if the hosts parameter is an empty list"""
+ heap = self._create_sub_heap(hosts, filter_on)
+ return self._get_next_n_available_nodes(num_items, heap, tracking_id)
diff --git a/smartsim/_core/launcher/step/alps_step.py b/smartsim/_core/launcher/step/alps_step.py
index 047e75d2cf..dc9f3bff61 100644
--- a/smartsim/_core/launcher/step/alps_step.py
+++ b/smartsim/_core/launcher/step/alps_step.py
@@ -126,14 +126,14 @@ def _build_exe(self) -> t.List[str]:
return self._make_mpmd()
exe = self.entity.exe
- args = self.entity.exe_args # pylint: disable=protected-access
+ args = self.entity.exe_args
return exe + args
def _make_mpmd(self) -> t.List[str]:
"""Build Aprun (MPMD) executable"""
exe = self.entity.exe
- exe_args = self.entity._exe_args # pylint: disable=protected-access
+ exe_args = self.entity.exe_args
cmd = exe + exe_args
for mpmd in self._get_mpmd():
diff --git a/smartsim/_core/launcher/step/dragon_step.py b/smartsim/_core/launcher/step/dragon_step.py
index 63e9f65fe8..f1e8662e2a 100644
--- a/smartsim/_core/launcher/step/dragon_step.py
+++ b/smartsim/_core/launcher/step/dragon_step.py
@@ -170,6 +170,7 @@ def _write_request_file(self) -> str:
env = run_settings.env_vars
nodes = int(run_args.get("nodes", None) or 1)
tasks_per_node = int(run_args.get("tasks-per-node", None) or 1)
+ hosts_csv = run_args.get("host-list", None)
policy = DragonRunPolicy.from_run_args(run_args)
@@ -188,6 +189,7 @@ def _write_request_file(self) -> str:
output_file=out,
error_file=err,
policy=policy,
+ hostlist=hosts_csv,
)
requests.append(request_registry.to_string(request))
with open(request_file, "w", encoding="utf-8") as script_file:
diff --git a/smartsim/_core/launcher/step/lsf_step.py b/smartsim/_core/launcher/step/lsf_step.py
index 372e21c81b..80583129c1 100644
--- a/smartsim/_core/launcher/step/lsf_step.py
+++ b/smartsim/_core/launcher/step/lsf_step.py
@@ -217,7 +217,7 @@ def _build_exe(self) -> t.List[str]:
:return: executable list
"""
exe = self.entity.exe
- args = self.entity.exe_args # pylint: disable=protected-access
+ args = self.entity.exe_args
if self._get_mpmd():
erf_file = self.get_step_file(ending=".mpmd")
diff --git a/smartsim/_core/launcher/step/mpi_step.py b/smartsim/_core/launcher/step/mpi_step.py
index 06a94cd4cc..0eb2f34fdb 100644
--- a/smartsim/_core/launcher/step/mpi_step.py
+++ b/smartsim/_core/launcher/step/mpi_step.py
@@ -136,13 +136,13 @@ def _build_exe(self) -> t.List[str]:
return self._make_mpmd()
exe = self.entity.exe
- args = self.entity.exe_args # pylint: disable=protected-access
+ args = self.entity.exe_args
return exe + args
def _make_mpmd(self) -> t.List[str]:
"""Build mpiexec (MPMD) executable"""
exe = self.entity.exe
- args = self.entity.exe_args # pylint: disable=protected-access
+ args = self.entity.exe_args
cmd = exe + args
for mpmd in self._get_mpmd():
@@ -150,7 +150,7 @@ def _make_mpmd(self) -> t.List[str]:
cmd += mpmd.format_run_args()
cmd += mpmd.format_env_vars()
cmd += mpmd.exe
- cmd += mpmd.exe_args # pylint: disable=protected-access
+ cmd += mpmd.exe_args
cmd = sh_split(" ".join(cmd))
return cmd
diff --git a/smartsim/_core/launcher/step/slurm_step.py b/smartsim/_core/launcher/step/slurm_step.py
index af042dfc18..410d14d269 100644
--- a/smartsim/_core/launcher/step/slurm_step.py
+++ b/smartsim/_core/launcher/step/slurm_step.py
@@ -211,8 +211,9 @@ def _build_exe(self) -> t.List[str]:
return exe + args
# There is an issue here, exe and exe_args are no longer attached to the
- # runsettings. This functions is looping through the list of run_settings.mpmd
- # and build the variable cmd
+ # runsettings
+ # This functions is looping through the list of run_settings.mpmd and
+ # build the variable cmd
def _make_mpmd(self) -> t.List[str]:
"""Build Slurm multi-prog (MPMD) executable"""
exe = self.entity.exe
diff --git a/smartsim/_core/mli/__init__.py b/smartsim/_core/mli/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/smartsim/_core/mli/client/__init__.py b/smartsim/_core/mli/client/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/smartsim/_core/mli/client/protoclient.py b/smartsim/_core/mli/client/protoclient.py
new file mode 100644
index 0000000000..46598a8171
--- /dev/null
+++ b/smartsim/_core/mli/client/protoclient.py
@@ -0,0 +1,348 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# isort: off
+# pylint: disable=unused-import,import-error
+import dragon
+import dragon.channels
+from dragon.globalservices.api_setup import connect_to_infrastructure
+
+try:
+ from mpi4py import MPI # type: ignore[import-not-found]
+except Exception:
+ MPI = None
+ print("Unable to import `mpi4py` package")
+
+# isort: on
+# pylint: enable=unused-import,import-error
+
+import numbers
+import os
+import time
+import typing as t
+from collections import OrderedDict
+
+import numpy
+import torch
+
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
+from smartsim._core.mli.comm.channel.dragon_util import create_local
+from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster
+from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim._core.mli.message_handler import MessageHandler
+from smartsim._core.utils.timings import PerfTimer
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+_TimingDict = OrderedDict[str, list[str]]
+
+
+logger = get_logger("App")
+logger.info("Started app")
+CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
+
+
+class ProtoClient:
+ """Proof of concept implementation of a client enabling user applications
+ to interact with MLI resources."""
+
+ _DEFAULT_BACKBONE_TIMEOUT = 1.0
+ """A default timeout period applied to connection attempts with the
+ backbone feature store."""
+
+ _DEFAULT_WORK_QUEUE_SIZE = 500
+ """A default number of events to be buffered in the work queue before
+ triggering QueueFull exceptions."""
+
+ _EVENT_SOURCE = "proto-client"
+ """A user-friendly name for this class instance to identify
+ the client as the publisher of an event."""
+
+ @staticmethod
+ def _attach_to_backbone() -> BackboneFeatureStore:
+ """Use the supplied environment variables to attach
+ to a pre-existing backbone featurestore. Requires the
+ environment to contain `_SMARTSIM_INFRA_BACKBONE`
+ environment variable.
+
+ :returns: The attached backbone featurestore
+ :raises SmartSimError: If the backbone descriptor is not contained
+ in the appropriate environment variable
+ """
+ descriptor = os.environ.get(BackboneFeatureStore.MLI_BACKBONE, None)
+ if descriptor is None or not descriptor:
+ raise SmartSimError(
+ "Missing required backbone configuration in environment: "
+ f"{BackboneFeatureStore.MLI_BACKBONE}"
+ )
+
+ backbone = t.cast(
+ BackboneFeatureStore, BackboneFeatureStore.from_descriptor(descriptor)
+ )
+ return backbone
+
+ def _attach_to_worker_queue(self) -> DragonFLIChannel:
+ """Wait until the backbone contains the worker queue configuration,
+ then attach an FLI to the given worker queue.
+
+ :returns: The attached FLI channel
+ :raises SmartSimError: if the required configuration is not found in the
+ backbone feature store
+ """
+
+ descriptor = ""
+ try:
+ # NOTE: without wait_for, this MUST be in the backbone....
+ config = self._backbone.wait_for(
+ [BackboneFeatureStore.MLI_WORKER_QUEUE], self.backbone_timeout
+ )
+ descriptor = str(config[BackboneFeatureStore.MLI_WORKER_QUEUE])
+ except Exception as ex:
+ logger.info(
+ f"Unable to retrieve {BackboneFeatureStore.MLI_WORKER_QUEUE} "
+ "to attach to the worker queue."
+ )
+ raise SmartSimError("Unable to locate worker queue using backbone") from ex
+
+ return DragonFLIChannel.from_descriptor(descriptor)
+
+ def _create_broadcaster(self) -> EventBroadcaster:
+ """Create an EventBroadcaster that broadcasts events to
+ all MLI components registered to consume them.
+
+ :returns: An EventBroadcaster instance
+ """
+ broadcaster = EventBroadcaster(
+ self._backbone, DragonCommChannel.from_descriptor
+ )
+ return broadcaster
+
+ def __init__(
+ self,
+ timing_on: bool,
+ backbone_timeout: float = _DEFAULT_BACKBONE_TIMEOUT,
+ ) -> None:
+ """Initialize the client instance.
+
+ :param timing_on: Flag indicating if timing information should be
+ written to file
+ :param backbone_timeout: Maximum wait time (in seconds) allowed to attach to the
+ worker queue
+ :raises SmartSimError: If unable to attach to a backbone featurestore
+ :raises ValueError: If an invalid backbone timeout is specified
+ """
+ if MPI is not None:
+ # TODO: determine a way to make MPI work in the test environment
+ # - consider catching the import exception and defaulting rank to 0
+ comm = MPI.COMM_WORLD
+ rank: int = comm.Get_rank()
+ else:
+ rank = 0
+
+ if backbone_timeout <= 0:
+ raise ValueError(
+ f"Invalid backbone timeout provided: {backbone_timeout}. "
+ "The value must be greater than zero."
+ )
+ self._backbone_timeout = max(backbone_timeout, 0.1)
+
+ connect_to_infrastructure()
+
+ self._backbone = self._attach_to_backbone()
+ self._backbone.wait_timeout = self.backbone_timeout
+ self._to_worker_fli = self._attach_to_worker_queue()
+
+ self._from_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE)
+ self._to_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE)
+
+ self._publisher = self._create_broadcaster()
+
+ self.perf_timer: PerfTimer = PerfTimer(
+ debug=False, timing_on=timing_on, prefix=f"a{rank}_"
+ )
+ self._start: t.Optional[float] = None
+ self._interm: t.Optional[float] = None
+ self._timings: _TimingDict = OrderedDict()
+ self._timing_on = timing_on
+
+ @property
+ def backbone_timeout(self) -> float:
+ """The timeout (in seconds) applied to retrievals
+ from the backbone feature store.
+
+ :returns: A float indicating the number of seconds to allow"""
+ return self._backbone_timeout
+
+ def _add_label_to_timings(self, label: str) -> None:
+ """Adds a new label into the timing dictionary to prepare for
+ receiving timing events.
+
+ :param label: The label to create storage for
+ """
+ if label not in self._timings:
+ self._timings[label] = []
+
+ @staticmethod
+ def _format_number(number: t.Union[numbers.Number, float]) -> str:
+ """Utility function for formatting numbers consistently for logs.
+
+ :param number: The number to convert to a formatted string
+ :returns: The formatted string containing the number
+ """
+ return f"{number:0.4e}"
+
+ def start_timings(self, batch_size: numbers.Number) -> None:
+ """Configure the client to begin storing timing information.
+
+ :param batch_size: The size of batches to generate as inputs
+ to the model
+ """
+ if self._timing_on:
+ self._add_label_to_timings("batch_size")
+ self._timings["batch_size"].append(self._format_number(batch_size))
+ self._start = time.perf_counter()
+ self._interm = time.perf_counter()
+
+ def end_timings(self) -> None:
+ """Configure the client to stop storing timing information."""
+ if self._timing_on and self._start is not None:
+ self._add_label_to_timings("total_time")
+ self._timings["total_time"].append(
+ self._format_number(time.perf_counter() - self._start)
+ )
+
+ def measure_time(self, label: str) -> None:
+ """Measures elapsed time since the last recorded signal.
+
+ :param label: The label to measure time for
+ """
+ if self._timing_on and self._interm is not None:
+ self._add_label_to_timings(label)
+ self._timings[label].append(
+ self._format_number(time.perf_counter() - self._interm)
+ )
+ self._interm = time.perf_counter()
+
+ def print_timings(self, to_file: bool = False) -> None:
+ """Print timing information to standard output. If `to_file`
+ is `True`, also write results to a file.
+
+ :param to_file: If `True`, also saves timing information
+ to the files `timings.npy` and `timings.txt`
+ """
+ print(" ".join(self._timings.keys()))
+
+ value_array = numpy.array(self._timings.values(), dtype=float)
+ value_array = numpy.transpose(value_array)
+ for i in range(value_array.shape[0]):
+ print(" ".join(self._format_number(value) for value in value_array[i]))
+ if to_file:
+ numpy.save("timings.npy", value_array)
+ numpy.savetxt("timings.txt", value_array)
+
+ def run_model(self, model: t.Union[bytes, str], batch: torch.Tensor) -> t.Any:
+ """Execute a batch of inference requests with the supplied ML model.
+
+ :param model: The raw bytes or path to a pytorch model
+ :param batch: The tensor batch to perform inference on
+ :returns: The inference results
+ :raises ValueError: if the worker queue is not configured properly
+ in the environment variables
+ """
+ tensors = [batch.numpy()]
+ self.perf_timer.start_timings("batch_size", batch.shape[0])
+ built_tensor_desc = MessageHandler.build_tensor_descriptor(
+ "c", "float32", list(batch.shape)
+ )
+ self.perf_timer.measure_time("build_tensor_descriptor")
+ if isinstance(model, str):
+ model_arg = MessageHandler.build_model_key(model, self._backbone.descriptor)
+ else:
+ model_arg = MessageHandler.build_model(
+ model, "resnet-50", "1.0"
+ ) # type: ignore
+ request = MessageHandler.build_request(
+ reply_channel=self._from_worker_ch.descriptor,
+ model=model_arg,
+ inputs=[built_tensor_desc],
+ outputs=[],
+ output_descriptors=[],
+ custom_attributes=None,
+ )
+ self.perf_timer.measure_time("build_request")
+ request_bytes = MessageHandler.serialize_request(request)
+ self.perf_timer.measure_time("serialize_request")
+
+ if self._to_worker_fli is None:
+ raise ValueError("No worker queue available.")
+
+ # pylint: disable-next=protected-access
+ with self._to_worker_fli._channel.sendh( # type: ignore
+ timeout=None,
+ stream_channel=self._to_worker_ch.channel,
+ ) as to_sendh:
+ to_sendh.send_bytes(request_bytes)
+ self.perf_timer.measure_time("send_request")
+ for tensor in tensors:
+ to_sendh.send_bytes(tensor.tobytes()) # TODO NOT FAST ENOUGH!!!
+ logger.info(f"Message size: {len(request_bytes)} bytes")
+
+ self.perf_timer.measure_time("send_tensors")
+ with self._from_worker_ch.channel.recvh(timeout=None) as from_recvh:
+ resp = from_recvh.recv_bytes(timeout=None)
+ self.perf_timer.measure_time("receive_response")
+ response = MessageHandler.deserialize_response(resp)
+ self.perf_timer.measure_time("deserialize_response")
+
+ # recv depending on the len(response.result.descriptors)?
+ data_blob: bytes = from_recvh.recv_bytes(timeout=None)
+ self.perf_timer.measure_time("receive_tensor")
+ result = torch.from_numpy(
+ numpy.frombuffer(
+ data_blob,
+ dtype=str(response.result.descriptors[0].dataType),
+ )
+ )
+ self.perf_timer.measure_time("deserialize_tensor")
+
+ self.perf_timer.end_timings()
+ return result
+
+ def set_model(self, key: str, model: bytes) -> None:
+ """Write the supplied model to the feature store.
+
+ :param key: The unique key used to identify the model
+ :param model: The raw bytes of the model to execute
+ """
+ self._backbone[key] = model
+
+ # notify components of a change in the data at this key
+ event = OnWriteFeatureStore(self._EVENT_SOURCE, self._backbone.descriptor, key)
+ self._publisher.send(event)
diff --git a/smartsim/_core/mli/comm/channel/__init__.py b/smartsim/_core/mli/comm/channel/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/smartsim/_core/mli/comm/channel/channel.py b/smartsim/_core/mli/comm/channel/channel.py
new file mode 100644
index 0000000000..104333ce7f
--- /dev/null
+++ b/smartsim/_core/mli/comm/channel/channel.py
@@ -0,0 +1,82 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import base64
+import typing as t
+import uuid
+from abc import ABC, abstractmethod
+
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class CommChannelBase(ABC):
+ """Base class for abstracting a message passing mechanism"""
+
+ def __init__(
+ self,
+ descriptor: str,
+ name: t.Optional[str] = None,
+ ) -> None:
+ """Initialize the CommChannel instance.
+
+ :param descriptor: Channel descriptor
+ """
+ self._descriptor = descriptor
+ """An opaque identifier used to connect to an underlying communication channel"""
+ self._name = name or str(uuid.uuid4())
+ """A user-friendly identifier for channel-related logging"""
+
+ @abstractmethod
+ def send(self, value: bytes, timeout: float = 0.001) -> None:
+ """Send a message through the underlying communication channel.
+
+ :param value: The value to send
+ :param timeout: Maximum time to wait (in seconds) for messages to send
+ :raises SmartSimError: If sending message fails
+ """
+
+ @abstractmethod
+ def recv(self, timeout: float = 0.001) -> t.List[bytes]:
+ """Receives message(s) through the underlying communication channel.
+
+ :param timeout: Maximum time to wait (in seconds) for messages to arrive
+ :returns: The received message
+ """
+
+ @property
+ def descriptor(self) -> str:
+ """Return the channel descriptor for the underlying dragon channel.
+
+ :returns: Byte encoded channel descriptor
+ """
+ return self._descriptor
+
+ def __str__(self) -> str:
+ """Build a string representation of the channel useful for printing."""
+ classname = type(self).__class__.__name__
+ return f"{classname}('{self._name}', '{self._descriptor}')"
diff --git a/smartsim/_core/mli/comm/channel/dragon_channel.py b/smartsim/_core/mli/comm/channel/dragon_channel.py
new file mode 100644
index 0000000000..110f19258a
--- /dev/null
+++ b/smartsim/_core/mli/comm/channel/dragon_channel.py
@@ -0,0 +1,127 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import typing as t
+
+import dragon.channels as dch
+
+import smartsim._core.mli.comm.channel.channel as cch
+import smartsim._core.mli.comm.channel.dragon_util as drg_util
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class DragonCommChannel(cch.CommChannelBase):
+ """Passes messages by writing to a Dragon channel."""
+
+ def __init__(self, channel: "dch.Channel") -> None:
+ """Initialize the DragonCommChannel instance.
+
+ :param channel: A channel to use for communications
+ """
+ descriptor = drg_util.channel_to_descriptor(channel)
+ super().__init__(descriptor)
+ self._channel = channel
+ """The underlying dragon channel used by this CommChannel for communications"""
+
+ @property
+ def channel(self) -> "dch.Channel":
+ """The underlying communication channel.
+
+ :returns: The channel
+ """
+ return self._channel
+
+ def send(self, value: bytes, timeout: float = 0.001) -> None:
+ """Send a message through the underlying communication channel.
+
+ :param value: The value to send
+ :param timeout: Maximum time to wait (in seconds) for messages to send
+ :raises SmartSimError: If sending message fails
+ """
+ try:
+ with self._channel.sendh(timeout=timeout) as sendh:
+ sendh.send_bytes(value, blocking=False)
+ logger.debug(f"DragonCommChannel {self.descriptor} sent message")
+ except Exception as e:
+ raise SmartSimError(
+ f"Error sending via DragonCommChannel {self.descriptor}"
+ ) from e
+
+ def recv(self, timeout: float = 0.001) -> t.List[bytes]:
+ """Receives message(s) through the underlying communication channel.
+
+ :param timeout: Maximum time to wait (in seconds) for messages to arrive
+ :returns: The received message(s)
+ """
+ with self._channel.recvh(timeout=timeout) as recvh:
+ messages: t.List[bytes] = []
+
+ try:
+ message_bytes = recvh.recv_bytes(timeout=timeout)
+ messages.append(message_bytes)
+ logger.debug(f"DragonCommChannel {self.descriptor} received message")
+ except dch.ChannelEmpty:
+ # emptied the queue, ok to swallow this ex
+ logger.debug(f"DragonCommChannel exhausted: {self.descriptor}")
+ except dch.ChannelRecvTimeout:
+ logger.debug(f"Timeout exceeded on channel.recv: {self.descriptor}")
+
+ return messages
+
+ @classmethod
+ def from_descriptor(
+ cls,
+ descriptor: str,
+ ) -> "DragonCommChannel":
+ """A factory method that creates an instance from a descriptor string.
+
+ :param descriptor: The descriptor that uniquely identifies the resource.
+ :returns: An attached DragonCommChannel
+ :raises SmartSimError: If creation of comm channel fails
+ """
+ try:
+ channel = drg_util.descriptor_to_channel(descriptor)
+ return DragonCommChannel(channel)
+ except Exception as ex:
+ raise SmartSimError(
+ f"Failed to create dragon comm channel: {descriptor}"
+ ) from ex
+
+ @classmethod
+ def from_local(cls, _descriptor: t.Optional[str] = None) -> "DragonCommChannel":
+ """A factory method that creates a local channel instance.
+
+ :param _descriptor: Unused placeholder
+ :returns: An attached DragonCommChannel"""
+ try:
+ channel = drg_util.create_local()
+ return DragonCommChannel(channel)
+ except:
+ logger.error(f"Failed to create local dragon comm channel", exc_info=True)
+ raise
diff --git a/smartsim/_core/mli/comm/channel/dragon_fli.py b/smartsim/_core/mli/comm/channel/dragon_fli.py
new file mode 100644
index 0000000000..01849247cd
--- /dev/null
+++ b/smartsim/_core/mli/comm/channel/dragon_fli.py
@@ -0,0 +1,160 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# isort: off
+
+import dragon
+import dragon.fli as fli
+from dragon.channels import Channel
+
+# isort: on
+
+import typing as t
+
+import smartsim._core.mli.comm.channel.channel as cch
+import smartsim._core.mli.comm.channel.dragon_util as drg_util
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class DragonFLIChannel(cch.CommChannelBase):
+ """Passes messages by writing to a Dragon FLI Channel."""
+
+ def __init__(
+ self,
+ fli_: fli.FLInterface,
+ buffer_size: int = drg_util.DEFAULT_CHANNEL_BUFFER_SIZE,
+ ) -> None:
+ """Initialize the DragonFLIChannel instance.
+
+ :param fli_: The FLIInterface to use as the underlying communications channel
+ :param sender_supplied: Flag indicating if the FLI uses sender-supplied streams
+ :param buffer_size: Maximum number of sent messages that can be buffered
+ """
+ descriptor = drg_util.channel_to_descriptor(fli_)
+ super().__init__(descriptor)
+
+ self._channel: t.Optional["Channel"] = None
+ """The underlying dragon Channel used by a sender-side DragonFLIChannel
+ to attach to the main FLI channel"""
+
+ self._fli = fli_
+ """The underlying dragon FLInterface used by this CommChannel for communications"""
+ self._buffer_size: int = buffer_size
+ """Maximum number of messages that can be buffered before sending"""
+
+ def send(self, value: bytes, timeout: float = 0.001) -> None:
+ """Send a message through the underlying communication channel.
+
+ :param value: The value to send
+ :param timeout: Maximum time to wait (in seconds) for messages to send
+ :raises SmartSimError: If sending message fails
+ """
+ try:
+ if self._channel is None:
+ self._channel = drg_util.create_local(self._buffer_size)
+
+ with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh:
+ sendh.send_bytes(value, timeout=timeout)
+ logger.debug(f"DragonFLIChannel {self.descriptor} sent message")
+ except Exception as e:
+ self._channel = None
+ raise SmartSimError(
+ f"Error sending via DragonFLIChannel {self.descriptor}"
+ ) from e
+
+ def send_multiple(
+ self,
+ values: t.Sequence[bytes],
+ timeout: float = 0.001,
+ ) -> None:
+ """Send a message through the underlying communication channel.
+
+ :param values: The values to send
+ :param timeout: Maximum time to wait (in seconds) for messages to send
+ :raises SmartSimError: If sending message fails
+ """
+ try:
+ if self._channel is None:
+ self._channel = drg_util.create_local(self._buffer_size)
+
+ with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh:
+ for value in values:
+ sendh.send_bytes(value)
+ logger.debug(f"DragonFLIChannel {self.descriptor} sent message")
+ except Exception as e:
+ self._channel = None
+ raise SmartSimError(
+ f"Error sending via DragonFLIChannel {self.descriptor} {e}"
+ ) from e
+
+ def recv(self, timeout: float = 0.001) -> t.List[bytes]:
+ """Receives message(s) through the underlying communication channel.
+
+ :param timeout: Maximum time to wait (in seconds) for messages to arrive
+ :returns: The received message(s)
+ :raises SmartSimError: If receiving message(s) fails
+ """
+ messages = []
+ eot = False
+ with self._fli.recvh(timeout=timeout) as recvh:
+ while not eot:
+ try:
+ message, _ = recvh.recv_bytes(timeout=timeout)
+ messages.append(message)
+ logger.debug(f"DragonFLIChannel {self.descriptor} received message")
+ except fli.FLIEOT:
+ eot = True
+ logger.debug(f"DragonFLIChannel exhausted: {self.descriptor}")
+ except Exception as e:
+ raise SmartSimError(
+ f"Error receiving messages: DragonFLIChannel {self.descriptor}"
+ ) from e
+ return messages
+
+ @classmethod
+ def from_descriptor(
+ cls,
+ descriptor: str,
+ ) -> "DragonFLIChannel":
+ """A factory method that creates an instance from a descriptor string.
+
+ :param descriptor: The descriptor that uniquely identifies the resource
+ :returns: An attached DragonFLIChannel
+ :raises SmartSimError: If creation of DragonFLIChannel fails
+ :raises ValueError: If the descriptor is invalid
+ """
+ if not descriptor:
+ raise ValueError("Invalid descriptor provided")
+
+ try:
+ return DragonFLIChannel(fli_=drg_util.descriptor_to_fli(descriptor))
+ except Exception as e:
+ raise SmartSimError(
+ f"Error while creating DragonFLIChannel: {descriptor}"
+ ) from e
diff --git a/smartsim/_core/mli/comm/channel/dragon_util.py b/smartsim/_core/mli/comm/channel/dragon_util.py
new file mode 100644
index 0000000000..8517979ec4
--- /dev/null
+++ b/smartsim/_core/mli/comm/channel/dragon_util.py
@@ -0,0 +1,131 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import base64
+import binascii
+import typing as t
+
+import dragon.channels as dch
+import dragon.fli as fli
+import dragon.managed_memory as dm
+
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+DEFAULT_CHANNEL_BUFFER_SIZE = 500
+"""Maximum number of messages that can be buffered. DragonCommChannel will
+raise an exception if no clients consume messages before the buffer is filled."""
+
+LAST_OFFSET = 0
+"""The last offset used to create a local channel. This is used to avoid
+unnecessary retries when creating a local channel."""
+
+
+def channel_to_descriptor(channel: t.Union[dch.Channel, fli.FLInterface]) -> str:
+ """Convert a dragon channel to a descriptor string.
+
+ :param channel: The dragon channel to convert
+ :returns: The descriptor string
+ :raises ValueError: If a dragon channel is not provided
+ """
+ if channel is None:
+ raise ValueError("Channel is not available to create a descriptor")
+
+ serialized_ch = channel.serialize()
+ return base64.b64encode(serialized_ch).decode("utf-8")
+
+
+def pool_to_descriptor(pool: dm.MemoryPool) -> str:
+ """Convert a dragon memory pool to a descriptor string.
+
+ :param pool: The memory pool to convert
+ :returns: The descriptor string
+ :raises ValueError: If a memory pool is not provided
+ """
+ if pool is None:
+ raise ValueError("Memory pool is not available to create a descriptor")
+
+ serialized_pool = pool.serialize()
+ return base64.b64encode(serialized_pool).decode("utf-8")
+
+
+def descriptor_to_fli(descriptor: str) -> "fli.FLInterface":
+ """Create and attach a new FLI instance given
+ the string-encoded descriptor.
+
+ :param descriptor: The descriptor of an FLI to attach to
+ :returns: The attached dragon FLI
+ :raises ValueError: If the descriptor is empty or incorrectly formatted
+ :raises SmartSimError: If attachment using the descriptor fails
+ """
+ if len(descriptor) < 1:
+ raise ValueError("Descriptors may not be empty")
+
+ try:
+ encoded = descriptor.encode("utf-8")
+ descriptor_ = base64.b64decode(encoded)
+ return fli.FLInterface.attach(descriptor_)
+ except binascii.Error:
+ raise ValueError("The descriptor was not properly base64 encoded")
+ except fli.DragonFLIError:
+ raise SmartSimError("The descriptor did not address an available FLI")
+
+
+def descriptor_to_channel(descriptor: str) -> dch.Channel:
+ """Create and attach a new Channel instance given
+ the string-encoded descriptor.
+
+ :param descriptor: The descriptor of a channel to attach to
+ :returns: The attached dragon Channel
+ :raises ValueError: If the descriptor is empty or incorrectly formatted
+ :raises SmartSimError: If attachment using the descriptor fails
+ """
+ if len(descriptor) < 1:
+ raise ValueError("Descriptors may not be empty")
+
+ try:
+ encoded = descriptor.encode("utf-8")
+ descriptor_ = base64.b64decode(encoded)
+ return dch.Channel.attach(descriptor_)
+ except binascii.Error:
+ raise ValueError("The descriptor was not properly base64 encoded")
+ except dch.ChannelError:
+ raise SmartSimError("The descriptor did not address an available channel")
+
+
+def create_local(_capacity: int = 0) -> dch.Channel:
+ """Creates a Channel attached to the local memory pool. Replacement for
+ direct calls to `dch.Channel.make_process_local()` to enable
+ supplying a channel capacity.
+
+ :param _capacity: The number of events the channel can buffer; uses the default
+ buffer size `DEFAULT_CHANNEL_BUFFER_SIZE` when not supplied
+ :returns: The instantiated channel
+ """
+ channel = dch.Channel.make_process_local()
+ return channel
diff --git a/smartsim/_core/mli/infrastructure/__init__.py b/smartsim/_core/mli/infrastructure/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/smartsim/_core/mli/infrastructure/comm/__init__.py b/smartsim/_core/mli/infrastructure/comm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/smartsim/_core/mli/infrastructure/comm/broadcaster.py b/smartsim/_core/mli/infrastructure/comm/broadcaster.py
new file mode 100644
index 0000000000..56dcf549f7
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/comm/broadcaster.py
@@ -0,0 +1,239 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import typing as t
+import uuid
+from collections import defaultdict, deque
+
+from smartsim._core.mli.comm.channel.channel import CommChannelBase
+from smartsim._core.mli.infrastructure.comm.event import EventBase
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class BroadcastResult(t.NamedTuple):
+ """Contains summary details about a broadcast."""
+
+ num_sent: int
+ """The total number of messages delivered across all consumers"""
+ num_failed: int
+ """The total number of messages not delivered across all consumers"""
+
+
+class EventBroadcaster:
+ """Performs fan-out publishing of system events."""
+
+ def __init__(
+ self,
+ backbone: BackboneFeatureStore,
+ channel_factory: t.Optional[t.Callable[[str], CommChannelBase]] = None,
+ name: t.Optional[str] = None,
+ ) -> None:
+ """Initialize the EventPublisher instance.
+
+ :param backbone: The MLI backbone feature store
+ :param channel_factory: Factory method to construct new channel instances
+ :param name: A user-friendly name for logging. If not provided, an
+ auto-generated GUID will be used
+ """
+ self._backbone = backbone
+ """The backbone feature store used to retrieve consumer descriptors"""
+ self._channel_factory = channel_factory
+ """A factory method used to instantiate channels from descriptors"""
+ self._channel_cache: t.Dict[str, t.Optional[CommChannelBase]] = defaultdict(
+ lambda: None
+ )
+ """A mapping of instantiated channels that can be re-used. Automatically
+ calls the channel factory if a descriptor is not already in the collection"""
+ self._event_buffer: t.Deque[EventBase] = deque()
+ """A buffer for storing events when a consumer list is not found"""
+ self._descriptors: t.Set[str]
+ """Stores the most recent list of broadcast consumers. Updated automatically
+ on each broadcast"""
+ self._name = name or str(uuid.uuid4())
+ """A unique identifer assigned to the broadcaster for logging"""
+
+ @property
+ def name(self) -> str:
+ """The friendly name assigned to the broadcaster.
+
+ :returns: The broadcaster name if one is assigned, otherwise a unique
+ id assigned by the system.
+ """
+ return self._name
+
+ @property
+ def num_buffered(self) -> int:
+ """Return the number of events currently buffered to send.
+
+ :returns: Number of buffered events
+ """
+ return len(self._event_buffer)
+
+ def _save_to_buffer(self, event: EventBase) -> None:
+ """Places the event in the buffer to be sent once a consumer
+ list is available.
+
+ :param event: The event to buffer
+ :raises ValueError: If the event cannot be buffered
+ """
+ try:
+ self._event_buffer.append(event)
+ logger.debug(f"Buffered event {event=}")
+ except Exception as ex:
+ raise ValueError(
+ f"Unable to buffer event {event} in broadcaster {self.name}"
+ ) from ex
+
+ def _log_broadcast_start(self) -> None:
+ """Logs broadcast statistics."""
+ num_events = len(self._event_buffer)
+ num_copies = len(self._descriptors)
+ logger.debug(
+ f"Broadcast {num_events} events to {num_copies} consumers from {self.name}"
+ )
+
+ def _prune_unused_consumers(self) -> None:
+ """Performs maintenance on the channel cache by pruning any channel
+ that has been removed from the consumers list."""
+ active_consumers = set(self._descriptors)
+ current_channels = set(self._channel_cache.keys())
+
+ # find any cached channels that are now unused
+ inactive_channels = current_channels.difference(active_consumers)
+ new_channels = active_consumers.difference(current_channels)
+
+ for descriptor in inactive_channels:
+ self._channel_cache.pop(descriptor)
+
+ logger.debug(
+ f"Pruning {len(inactive_channels)} stale consumers and"
+ f" found {len(new_channels)} new channels for {self.name}"
+ )
+
+ def _get_comm_channel(self, descriptor: str) -> CommChannelBase:
+ """Helper method to build and cache a comm channel.
+
+ :param descriptor: The descriptor to pass to the channel factory
+ :returns: The instantiated channel
+ :raises SmartSimError: If the channel fails to attach
+ """
+ comm_channel = self._channel_cache[descriptor]
+ if comm_channel is not None:
+ return comm_channel
+
+ if self._channel_factory is None:
+ raise SmartSimError("No channel factory provided for consumers")
+
+ try:
+ channel = self._channel_factory(descriptor)
+ self._channel_cache[descriptor] = channel
+ return channel
+ except Exception as ex:
+ msg = f"Unable to construct channel with descriptor: {descriptor}"
+ logger.error(msg, exc_info=True)
+ raise SmartSimError(msg) from ex
+
+ def _get_next_event(self) -> t.Optional[EventBase]:
+ """Pop the next event to be sent from the queue.
+
+ :returns: The next event to send if any events are enqueued, otherwise `None`.
+ """
+ try:
+ return self._event_buffer.popleft()
+ except IndexError:
+ logger.debug(f"Broadcast buffer exhausted for {self.name}")
+
+ return None
+
+ def _broadcast(self, timeout: float = 0.001) -> BroadcastResult:
+ """Broadcasts all buffered events to registered event consumers.
+
+ :param timeout: Maximum time to wait (in seconds) for messages to send
+ :returns: BroadcastResult containing the number of messages that were
+ successfully and unsuccessfully sent for all consumers
+ :raises SmartSimError: If the channel fails to attach or broadcasting fails
+ """
+ # allow descriptors to be empty since events are buffered
+ self._descriptors = set(x for x in self._backbone.notification_channels if x)
+ if not self._descriptors:
+ msg = f"No event consumers are registered for {self.name}"
+ logger.warning(msg)
+ return BroadcastResult(0, 0)
+
+ self._prune_unused_consumers()
+ self._log_broadcast_start()
+
+ num_listeners = len(self._descriptors)
+ num_sent = 0
+ num_failures = 0
+
+ # send each event to every consumer
+ while event := self._get_next_event():
+ logger.debug(f"Broadcasting {event=} to {num_listeners} listeners")
+ event_bytes = bytes(event)
+
+ for i, descriptor in enumerate(self._descriptors):
+ comm_channel = self._get_comm_channel(descriptor)
+
+ try:
+ comm_channel.send(event_bytes, timeout)
+ num_sent += 1
+ except Exception:
+ msg = (
+ f"Broadcast {i+1}/{num_listeners} for event {event.uid} to "
+ f"channel {descriptor} from {self.name} failed."
+ )
+ logger.exception(msg)
+ num_failures += 1
+
+ return BroadcastResult(num_sent, num_failures)
+
+ def send(self, event: EventBase, timeout: float = 0.001) -> int:
+ """Implementation of `send` method of the `EventPublisher` protocol. Publishes
+ the supplied event to all registered broadcast consumers.
+
+ :param event: An event to publish
+ :param timeout: Maximum time to wait (in seconds) for messages to send
+ :returns: The total number of events successfully published to consumers
+ :raises ValueError: If event serialization fails
+ :raises AttributeError: If event cannot be serialized
+ :raises KeyError: If channel fails to attach using registered descriptors
+ :raises SmartSimError: If any unexpected error occurs during send
+ """
+ try:
+ self._save_to_buffer(event)
+ result = self._broadcast(timeout)
+ return result.num_sent
+ except (KeyError, ValueError, AttributeError, SmartSimError):
+ raise
+ except Exception as ex:
+ raise SmartSimError("An unexpected failure occurred while sending") from ex
diff --git a/smartsim/_core/mli/infrastructure/comm/consumer.py b/smartsim/_core/mli/infrastructure/comm/consumer.py
new file mode 100644
index 0000000000..08b5c47852
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/comm/consumer.py
@@ -0,0 +1,281 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pickle
+import time
+import typing as t
+import uuid
+
+from smartsim._core.mli.comm.channel.channel import CommChannelBase
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.infrastructure.comm.event import (
+ EventBase,
+ OnCreateConsumer,
+ OnRemoveConsumer,
+ OnShutdownRequested,
+)
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class EventConsumer:
+ """Reads system events published to a communications channel."""
+
+ _BACKBONE_WAIT_TIMEOUT = 10.0
+ """Maximum time (in seconds) to wait for the backbone to register the consumer"""
+
+ def __init__(
+ self,
+ comm_channel: CommChannelBase,
+ backbone: BackboneFeatureStore,
+ filters: t.Optional[t.List[str]] = None,
+ name: t.Optional[str] = None,
+ event_handler: t.Optional[t.Callable[[EventBase], None]] = None,
+ ) -> None:
+ """Initialize the EventConsumer instance.
+
+ :param comm_channel: Communications channel to listen to for events
+ :param backbone: The MLI backbone feature store
+ :param filters: A list of event types to deliver. when empty, all
+ events will be delivered
+ :param name: A user-friendly name for logging. If not provided, an
+ auto-generated GUID will be used
+ """
+ self._comm_channel = comm_channel
+ """The comm channel used by the consumer to receive messages. The channel
+ descriptor will be published for senders to discover."""
+ self._backbone = backbone
+ """The backbone instance used to bootstrap the instance. The EventConsumer
+ uses the backbone to discover where it can publish its descriptor."""
+ self._global_filters = filters or []
+ """A set of global filters to apply to incoming events. Global filters are
+ combined with per-call filters. Filters act as an allow-list."""
+ self._name = name or str(uuid.uuid4())
+ """User-friendly name assigned to a consumer for logging. Automatically
+ assigned if not provided."""
+ self._event_handler = event_handler
+ """The function that should be executed when an event
+ passed by the filters is received."""
+ self.listening = True
+ """Flag indicating that the consumer is currently listening for new
+ events. Setting this flag to `False` will cause any active calls to
+ `listen` to terminate."""
+
+ @property
+ def descriptor(self) -> str:
+ """The descriptor of the underlying comm channel.
+
+ :returns: The comm channel descriptor"""
+ return self._comm_channel.descriptor
+
+ @property
+ def name(self) -> str:
+ """The friendly name assigned to the consumer.
+
+ :returns: The consumer name if one is assigned, otherwise a unique
+ id assigned by the system.
+ """
+ return self._name
+
+ def recv(
+ self,
+ filters: t.Optional[t.List[str]] = None,
+ timeout: float = 0.001,
+ batch_timeout: float = 1.0,
+ ) -> t.List[EventBase]:
+ """Receives available published event(s).
+
+ :param filters: Additional filters to add to the global filters configured
+ on the EventConsumer instance
+ :param timeout: Maximum time to wait for a single message to arrive
+ :param batch_timeout: Maximum time to wait for messages to arrive; allows
+ multiple batches to be retrieved in one call to `send`
+ :returns: A list of events that pass any configured filters
+ :raises ValueError: If a positive, non-zero value is not provided for the
+ timeout or batch_timeout.
+ """
+ if filters is None:
+ filters = []
+
+ if timeout is not None and timeout <= 0:
+ raise ValueError("request timeout must be a non-zero, positive value")
+
+ if batch_timeout is not None and batch_timeout <= 0:
+ raise ValueError("batch_timeout must be a non-zero, positive value")
+
+ filter_set = {*self._global_filters, *filters}
+ all_message_bytes: t.List[bytes] = []
+
+ # firehose as many messages as possible within the batch_timeout
+ start_at = time.time()
+ remaining = batch_timeout
+
+ batch_message_bytes = self._comm_channel.recv(timeout=timeout)
+ while batch_message_bytes:
+ # remove any empty messages that will fail to decode
+ all_message_bytes.extend(batch_message_bytes)
+ batch_message_bytes = []
+
+ # avoid getting stuck indefinitely waiting for the channel
+ elapsed = time.time() - start_at
+ remaining = batch_timeout - elapsed
+
+ if remaining > 0:
+ batch_message_bytes = self._comm_channel.recv(timeout=timeout)
+
+ events_received: t.List[EventBase] = []
+
+ # Timeout elapsed or no messages received - return the empty list
+ if not all_message_bytes:
+ return events_received
+
+ for message in all_message_bytes:
+ if not message or message is None:
+ continue
+
+ event = pickle.loads(message)
+ if not event:
+ logger.warning(f"Consumer {self.name} is unable to unpickle message")
+ continue
+
+ # skip events that don't pass a filter
+ if filter_set and event.category not in filter_set:
+ continue
+
+ events_received.append(event)
+
+ return events_received
+
+ def _send_to_registrar(self, event: EventBase) -> None:
+ """Send an event direct to the registrar listener."""
+ registrar_key = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER
+ config = self._backbone.wait_for([registrar_key], self._BACKBONE_WAIT_TIMEOUT)
+ registrar_descriptor = str(config.get(registrar_key, None))
+
+ if not registrar_descriptor:
+ logger.warning(
+ f"Unable to send {event.category} from {self.name}. "
+ "No registrar channel found."
+ )
+ return
+
+ logger.debug(f"Sending {event.category} from {self.name}")
+
+ registrar_channel = DragonCommChannel.from_descriptor(registrar_descriptor)
+ registrar_channel.send(bytes(event), timeout=1.0)
+
+ logger.debug(f"{event.category} from {self.name} sent")
+
+ def register(self) -> None:
+ """Send an event to register this consumer as a listener."""
+ descriptor = self._comm_channel.descriptor
+ event = OnCreateConsumer(self.name, descriptor, self._global_filters)
+
+ self._send_to_registrar(event)
+
+ def unregister(self) -> None:
+ """Send an event to un-register this consumer as a listener."""
+ descriptor = self._comm_channel.descriptor
+ event = OnRemoveConsumer(self.name, descriptor)
+
+ self._send_to_registrar(event)
+
+ def _on_handler_missing(self, event: EventBase) -> None:
+ """A "dead letter" event handler that is called to perform
+ processing on events before they're discarded.
+
+ :param event: The event to handle
+ """
+ logger.warning(
+ "No event handler is registered in consumer "
+ f"{self.name}. Discarding {event=}"
+ )
+
+ def listen_once(self, timeout: float = 0.001, batch_timeout: float = 1.0) -> None:
+ """Receives messages for the consumer a single time. Delivers
+ all messages that pass the consumer filters. Shutdown requests
+ are handled by a default event handler.
+
+
+ NOTE: Executes a single batch-retrieval to receive the maximum
+ number of messages available under batch timeout. To continually
+ listen, use `listen` in a non-blocking thread/process
+
+ :param timeout: Maximum time to wait (in seconds) for a message to arrive
+ :param timeout: Maximum time to wait (in seconds) for a batch to arrive
+ """
+ logger.info(
+ f"Consumer {self.name} listening with {timeout} second timeout"
+ f" on channel {self._comm_channel.descriptor}"
+ )
+
+ if not self._event_handler:
+ logger.info("Unable to handle messages. No event handler is registered.")
+
+ incoming_messages = self.recv(timeout=timeout, batch_timeout=batch_timeout)
+
+ if not incoming_messages:
+ logger.info(f"Consumer {self.name} received empty message list")
+
+ for message in incoming_messages:
+ logger.info(f"Consumer {self.name} is handling event {message=}")
+ self._handle_shutdown(message)
+
+ if self._event_handler:
+ self._event_handler(message)
+ else:
+ self._on_handler_missing(message)
+
+ def _handle_shutdown(self, event: EventBase) -> bool:
+ """Handles shutdown requests sent to the consumer by setting the
+ `self.listener` property to `False`.
+
+ :param event: The event to handle
+ :returns: A bool indicating if the event was a shutdown request
+ """
+ if isinstance(event, OnShutdownRequested):
+ logger.debug(f"Shutdown requested from: {event.source}")
+ self.listening = False
+ return True
+ return False
+
+ def listen(self, timeout: float = 0.001, batch_timeout: float = 1.0) -> None:
+ """Receives messages for the consumer until a shutdown request is received.
+
+ :param timeout: Maximum time to wait (in seconds) for a message to arrive
+ :param batch_timeout: Maximum time to wait (in seconds) for a batch to arrive
+ """
+
+ logger.debug(f"Consumer {self.name} is now listening for events.")
+
+ while self.listening:
+ self.listen_once(timeout, batch_timeout)
+
+ logger.debug(f"Consumer {self.name} is no longer listening.")
diff --git a/smartsim/_core/mli/infrastructure/comm/event.py b/smartsim/_core/mli/infrastructure/comm/event.py
new file mode 100644
index 0000000000..ccef9f9b86
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/comm/event.py
@@ -0,0 +1,162 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pickle
+import typing as t
+import uuid
+from dataclasses import dataclass, field
+
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+@dataclass
+class EventBase:
+ """Core API for an event."""
+
+ category: str
+ """Unique category name for an event class"""
+ source: str
+ """A unique identifier for the publisher of the event"""
+ uid: str = field(default_factory=lambda: str(uuid.uuid4()))
+ """A unique identifier for this event"""
+
+ def __bytes__(self) -> bytes:
+ """Default conversion to bytes for an event required to publish
+ messages using byte-oriented communication channels.
+
+ :returns: This entity encoded as bytes"""
+ return pickle.dumps(self)
+
+ def __str__(self) -> str:
+ """Convert the event to a string.
+
+ :returns: A string representation of this instance"""
+ return f"{self.uid}|{self.category}"
+
+
+class OnShutdownRequested(EventBase):
+ """Publish this event to trigger the listener to shutdown."""
+
+ SHUTDOWN: t.ClassVar[str] = "consumer-unregister"
+ """Unique category name for an event raised when a new consumer is unregistered"""
+
+ def __init__(self, source: str) -> None:
+ """Initialize the event instance.
+
+ :param source: A unique identifier for the publisher of the event
+ creating the event
+ """
+ super().__init__(self.SHUTDOWN, source)
+
+
+class OnCreateConsumer(EventBase):
+ """Publish this event when a new event consumer registration is required."""
+
+ descriptor: str
+ """Descriptor of the comm channel exposed by the consumer"""
+ filters: t.List[str] = field(default_factory=list)
+ """The collection of filters indicating messages of interest to this consumer"""
+
+ CONSUMER_CREATED: t.ClassVar[str] = "consumer-created"
+ """Unique category name for an event raised when a new consumer is registered"""
+
+ def __init__(self, source: str, descriptor: str, filters: t.Sequence[str]) -> None:
+ """Initialize the event instance.
+
+ :param source: A unique identifier for the publisher of the event
+ :param descriptor: Descriptor of the comm channel exposed by the consumer
+ :param filters: Collection of filters indicating messages of interest
+ """
+ super().__init__(self.CONSUMER_CREATED, source)
+ self.descriptor = descriptor
+ self.filters = list(filters)
+
+ def __str__(self) -> str:
+ """Convert the event to a string.
+
+ :returns: A string representation of this instance
+ """
+ _filters = ",".join(self.filters)
+ return f"{str(super())}|{self.descriptor}|{_filters}"
+
+
+class OnRemoveConsumer(EventBase):
+ """Publish this event when a consumer is shutting down and
+ should be removed from notification lists."""
+
+ descriptor: str
+ """Descriptor of the comm channel exposed by the consumer"""
+
+ CONSUMER_REMOVED: t.ClassVar[str] = "consumer-removed"
+ """Unique category name for an event raised when a new consumer is unregistered"""
+
+ def __init__(self, source: str, descriptor: str) -> None:
+ """Initialize the OnRemoveConsumer event.
+
+ :param source: A unique identifier for the publisher of the event
+ :param descriptor: Descriptor of the comm channel exposed by the consumer
+ """
+ super().__init__(self.CONSUMER_REMOVED, source)
+ self.descriptor = descriptor
+
+ def __str__(self) -> str:
+ """Convert the event to a string.
+
+ :returns: A string representation of this instance
+ """
+ return f"{str(super())}|{self.descriptor}"
+
+
+class OnWriteFeatureStore(EventBase):
+ """Publish this event when a feature store key is written."""
+
+ descriptor: str
+ """The descriptor of the feature store where the write occurred"""
+ key: str
+ """The key identifying where the write occurred"""
+
+ FEATURE_STORE_WRITTEN: str = "feature-store-written"
+ """Event category for an event raised when a feature store key is written"""
+
+ def __init__(self, source: str, descriptor: str, key: str) -> None:
+ """Initialize the OnWriteFeatureStore event.
+
+ :param source: A unique identifier for the publisher of the event
+ :param descriptor: The descriptor of the feature store where the write occurred
+ :param key: The key identifying where the write occurred
+ """
+ super().__init__(self.FEATURE_STORE_WRITTEN, source)
+ self.descriptor = descriptor
+ self.key = key
+
+ def __str__(self) -> str:
+ """Convert the event to a string.
+
+ :returns: A string representation of this instance
+ """
+ return f"{str(super())}|{self.descriptor}|{self.key}"
diff --git a/smartsim/_core/mli/infrastructure/comm/producer.py b/smartsim/_core/mli/infrastructure/comm/producer.py
new file mode 100644
index 0000000000..2d8a7c14ad
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/comm/producer.py
@@ -0,0 +1,44 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import typing as t
+
+from smartsim._core.mli.infrastructure.comm.event import EventBase
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class EventProducer(t.Protocol):
+ """Core API of a class that publishes events."""
+
+ def send(self, event: EventBase, timeout: float = 0.001) -> int:
+ """Send an event using the configured comm channel.
+
+ :param event: The event to send
+ :param timeout: Maximum time to wait (in seconds) for messages to send
+ :returns: The number of messages that were sent
+ """
diff --git a/smartsim/_core/mli/infrastructure/control/__init__.py b/smartsim/_core/mli/infrastructure/control/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/smartsim/_core/mli/infrastructure/control/device_manager.py b/smartsim/_core/mli/infrastructure/control/device_manager.py
new file mode 100644
index 0000000000..9334971f8c
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/control/device_manager.py
@@ -0,0 +1,166 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import typing as t
+from contextlib import _GeneratorContextManager, contextmanager
+
+from .....log import get_logger
+from ..storage.feature_store import FeatureStore
+from ..worker.worker import MachineLearningWorkerBase, RequestBatch
+
+logger = get_logger(__name__)
+
+
+class WorkerDevice:
+ def __init__(self, name: str) -> None:
+ """Wrapper around a device to keep track of loaded Models and availability.
+
+ :param name: Name used by the toolkit to identify this device, e.g. ``cuda:0``
+ """
+ self._name = name
+ """The name used by the toolkit to identify this device"""
+ self._models: dict[str, t.Any] = {}
+ """Dict of keys to models which are loaded on this device"""
+
+ @property
+ def name(self) -> str:
+ """The identifier of the device represented by this object
+
+ :returns: Name used by the toolkit to identify this device
+ """
+ return self._name
+
+ def add_model(self, key: str, model: t.Any) -> None:
+ """Add a reference to a model loaded on this device and assign it a key.
+
+ :param key: The key under which the model is saved
+ :param model: The model which is added
+ """
+ self._models[key] = model
+
+ def remove_model(self, key: str) -> None:
+ """Remove the reference to a model loaded on this device.
+
+ :param key: The key of the model to remove
+ :raises KeyError: If key does not exist for removal
+ """
+ try:
+ self._models.pop(key)
+ except KeyError:
+ logger.warning(f"An unknown key was requested for removal: {key}")
+ raise
+
+ def get_model(self, key: str) -> t.Any:
+ """Get the model corresponding to a given key.
+
+ :param key: The model key
+ :returns: The model for the given key
+ :raises KeyError: If key does not exist
+ """
+ try:
+ return self._models[key]
+ except KeyError:
+ logger.warning(f"An unknown key was requested: {key}")
+ raise
+
+ def __contains__(self, key: str) -> bool:
+ """Check if model with a given key is available on the device.
+
+ :param key: The key of the model to check for existence
+ :returns: Whether the model is available on the device
+ """
+ return key in self._models
+
+ @contextmanager
+ def get(self, key_to_remove: t.Optional[str]) -> t.Iterator["WorkerDevice"]:
+ """Get the WorkerDevice generator and optionally remove a model.
+
+ :param key_to_remove: The key of the model to optionally remove
+ :returns: WorkerDevice generator
+ """
+ yield self
+ if key_to_remove is not None:
+ self.remove_model(key_to_remove)
+
+
+class DeviceManager:
+ def __init__(self, device: WorkerDevice):
+ """An object to manage devices such as GPUs and CPUs.
+
+ The main goal of the ``DeviceManager`` is to ensure that
+ the managed device is ready to be used by a worker to
+ run a given model.
+
+ :param device: The managed device
+ """
+ self._device = device
+ """Device managed by this object"""
+
+ def _load_model_on_device(
+ self,
+ worker: MachineLearningWorkerBase,
+ batch: RequestBatch,
+ feature_stores: dict[str, FeatureStore],
+ ) -> None:
+ """Load the model needed to execute a batch on the managed device.
+
+ The model is loaded by the worker.
+
+ :param worker: The worker that loads the model
+ :param batch: The batch for which the model is needed
+ :param feature_stores: Feature stores where the model could be stored
+ """
+
+ model_bytes = worker.fetch_model(batch, feature_stores)
+ loaded_model = worker.load_model(batch, model_bytes, self._device.name)
+ self._device.add_model(batch.model_id.key, loaded_model.model)
+
+ def get_device(
+ self,
+ worker: MachineLearningWorkerBase,
+ batch: RequestBatch,
+ feature_stores: dict[str, FeatureStore],
+ ) -> _GeneratorContextManager[WorkerDevice]:
+ """Get the device managed by this object.
+
+ The model needed to run the batch of requests is
+ guaranteed to be available on the device.
+
+ :param worker: The worker that wants to access the device
+ :param batch: The batch of requests
+ :param feature_store: The feature store on which part of the
+ data needed by the request may be stored
+ :returns: A generator yielding the device
+ """
+ model_in_request = batch.has_raw_model
+
+ # Load model if not already loaded, or
+ # because it is sent with the request
+ if model_in_request or not batch.model_id.key in self._device:
+ self._load_model_on_device(worker, batch, feature_stores)
+
+ key_to_remove = batch.model_id.key if model_in_request else None
+ return self._device.get(key_to_remove)
diff --git a/smartsim/_core/mli/infrastructure/control/dragon_util.py b/smartsim/_core/mli/infrastructure/control/dragon_util.py
new file mode 100644
index 0000000000..95c3e60524
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/control/dragon_util.py
@@ -0,0 +1,79 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from __future__ import annotations
+
+import os
+import socket
+import typing as t
+
+import pytest
+from smartsim.log import get_logger
+
+dragon = pytest.importorskip("dragon")
+
+# isort: off
+
+import dragon.infrastructure.policy as dragon_policy
+import dragon.infrastructure.process_desc as dragon_process_desc
+import dragon.native.process as dragon_process
+
+# isort: on
+
+
+logger = get_logger(__name__)
+
+
+def function_as_dragon_proc(
+ entrypoint_fn: t.Callable[[t.Any], None],
+ args: t.List[t.Any],
+ cpu_affinity: t.List[int],
+ gpu_affinity: t.List[int],
+) -> dragon_process.Process:
+ """Execute a function as an independent dragon process.
+
+ :param entrypoint_fn: The function to execute
+ :param args: The arguments for the entrypoint function
+ :param cpu_affinity: The cpu affinity for the process
+ :param gpu_affinity: The gpu affinity for the process
+ :returns: The dragon process handle
+ """
+ options = dragon_process_desc.ProcessOptions(make_inf_channels=True)
+ local_policy = dragon_policy.Policy(
+ placement=dragon_policy.Policy.Placement.HOST_NAME,
+ host_name=socket.gethostname(),
+ cpu_affinity=cpu_affinity,
+ gpu_affinity=gpu_affinity,
+ )
+ return dragon_process.Process(
+ target=entrypoint_fn,
+ args=args,
+ cwd=os.getcwd(),
+ policy=local_policy,
+ options=options,
+ stderr=dragon_process.Popen.STDOUT,
+ stdout=dragon_process.Popen.STDOUT,
+ )
diff --git a/smartsim/_core/mli/infrastructure/control/error_handling.py b/smartsim/_core/mli/infrastructure/control/error_handling.py
new file mode 100644
index 0000000000..a75f533a37
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/control/error_handling.py
@@ -0,0 +1,78 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import typing as t
+
+from .....log import get_logger
+from ...comm.channel.channel import CommChannelBase
+from ...message_handler import MessageHandler
+from ...mli_schemas.response.response_capnp import ResponseBuilder
+
+if t.TYPE_CHECKING:
+ from smartsim._core.mli.mli_schemas.response.response_capnp import Status
+
+logger = get_logger(__file__)
+
+
+def build_failure_reply(status: "Status", message: str) -> ResponseBuilder:
+ """
+ Builds a failure response message.
+
+ :param status: Status enum
+ :param message: Status message
+ :returns: Failure response
+ """
+ return MessageHandler.build_response(
+ status=status,
+ message=message,
+ result=None,
+ custom_attributes=None,
+ )
+
+
+def exception_handler(
+ exc: Exception,
+ reply_channel: t.Optional[CommChannelBase],
+ failure_message: t.Optional[str],
+) -> None:
+ """
+ Logs exceptions and sends a failure response.
+
+ :param exc: The exception to be logged
+ :param reply_channel: The channel used to send replies
+ :param failure_message: Failure message to log and send back
+ """
+ logger.exception(exc)
+ if reply_channel:
+ if failure_message is None:
+ failure_message = str(exc)
+
+ serialized_resp = MessageHandler.serialize_response(
+ build_failure_reply("fail", failure_message)
+ )
+ reply_channel.send(serialized_resp)
+ else:
+ logger.warning("Unable to notify client of error without a reply channel")
diff --git a/smartsim/_core/mli/infrastructure/control/listener.py b/smartsim/_core/mli/infrastructure/control/listener.py
new file mode 100644
index 0000000000..56a7b12d34
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/control/listener.py
@@ -0,0 +1,352 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# isort: off
+# pylint: disable=import-error
+# pylint: disable=unused-import
+import socket
+import dragon
+
+# pylint: enable=unused-import
+# pylint: enable=import-error
+# isort: on
+
+import argparse
+import multiprocessing as mp
+import os
+import sys
+import typing as t
+
+from smartsim._core.entrypoints.service import Service
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.comm.channel.dragon_util import create_local
+from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer
+from smartsim._core.mli.infrastructure.comm.event import (
+ EventBase,
+ OnCreateConsumer,
+ OnRemoveConsumer,
+ OnShutdownRequested,
+)
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class ConsumerRegistrationListener(Service):
+ """A long-running service that manages the list of consumers receiving
+ events that are broadcast. It hosts handlers for adding and removing consumers
+ """
+
+ def __init__(
+ self,
+ backbone: BackboneFeatureStore,
+ timeout: float,
+ batch_timeout: float,
+ as_service: bool = False,
+ cooldown: int = 0,
+ health_check_frequency: float = 60.0,
+ ) -> None:
+ """Initialize the EventListener.
+
+ :param backbone: The backbone feature store
+ :param timeout: Maximum time (in seconds) to allow a single recv request to wait
+ :param batch_timeout: Maximum time (in seconds) to allow a batch of receives to
+ continue to build
+ :param as_service: Specifies run-once or run-until-complete behavior of service
+ :param cooldown: Number of seconds to wait before shutting down after
+ shutdown criteria are met
+ """
+ super().__init__(
+ as_service, cooldown, health_check_frequency=health_check_frequency
+ )
+ self._timeout = timeout
+ """ Maximum time (in seconds) to allow a single recv request to wait"""
+ self._batch_timeout = batch_timeout
+ """Maximum time (in seconds) to allow a batch of receives to
+ continue to build"""
+ self._consumer: t.Optional[EventConsumer] = None
+ """The event consumer that handles receiving events"""
+ self._backbone = backbone
+ """A standalone, system-created feature store used to share internal
+ information among MLI components"""
+
+ def _on_start(self) -> None:
+ """Called on initial entry into Service `execute` event loop before
+ `_on_iteration` is invoked."""
+ super()._on_start()
+ self._create_eventing()
+
+ def _on_shutdown(self) -> None:
+ """Release dragon resources. Called immediately after exiting
+ the main event loop during automatic shutdown."""
+ super()._on_shutdown()
+
+ if not self._consumer:
+ return
+
+ # remove descriptor for this listener from the backbone if it's there
+ if registered_consumer := self._backbone.backend_channel:
+ # if there is a descriptor in the backbone and it's still this listener
+ if registered_consumer == self._consumer.descriptor:
+ logger.info(
+ f"Listener clearing backend consumer {self._consumer.name} "
+ "from backbone"
+ )
+
+ # unregister this listener in the backbone
+ self._backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER)
+
+ # TODO: need the channel to be cleaned up
+ # self._consumer._comm_channel._channel.destroy()
+
+ def _on_iteration(self) -> None:
+ """Executes calls to the machine learning worker implementation to complete
+ the inference pipeline."""
+
+ if self._consumer is None:
+ logger.info("Unable to listen. No consumer available.")
+ return
+
+ self._consumer.listen_once(self._timeout, self._batch_timeout)
+
+ def _can_shutdown(self) -> bool:
+ """Determines if the event consumer is ready to stop listening.
+
+ :returns: True when criteria to shutdown the service are met, False otherwise
+ """
+
+ if self._backbone is None:
+ logger.info("Listener must shutdown. No backbone attached")
+ return True
+
+ if self._consumer is None:
+ logger.info("Listener must shutdown. No consumer channel created")
+ return True
+
+ if not self._consumer.listening:
+ logger.info(
+ f"Listener can shutdown. Consumer `{self._consumer.name}` "
+ "is not listening"
+ )
+ return True
+
+ return False
+
+ def _on_unregister(self, event: OnRemoveConsumer) -> None:
+ """Event handler for updating the backbone when event consumers
+ are un-registered.
+
+ :param event: The event that was received
+ """
+ notify_list = set(self._backbone.notification_channels)
+
+ # remove the descriptor specified in the event
+ if event.descriptor in notify_list:
+ logger.debug(f"Removing notify consumer: {event.descriptor}")
+ notify_list.remove(event.descriptor)
+
+ # push the updated list back into the backbone
+ self._backbone.notification_channels = list(notify_list)
+
+ def _on_register(self, event: OnCreateConsumer) -> None:
+ """Event handler for updating the backbone when new event consumers
+ are registered.
+
+ :param event: The event that was received
+ """
+ notify_list = set(self._backbone.notification_channels)
+ logger.debug(f"Adding notify consumer: {event.descriptor}")
+ notify_list.add(event.descriptor)
+ self._backbone.notification_channels = list(notify_list)
+
+ def _on_event_received(self, event: EventBase) -> None:
+ """Primary event handler for the listener. Distributes events to
+ type-specific handlers.
+
+ :param event: The event that was received
+ """
+ if self._backbone is None:
+ logger.info("Unable to handle event. Backbone is missing.")
+
+ if isinstance(event, OnCreateConsumer):
+ self._on_register(event)
+ elif isinstance(event, OnRemoveConsumer):
+ self._on_unregister(event)
+ else:
+ logger.info(
+ "Consumer registration listener received an "
+ f"unexpected event: {event=}"
+ )
+
+ def _on_health_check(self) -> None:
+ """Check if this consumer has been replaced by a new listener
+ and automatically trigger a shutdown. Invoked based on the
+ value of `self._health_check_frequency`."""
+ super()._on_health_check()
+
+ try:
+ logger.debug("Retrieving registered listener descriptor")
+ descriptor = self._backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER]
+ except KeyError:
+ descriptor = None
+ if self._consumer:
+ self._consumer.listening = False
+
+ if self._consumer and descriptor != self._consumer.descriptor:
+ logger.warning(
+ f"Consumer `{self._consumer.name}` for `ConsumerRegistrationListener` "
+ "is no longer registered. It will automatically shut down."
+ )
+ self._consumer.listening = False
+
+ def _publish_consumer(self) -> None:
+ """Publish the registrar consumer descriptor to the backbone."""
+ if self._consumer is None:
+ logger.warning("No registrar consumer descriptor available to publisher")
+ return
+
+ logger.debug(f"Publishing {self._consumer.descriptor} to backbone")
+ self._backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = (
+ self._consumer.descriptor
+ )
+
+ def _create_eventing(self) -> EventConsumer:
+ """
+ Create an event publisher and event consumer for communicating with
+ other MLI resources.
+
+ NOTE: the backbone must be initialized before connecting eventing clients.
+
+ :returns: The newly created EventConsumer instance
+ :raises SmartSimError: If a listener channel cannot be created
+ """
+
+ if self._consumer:
+ return self._consumer
+
+ logger.info("Creating event consumer")
+
+ dragon_channel = create_local(500)
+ event_channel = DragonCommChannel(dragon_channel)
+
+ if not event_channel.descriptor:
+ raise SmartSimError(
+ "Unable to generate the descriptor for the event channel"
+ )
+
+ self._consumer = EventConsumer(
+ event_channel,
+ self._backbone,
+ [
+ OnCreateConsumer.CONSUMER_CREATED,
+ OnRemoveConsumer.CONSUMER_REMOVED,
+ OnShutdownRequested.SHUTDOWN,
+ ],
+ name=f"ConsumerRegistrar.{socket.gethostname()}",
+ event_handler=self._on_event_received,
+ )
+ self._publish_consumer()
+
+ logger.info(
+ f"Backend consumer `{self._consumer.name}` created: "
+ f"{self._consumer.descriptor}"
+ )
+
+ return self._consumer
+
+
+def _create_parser() -> argparse.ArgumentParser:
+ """
+ Create an argument parser that contains the arguments
+ required to start the listener as a new process:
+
+ --timeout
+ --batch_timeout
+
+ :returns: A configured parser
+ """
+ arg_parser = argparse.ArgumentParser(prog="ConsumerRegistrarEventListener")
+
+ arg_parser.add_argument("--timeout", type=float, default=1.0)
+ arg_parser.add_argument("--batch_timeout", type=float, default=1.0)
+
+ return arg_parser
+
+
+def _connect_backbone() -> t.Optional[BackboneFeatureStore]:
+ """
+ Load the backbone by retrieving the descriptor from environment variables.
+
+ :returns: The backbone feature store
+ :raises SmartSimError: if a descriptor is not found
+ """
+ descriptor = os.environ.get(BackboneFeatureStore.MLI_BACKBONE, "")
+
+ if not descriptor:
+ return None
+
+ logger.info(f"Listener backbone descriptor: {descriptor}\n")
+
+ # `from_writable_descriptor` ensures we can update the backbone
+ return BackboneFeatureStore.from_writable_descriptor(descriptor)
+
+
+if __name__ == "__main__":
+ mp.set_start_method("dragon")
+
+ parser = _create_parser()
+ args = parser.parse_args()
+
+ backbone_fs = _connect_backbone()
+
+ if backbone_fs is None:
+ logger.error(
+ "Unable to attach to the backbone without the "
+ f"`{BackboneFeatureStore.MLI_BACKBONE}` environment variable."
+ )
+ sys.exit(1)
+
+ logger.debug(f"Listener attached to backbone: {backbone_fs.descriptor}")
+
+ listener = ConsumerRegistrationListener(
+ backbone_fs,
+ float(args.timeout),
+ float(args.batch_timeout),
+ as_service=True,
+ )
+
+ logger.info(f"listener created? {listener}")
+
+ try:
+ listener.execute()
+ sys.exit(0)
+ except Exception:
+ logger.exception("An error occurred in the event listener")
+ sys.exit(1)
diff --git a/smartsim/_core/mli/infrastructure/control/request_dispatcher.py b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py
new file mode 100644
index 0000000000..e22a2c8f62
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py
@@ -0,0 +1,559 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# pylint: disable=import-error
+# pylint: disable-next=unused-import
+import dragon
+import dragon.globalservices.pool as dragon_gs_pool
+from dragon.managed_memory import MemoryPool
+from dragon.mpbridge.queues import DragonQueue
+
+# pylint: enable=import-error
+
+# isort: off
+# isort: on
+
+import multiprocessing as mp
+import time
+import typing as t
+import uuid
+from queue import Empty, Full, Queue
+
+from smartsim._core.entrypoints.service import Service
+
+from .....error import SmartSimError
+from .....log import get_logger
+from ....utils.timings import PerfTimer
+from ..environment_loader import EnvironmentConfigLoader
+from ..storage.feature_store import FeatureStore
+from ..worker.worker import (
+ InferenceRequest,
+ MachineLearningWorkerBase,
+ ModelIdentifier,
+ RequestBatch,
+)
+from .error_handling import exception_handler
+
+if t.TYPE_CHECKING:
+ from smartsim._core.mli.mli_schemas.response.response_capnp import Status
+
+logger = get_logger("Request Dispatcher")
+
+
+class BatchQueue(Queue[InferenceRequest]):
+ def __init__(
+ self, batch_timeout: float, batch_size: int, model_id: ModelIdentifier
+ ) -> None:
+ """Queue used to store inference requests waiting to be batched and
+ sent to Worker Managers.
+
+ :param batch_timeout: Time in seconds that has to be waited before flushing a
+ non-full queue. The time of the first item put is 0 seconds.
+ :param batch_size: Total capacity of the queue
+ :param model_id: Key of the model which needs to be executed on the queued
+ requests
+ """
+ super().__init__(maxsize=batch_size)
+ self._batch_timeout = batch_timeout
+ """Time in seconds that has to be waited before flushing a non-full queue.
+ The time of the first item put is 0 seconds."""
+ self._batch_size = batch_size
+ """Total capacity of the queue"""
+ self._first_put: t.Optional[float] = None
+ """Time at which the first item was put on the queue"""
+ self._disposable = False
+ """Whether the queue will not be used again and can be deleted.
+ A disposable queue is always full."""
+ self._model_id: ModelIdentifier = model_id
+ """Key of the model which needs to be executed on the queued requests"""
+ self._uid = str(uuid.uuid4())
+ """Unique ID of queue"""
+
+ @property
+ def uid(self) -> str:
+ """ID of this queue.
+
+ :returns: Queue ID
+ """
+ return self._uid
+
+ @property
+ def model_id(self) -> ModelIdentifier:
+ """Key of the model which needs to be run on the queued requests.
+
+ :returns: Model key
+ """
+ return self._model_id
+
+ def put(
+ self,
+ item: InferenceRequest,
+ block: bool = False,
+ timeout: t.Optional[float] = 0.0,
+ ) -> None:
+ """Put an inference request in the queue.
+
+ :param item: The request
+ :param block: Whether to block when trying to put the item
+ :param timeout: Time (in seconds) to wait if block==True
+ :raises Full: If an item cannot be put on the queue
+ """
+ super().put(item, block=block, timeout=timeout)
+ if self._first_put is None:
+ self._first_put = time.time()
+
+ @property
+ def _elapsed_time(self) -> float:
+ """Time elapsed since the first item was put on this queue.
+
+ :returns: Time elapsed
+ """
+ if self.empty() or self._first_put is None:
+ return 0
+ return time.time() - self._first_put
+
+ @property
+ def ready(self) -> bool:
+ """Check if the queue can be flushed.
+
+ :returns: True if the queue can be flushed, False otherwise
+ """
+ if self.empty():
+ logger.debug("Request dispatcher queue is empty")
+ return False
+
+ timed_out = False
+ if self._batch_timeout >= 0:
+ timed_out = self._elapsed_time >= self._batch_timeout
+
+ if self.full():
+ logger.debug("Request dispatcher ready to deliver full batch")
+ return True
+
+ if timed_out:
+ logger.debug("Request dispatcher delivering partial batch")
+ return True
+
+ return False
+
+ def make_disposable(self) -> None:
+ """Set this queue as disposable, and never use it again after it gets
+ flushed."""
+ self._disposable = True
+
+ @property
+ def can_be_removed(self) -> bool:
+ """Determine whether this queue can be deleted and garbage collected.
+
+ :returns: True if queue can be removed, False otherwise
+ """
+ return self.empty() and self._disposable
+
+ def flush(self) -> list[t.Any]:
+ """Get all requests from queue.
+
+ :returns: Requests waiting to be executed
+ """
+ num_items = self.qsize()
+ self._first_put = None
+ items = []
+ for _ in range(num_items):
+ try:
+ items.append(self.get())
+ except Empty:
+ break
+
+ return items
+
+ def full(self) -> bool:
+ """Check if the queue has reached its maximum capacity.
+
+ :returns: True if the queue has reached its maximum capacity,
+ False otherwise
+ """
+ if self._disposable:
+ return True
+ return self.qsize() >= self._batch_size
+
+ def empty(self) -> bool:
+ """Check if the queue is empty.
+
+ :returns: True if the queue has 0 elements, False otherwise
+ """
+ return self.qsize() == 0
+
+
+class RequestDispatcher(Service):
+ def __init__(
+ self,
+ batch_timeout: float,
+ batch_size: int,
+ config_loader: EnvironmentConfigLoader,
+ worker_type: t.Type[MachineLearningWorkerBase],
+ mem_pool_size: int = 2 * 1024**3,
+ ) -> None:
+ """The RequestDispatcher intercepts inference requests, stages them in
+ queues and batches them together before making them available to Worker
+ Managers.
+
+ :param batch_timeout: Maximum elapsed time before flushing a complete or
+ incomplete batch
+ :param batch_size: Total capacity of each batch queue
+ :param mem_pool: Memory pool used to share batched input tensors with worker
+ managers
+ :param config_loader: Object to load configuration from environment
+ :param worker_type: Type of worker to instantiate to batch inputs
+ :param mem_pool_size: Size of the memory pool used to allocate tensors
+ """
+ super().__init__(as_service=True, cooldown=1)
+ self._queues: dict[str, list[BatchQueue]] = {}
+ """Dict of all batch queues available for a given model id"""
+ self._active_queues: dict[str, BatchQueue] = {}
+ """Mapping telling which queue is the recipient of requests for a given model
+ key"""
+ self._batch_timeout = batch_timeout
+ """Time in seconds that has to be waited before flushing a non-full queue"""
+ self._batch_size = batch_size
+ """Total capacity of each batch queue"""
+ incoming_channel = config_loader.get_queue()
+ if incoming_channel is None:
+ raise SmartSimError("No incoming channel for dispatcher")
+ self._incoming_channel = incoming_channel
+ """The channel the dispatcher monitors for new tasks"""
+ self._outgoing_queue: DragonQueue = mp.Queue(maxsize=0)
+ """The queue on which batched inference requests are placed"""
+ self._feature_stores: t.Dict[str, FeatureStore] = {}
+ """A collection of attached feature stores"""
+ self._featurestore_factory = config_loader._featurestore_factory
+ """A factory method to create a desired feature store client type"""
+ self._backbone: t.Optional[FeatureStore] = config_loader.get_backbone()
+ """A standalone, system-created feature store used to share internal
+ information among MLI components"""
+ self._callback_factory = config_loader._callback_factory
+ """The type of communication channel to construct for callbacks"""
+ self._worker = worker_type()
+ """The worker used to batch inputs"""
+ self._mem_pool = MemoryPool.attach(dragon_gs_pool.create(mem_pool_size).sdesc)
+ """Memory pool used to share batched input tensors with the Worker Managers"""
+ self._perf_timer = PerfTimer(prefix="r_", debug=False, timing_on=True)
+ """Performance timer"""
+
+ @property
+ def has_featurestore_factory(self) -> bool:
+ """Check if the RequestDispatcher has a FeatureStore factory.
+
+ :returns: True if there is a FeatureStore factory, False otherwise
+ """
+ return self._featurestore_factory is not None
+
+ def _check_feature_stores(self, request: InferenceRequest) -> bool:
+ """Ensures that all feature stores required by the request are available.
+
+ :param request: The request to validate
+ :returns: False if feature store validation fails for the request, True
+ otherwise
+ """
+ # collect all feature stores required by the request
+ fs_model: t.Set[str] = set()
+ if request.model_key:
+ fs_model = {request.model_key.descriptor}
+ fs_inputs = {key.descriptor for key in request.input_keys}
+ fs_outputs = {key.descriptor for key in request.output_keys}
+
+ # identify which feature stores are requested and unknown
+ fs_desired = fs_model.union(fs_inputs).union(fs_outputs)
+ fs_actual = {item.descriptor for item in self._feature_stores.values()}
+ fs_missing = fs_desired - fs_actual
+
+ if not self.has_featurestore_factory:
+ logger.error("No feature store factory is configured. Unable to dispatch.")
+ return False
+
+ # create the feature stores we need to service request
+ if fs_missing:
+ logger.debug(f"Adding feature store(s): {fs_missing}")
+ for descriptor in fs_missing:
+ feature_store = self._featurestore_factory(descriptor)
+ self._feature_stores[descriptor] = feature_store
+
+ return True
+
+ # pylint: disable-next=no-self-use
+ def _check_model(self, request: InferenceRequest) -> bool:
+ """Ensure that a model is available for the request.
+
+ :param request: The request to validate
+ :returns: False if model validation fails for the request, True otherwise
+ """
+ if request.has_model_key or request.has_raw_model:
+ return True
+
+ logger.error("Unable to continue without model bytes or feature store key")
+ return False
+
+ # pylint: disable-next=no-self-use
+ def _check_inputs(self, request: InferenceRequest) -> bool:
+ """Ensure that inputs are available for the request.
+
+ :param request: The request to validate
+ :returns: False if input validation fails for the request, True otherwise
+ """
+ if request.has_input_keys or request.has_raw_inputs:
+ return True
+
+ logger.error("Unable to continue without input bytes or feature store keys")
+ return False
+
+ # pylint: disable-next=no-self-use
+ def _check_callback(self, request: InferenceRequest) -> bool:
+ """Ensure that a callback channel is available for the request.
+
+ :param request: The request to validate
+ :returns: False if callback validation fails for the request, True otherwise
+ """
+ if request.callback:
+ return True
+
+ logger.error("No callback channel provided in request")
+ return False
+
+ def _validate_request(self, request: InferenceRequest) -> bool:
+ """Ensure the request can be processed.
+
+ :param request: The request to validate
+ :returns: False if the request fails any validation checks, True otherwise
+ """
+ checks = [
+ self._check_feature_stores(request),
+ self._check_model(request),
+ self._check_inputs(request),
+ self._check_callback(request),
+ ]
+
+ return all(checks)
+
+ def _on_iteration(self) -> None:
+ """This method is executed repeatedly until ``Service`` shutdown
+ conditions are satisfied and cooldown is elapsed."""
+ try:
+ self._perf_timer.is_active = True
+ bytes_list: t.List[bytes] = self._incoming_channel.recv()
+ except Exception:
+ self._perf_timer.is_active = False
+ else:
+ if not bytes_list:
+ exception_handler(
+ ValueError("No request data found"),
+ None,
+ None,
+ )
+
+ logger.debug(f"Dispatcher is processing {len(bytes_list)} messages")
+ request_bytes = bytes_list[0]
+ tensor_bytes_list = bytes_list[1:]
+ self._perf_timer.start_timings()
+
+ request = self._worker.deserialize_message(
+ request_bytes, self._callback_factory
+ )
+ if request.has_input_meta and tensor_bytes_list:
+ request.raw_inputs = tensor_bytes_list
+
+ self._perf_timer.measure_time("deserialize_message")
+
+ if not self._validate_request(request):
+ exception_handler(
+ ValueError("Error validating the request"),
+ request.callback,
+ None,
+ )
+ self._perf_timer.measure_time("validate_request")
+ else:
+ self._perf_timer.measure_time("validate_request")
+ self.dispatch(request)
+ self._perf_timer.measure_time("dispatch")
+ finally:
+ self.flush_requests()
+ self.remove_queues()
+
+ self._perf_timer.end_timings()
+
+ if self._perf_timer.max_length == 801 and self._perf_timer.is_active:
+ self._perf_timer.print_timings(True)
+
+ def remove_queues(self) -> None:
+ """Remove references to queues that can be removed
+ and allow them to be garbage collected."""
+ queue_lists_to_remove = []
+ for key, queues in self._queues.items():
+ queues_to_remove = []
+ for queue in queues:
+ if queue.can_be_removed:
+ queues_to_remove.append(queue)
+
+ for queue_to_remove in queues_to_remove:
+ queues.remove(queue_to_remove)
+ if (
+ key in self._active_queues
+ and self._active_queues[key] == queue_to_remove
+ ):
+ del self._active_queues[key]
+
+ if len(queues) == 0:
+ queue_lists_to_remove.append(key)
+
+ for key in queue_lists_to_remove:
+ del self._queues[key]
+
+ @property
+ def task_queue(self) -> DragonQueue:
+ """The queue on which batched requests are placed.
+
+ :returns: The queue
+ """
+ return self._outgoing_queue
+
+ def _swap_queue(self, model_id: ModelIdentifier) -> None:
+ """Get an empty queue or create a new one
+ and make it the active one for a given model.
+
+ :param model_id: The id of the model for which the
+ queue has to be swapped
+ """
+ if model_id.key in self._queues:
+ for queue in self._queues[model_id.key]:
+ if not queue.full():
+ self._active_queues[model_id.key] = queue
+ return
+
+ new_queue = BatchQueue(self._batch_timeout, self._batch_size, model_id)
+ if model_id.key in self._queues:
+ self._queues[model_id.key].append(new_queue)
+ else:
+ self._queues[model_id.key] = [new_queue]
+ self._active_queues[model_id.key] = new_queue
+ return
+
+ def dispatch(self, request: InferenceRequest) -> None:
+ """Assign a request to a batch queue.
+
+ :param request: The request to place
+ """
+ if request.has_raw_model:
+ logger.debug("Direct inference requested, creating tmp queue")
+ tmp_id = f"_tmp_{str(uuid.uuid4())}"
+ tmp_queue: BatchQueue = BatchQueue(
+ batch_timeout=0,
+ batch_size=1,
+ model_id=ModelIdentifier(key=tmp_id, descriptor="TMP"),
+ )
+ self._active_queues[tmp_id] = tmp_queue
+ self._queues[tmp_id] = [tmp_queue]
+ tmp_queue.put(request)
+ tmp_queue.make_disposable()
+ return
+
+ if request.model_key:
+ success = False
+ while not success:
+ try:
+ self._active_queues[request.model_key.key].put_nowait(request)
+ success = True
+ except (Full, KeyError):
+ self._swap_queue(request.model_key)
+
+ def flush_requests(self) -> None:
+ """Get all requests from queues which are ready to be flushed. Place all
+ available request batches in the outgoing queue."""
+ for queue_list in self._queues.values():
+ for queue in queue_list:
+ if queue.ready:
+ self._perf_timer.measure_time("find_queue")
+ try:
+ batch = RequestBatch(
+ requests=queue.flush(),
+ inputs=None,
+ model_id=queue.model_id,
+ )
+ finally:
+ self._perf_timer.measure_time("flush_requests")
+ try:
+ fetch_results = self._worker.fetch_inputs(
+ batch=batch, feature_stores=self._feature_stores
+ )
+ except Exception as exc:
+ exception_handler(
+ exc,
+ None,
+ "Error fetching input.",
+ )
+ continue
+ self._perf_timer.measure_time("fetch_input")
+ try:
+ transformed_inputs = self._worker.transform_input(
+ batch=batch,
+ fetch_results=fetch_results,
+ mem_pool=self._mem_pool,
+ )
+ except Exception as exc:
+ exception_handler(
+ exc,
+ None,
+ "Error transforming input.",
+ )
+ continue
+
+ self._perf_timer.measure_time("transform_input")
+ batch.inputs = transformed_inputs
+ for request in batch.requests:
+ request.raw_inputs = []
+ request.input_meta = []
+
+ try:
+ self._outgoing_queue.put(batch)
+ except Exception as exc:
+ exception_handler(
+ exc,
+ None,
+ "Error placing batch on task queue.",
+ )
+ continue
+ self._perf_timer.measure_time("put")
+
+ def _can_shutdown(self) -> bool:
+ """Determine whether the Service can be shut down.
+
+ :returns: False
+ """
+ return False
+
+ def __del__(self) -> None:
+ """Destroy allocated memory resources."""
+ # pool may be null if a failure occurs prior to successful attach
+ pool: t.Optional[MemoryPool] = getattr(self, "_mem_pool", None)
+
+ if pool:
+ pool.destroy()
diff --git a/smartsim/_core/mli/infrastructure/control/worker_manager.py b/smartsim/_core/mli/infrastructure/control/worker_manager.py
new file mode 100644
index 0000000000..bf6fddb81d
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/control/worker_manager.py
@@ -0,0 +1,330 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# pylint: disable=import-error
+# pylint: disable-next=unused-import
+import dragon
+
+# pylint: enable=import-error
+
+# isort: off
+# isort: on
+
+import multiprocessing as mp
+import time
+import typing as t
+from queue import Empty
+
+from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore
+
+from .....log import get_logger
+from ....entrypoints.service import Service
+from ....utils.timings import PerfTimer
+from ...message_handler import MessageHandler
+from ..environment_loader import EnvironmentConfigLoader
+from ..worker.worker import (
+ InferenceReply,
+ LoadModelResult,
+ MachineLearningWorkerBase,
+ RequestBatch,
+)
+from .device_manager import DeviceManager, WorkerDevice
+from .error_handling import build_failure_reply, exception_handler
+
+if t.TYPE_CHECKING:
+ from smartsim._core.mli.mli_schemas.response.response_capnp import Status
+
+logger = get_logger(__name__)
+
+
+class WorkerManager(Service):
+ """An implementation of a service managing distribution of tasks to
+ machine learning workers."""
+
+ def __init__(
+ self,
+ config_loader: EnvironmentConfigLoader,
+ worker_type: t.Type[MachineLearningWorkerBase],
+ dispatcher_queue: "mp.Queue[RequestBatch]",
+ as_service: bool = False,
+ cooldown: int = 0,
+ device: t.Literal["cpu", "gpu"] = "cpu",
+ ) -> None:
+ """Initialize the WorkerManager.
+
+ :param config_loader: Environment config loader for loading queues
+ and feature stores
+ :param worker_type: The type of worker to manage
+ :param dispatcher_queue: Queue from which the batched requests are pulled
+ :param as_service: Specifies run-once or run-until-complete behavior of service
+ :param cooldown: Number of seconds to wait before shutting down after
+ shutdown criteria are met
+ :param device: The device on which the Worker should run. Every worker manager
+ is assigned one single GPU (if available), thus the device should have no index.
+ """
+ super().__init__(as_service, cooldown)
+
+ self._dispatcher_queue = dispatcher_queue
+ """The Dispatcher queue that the WorkerManager monitors for new batches"""
+ self._worker = worker_type()
+ """The ML Worker implementation"""
+ self._callback_factory = config_loader._callback_factory
+ """The type of communication channel to construct for callbacks"""
+ self._device = device
+ """Device on which workers need to run"""
+ self._cached_models: dict[str, t.Any] = {}
+ """Dictionary of previously loaded models"""
+ self._feature_stores: t.Dict[str, FeatureStore] = {}
+ """A collection of attached feature stores"""
+ self._featurestore_factory = config_loader._featurestore_factory
+ """A factory method to create a desired feature store client type"""
+ self._backbone: t.Optional[FeatureStore] = config_loader.get_backbone()
+ """A standalone, system-created feature store used to share internal
+ information among MLI components"""
+ self._device_manager: t.Optional[DeviceManager] = None
+ """Object responsible for model caching and device access"""
+ self._perf_timer = PerfTimer(prefix="w_", debug=False, timing_on=True)
+ """Performance timer"""
+
+ @property
+ def has_featurestore_factory(self) -> bool:
+ """Check if the WorkerManager has a FeatureStore factory.
+
+ :returns: True if there is a FeatureStore factory, False otherwise
+ """
+ return self._featurestore_factory is not None
+
+ def _on_start(self) -> None:
+ """Called on initial entry into Service `execute` event loop before
+ `_on_iteration` is invoked."""
+ self._device_manager = DeviceManager(WorkerDevice(self._device))
+
+ def _check_feature_stores(self, batch: RequestBatch) -> bool:
+ """Ensures that all feature stores required by the request are available.
+
+ :param batch: The batch of requests to validate
+ :returns: False if feature store validation fails for the batch, True otherwise
+ """
+ # collect all feature stores required by the request
+ fs_model: t.Set[str] = set()
+ if batch.model_id.key:
+ fs_model = {batch.model_id.descriptor}
+ fs_inputs = {key.descriptor for key in batch.input_keys}
+ fs_outputs = {key.descriptor for key in batch.output_keys}
+
+ # identify which feature stores are requested and unknown
+ fs_desired = fs_model.union(fs_inputs).union(fs_outputs)
+ fs_actual = {item.descriptor for item in self._feature_stores.values()}
+ fs_missing = fs_desired - fs_actual
+
+ if not self.has_featurestore_factory:
+ logger.error("No feature store factory configured")
+ return False
+
+ # create the feature stores we need to service request
+ if fs_missing:
+ logger.debug(f"Adding feature store(s): {fs_missing}")
+ for descriptor in fs_missing:
+ feature_store = self._featurestore_factory(descriptor)
+ self._feature_stores[descriptor] = feature_store
+
+ return True
+
+ def _validate_batch(self, batch: RequestBatch) -> bool:
+ """Ensure the request can be processed.
+
+ :param batch: The batch of requests to validate
+ :returns: False if the request fails any validation checks, True otherwise
+ """
+ if batch is None or not batch.has_valid_requests:
+ return False
+
+ return self._check_feature_stores(batch)
+
+ # remove this when we are done with time measurements
+ # pylint: disable-next=too-many-statements
+ def _on_iteration(self) -> None:
+ """Executes calls to the machine learning worker implementation to complete
+ the inference pipeline."""
+ pre_batch_time = time.perf_counter()
+ try:
+ batch: RequestBatch = self._dispatcher_queue.get(timeout=0.0001)
+ except Empty:
+ return
+
+ self._perf_timer.start_timings(
+ "flush_requests", time.perf_counter() - pre_batch_time
+ )
+
+ if not self._validate_batch(batch):
+ exception_handler(
+ ValueError("An invalid batch was received"),
+ None,
+ None,
+ )
+ return
+
+ if not self._device_manager:
+ for request in batch.requests:
+ msg = "No Device Manager found. WorkerManager._on_start() "
+ "must be called after initialization. If possible, "
+ "you should use `WorkerManager.execute()` instead of "
+ "directly calling `_on_iteration()`."
+ try:
+ self._dispatcher_queue.put(batch)
+ except Exception:
+ msg += "\nThe batch could not be put back in the queue "
+ "and will not be processed."
+ exception_handler(
+ RuntimeError(msg),
+ request.callback,
+ "Error acquiring device manager",
+ )
+ return
+
+ try:
+ device_cm = self._device_manager.get_device(
+ worker=self._worker,
+ batch=batch,
+ feature_stores=self._feature_stores,
+ )
+ except Exception as exc:
+ for request in batch.requests:
+ exception_handler(
+ exc,
+ request.callback,
+ "Error loading model on device or getting device.",
+ )
+ return
+ self._perf_timer.measure_time("fetch_model")
+
+ with device_cm as device:
+
+ try:
+ model_result = LoadModelResult(device.get_model(batch.model_id.key))
+ except Exception as exc:
+ for request in batch.requests:
+ exception_handler(
+ exc, request.callback, "Error getting model from device."
+ )
+ return
+ self._perf_timer.measure_time("load_model")
+
+ if not batch.inputs:
+ for request in batch.requests:
+ exception_handler(
+ ValueError("Error batching inputs"),
+ request.callback,
+ None,
+ )
+ return
+ transformed_input = batch.inputs
+
+ try:
+ execute_result = self._worker.execute(
+ batch, model_result, transformed_input, device.name
+ )
+ except Exception as e:
+ for request in batch.requests:
+ exception_handler(e, request.callback, "Error while executing.")
+ return
+ self._perf_timer.measure_time("execute")
+
+ try:
+ transformed_outputs = self._worker.transform_output(
+ batch, execute_result
+ )
+ except Exception as e:
+ for request in batch.requests:
+ exception_handler(
+ e, request.callback, "Error while transforming the output."
+ )
+ return
+
+ for request, transformed_output in zip(batch.requests, transformed_outputs):
+ reply = InferenceReply()
+ if request.has_output_keys:
+ try:
+ reply.output_keys = self._worker.place_output(
+ request,
+ transformed_output,
+ self._feature_stores,
+ )
+ except Exception as e:
+ exception_handler(
+ e, request.callback, "Error while placing the output."
+ )
+ continue
+ else:
+ reply.outputs = transformed_output.outputs
+ self._perf_timer.measure_time("assign_output")
+
+ if not reply.has_outputs:
+ response = build_failure_reply("fail", "Outputs not found.")
+ else:
+ reply.status_enum = "complete"
+ reply.message = "Success"
+
+ results = self._worker.prepare_outputs(reply)
+ response = MessageHandler.build_response(
+ status=reply.status_enum,
+ message=reply.message,
+ result=results,
+ custom_attributes=None,
+ )
+
+ self._perf_timer.measure_time("build_reply")
+
+ serialized_resp = MessageHandler.serialize_response(response)
+
+ self._perf_timer.measure_time("serialize_resp")
+
+ if request.callback:
+ request.callback.send(serialized_resp)
+ if reply.has_outputs:
+ # send tensor data after response
+ for output in reply.outputs:
+ request.callback.send(output)
+ self._perf_timer.measure_time("send")
+
+ self._perf_timer.end_timings()
+
+ if self._perf_timer.max_length == 801:
+ self._perf_timer.print_timings(True)
+
+ def _can_shutdown(self) -> bool:
+ """Determine if the service can be shutdown.
+
+ :returns: True when criteria to shutdown the service are met, False otherwise
+ """
+ # todo: determine shutdown criteria
+ # will we receive a completion message?
+ # will we let MLI mgr just kill this?
+ # time_diff = self._last_event - datetime.datetime.now()
+ # if time_diff.total_seconds() > self._cooldown:
+ # return True
+ # return False
+ return self._worker is None
diff --git a/smartsim/_core/mli/infrastructure/environment_loader.py b/smartsim/_core/mli/infrastructure/environment_loader.py
new file mode 100644
index 0000000000..5ba0fccc27
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/environment_loader.py
@@ -0,0 +1,116 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+import typing as t
+
+from smartsim._core.mli.comm.channel.channel import CommChannelBase
+from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class EnvironmentConfigLoader:
+ """
+ Facilitates the loading of a FeatureStore and Queue into the WorkerManager.
+ """
+
+ REQUEST_QUEUE_ENV_VAR = "_SMARTSIM_REQUEST_QUEUE"
+ """The environment variable that holds the request queue descriptor"""
+ BACKBONE_ENV_VAR = "_SMARTSIM_INFRA_BACKBONE"
+ """The environment variable that holds the backbone descriptor"""
+
+ def __init__(
+ self,
+ featurestore_factory: t.Callable[[str], FeatureStore],
+ callback_factory: t.Callable[[str], CommChannelBase],
+ queue_factory: t.Callable[[str], CommChannelBase],
+ ) -> None:
+ """Initialize the config loader instance with the factories necessary for
+ creating additional objects.
+
+ :param featurestore_factory: A factory method that produces a feature store
+ given a descriptor
+ :param callback_factory: A factory method that produces a callback
+ channel given a descriptor
+ :param queue_factory: A factory method that produces a queue
+ channel given a descriptor
+ """
+ self.queue: t.Optional[CommChannelBase] = None
+ """The attached incoming event queue channel"""
+ self.backbone: t.Optional[FeatureStore] = None
+ """The attached backbone feature store"""
+ self._featurestore_factory = featurestore_factory
+ """A factory method to instantiate a FeatureStore"""
+ self._callback_factory = callback_factory
+ """A factory method to instantiate a concrete CommChannelBase
+ for inference callbacks"""
+ self._queue_factory = queue_factory
+ """A factory method to instantiate a concrete CommChannelBase
+ for inference requests"""
+
+ def get_backbone(self) -> t.Optional[FeatureStore]:
+ """Attach to the backbone feature store using the descriptor found in
+ the environment variable `_SMARTSIM_INFRA_BACKBONE`. The backbone is
+ a standalone, system-created feature store used to share internal
+ information among MLI components.
+
+ :returns: The attached feature store via `_SMARTSIM_INFRA_BACKBONE`
+ """
+ descriptor = os.getenv(self.BACKBONE_ENV_VAR, "")
+
+ if not descriptor:
+ logger.warning("No backbone descriptor is configured")
+ return None
+
+ if self._featurestore_factory is None:
+ logger.warning(
+ "No feature store factory is configured. Backbone not created."
+ )
+ return None
+
+ self.backbone = self._featurestore_factory(descriptor)
+ return self.backbone
+
+ def get_queue(self) -> t.Optional[CommChannelBase]:
+ """Attach to a queue-like communication channel using the descriptor
+ found in the environment variable `_SMARTSIM_REQUEST_QUEUE`.
+
+ :returns: The attached queue specified via `_SMARTSIM_REQUEST_QUEUE`
+ """
+ descriptor = os.getenv(self.REQUEST_QUEUE_ENV_VAR, "")
+
+ if not descriptor:
+ logger.warning("No queue descriptor is configured")
+ return None
+
+ if self._queue_factory is None:
+ logger.warning("No queue factory is configured")
+ return None
+
+ self.queue = self._queue_factory(descriptor)
+ return self.queue
diff --git a/smartsim/_core/mli/infrastructure/storage/__init__.py b/smartsim/_core/mli/infrastructure/storage/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py b/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py
new file mode 100644
index 0000000000..b12d7b11b4
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py
@@ -0,0 +1,259 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import itertools
+import os
+import time
+import typing as t
+
+# pylint: disable=import-error
+# isort: off
+import dragon.data.ddict.ddict as dragon_ddict
+
+# isort: on
+
+from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
+ DragonFeatureStore,
+)
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class BackboneFeatureStore(DragonFeatureStore):
+ """A DragonFeatureStore wrapper with utility methods for accessing shared
+ information stored in the MLI backbone feature store."""
+
+ MLI_NOTIFY_CONSUMERS = "_SMARTSIM_MLI_NOTIFY_CONSUMERS"
+ """Unique key used in the backbone to locate the consumer list"""
+ MLI_REGISTRAR_CONSUMER = "_SMARTIM_MLI_REGISTRAR_CONSUMER"
+ """Unique key used in the backbone to locate the registration consumer"""
+ MLI_WORKER_QUEUE = "_SMARTSIM_REQUEST_QUEUE"
+ """Unique key used in the backbone to locate MLI work queue"""
+ MLI_BACKBONE = "_SMARTSIM_INFRA_BACKBONE"
+ """Unique key used in the backbone to locate the backbone feature store"""
+ _CREATED_ON = "creation"
+ """Unique key used in the backbone to locate the creation date of the
+ feature store"""
+ _DEFAULT_WAIT_TIMEOUT = 1.0
+ """The default wait time (in seconds) for blocking requests to
+ the feature store"""
+
+ def __init__(
+ self,
+ storage: dragon_ddict.DDict,
+ allow_reserved_writes: bool = False,
+ ) -> None:
+ """Initialize the DragonFeatureStore instance.
+
+ :param storage: A distributed dictionary to be used as the underlying
+ storage mechanism of the feature store
+ :param allow_reserved_writes: Whether reserved writes are allowed
+ """
+ super().__init__(storage)
+ self._enable_reserved_writes = allow_reserved_writes
+
+ self._record_creation_data()
+
+ @property
+ def wait_timeout(self) -> float:
+ """Retrieve the wait timeout for this feature store. The wait timeout is
+ applied to all calls to `wait_for`.
+
+ :returns: The wait timeout (in seconds).
+ """
+ return self._wait_timeout
+
+ @wait_timeout.setter
+ def wait_timeout(self, value: float) -> None:
+ """Set the wait timeout (in seconds) for this feature store. The wait
+ timeout is applied to all calls to `wait_for`.
+
+ :param value: The new value to set
+ """
+ self._wait_timeout = value
+
+ @property
+ def notification_channels(self) -> t.Sequence[str]:
+ """Retrieve descriptors for all registered MLI notification channels.
+
+ :returns: The list of channel descriptors
+ """
+ if self.MLI_NOTIFY_CONSUMERS in self:
+ stored_consumers = self[self.MLI_NOTIFY_CONSUMERS]
+ return str(stored_consumers).split(",")
+ return []
+
+ @notification_channels.setter
+ def notification_channels(self, values: t.Sequence[str]) -> None:
+ """Set the notification channels to be sent events.
+
+ :param values: The list of channel descriptors to save
+ """
+ self[self.MLI_NOTIFY_CONSUMERS] = ",".join(
+ [str(value) for value in values if value]
+ )
+
+ @property
+ def backend_channel(self) -> t.Optional[str]:
+ """Retrieve the channel descriptor used to register event consumers.
+
+ :returns: The channel descriptor"""
+ if self.MLI_REGISTRAR_CONSUMER in self:
+ return str(self[self.MLI_REGISTRAR_CONSUMER])
+ return None
+
+ @backend_channel.setter
+ def backend_channel(self, value: str) -> None:
+ """Set the channel used to register event consumers.
+
+ :param value: The stringified channel descriptor"""
+ self[self.MLI_REGISTRAR_CONSUMER] = value
+
+ @property
+ def worker_queue(self) -> t.Optional[str]:
+ """Retrieve the channel descriptor used to send work to MLI worker managers.
+
+ :returns: The channel descriptor, if found. Otherwise, `None`"""
+ if self.MLI_WORKER_QUEUE in self:
+ return str(self[self.MLI_WORKER_QUEUE])
+ return None
+
+ @worker_queue.setter
+ def worker_queue(self, value: str) -> None:
+ """Set the channel descriptor used to send work to MLI worker managers.
+
+ :param value: The channel descriptor"""
+ self[self.MLI_WORKER_QUEUE] = value
+
+ @property
+ def creation_date(self) -> str:
+ """Return the creation date for the backbone feature store.
+
+ :returns: The string-formatted date when feature store was created"""
+ return str(self[self._CREATED_ON])
+
+ def _record_creation_data(self) -> None:
+ """Write the creation timestamp to the feature store."""
+ if self._CREATED_ON not in self:
+ if not self._allow_reserved_writes:
+ logger.warning(
+ "Recorded creation from a write-protected backbone instance"
+ )
+ self[self._CREATED_ON] = str(time.time())
+
+ os.environ[self.MLI_BACKBONE] = self.descriptor
+
+ @classmethod
+ def from_writable_descriptor(
+ cls,
+ descriptor: str,
+ ) -> "BackboneFeatureStore":
+ """A factory method that creates an instance from a descriptor string.
+
+ :param descriptor: The descriptor that uniquely identifies the resource
+ :returns: An attached DragonFeatureStore
+ :raises SmartSimError: if attachment to DragonFeatureStore fails
+ """
+ try:
+ return BackboneFeatureStore(dragon_ddict.DDict.attach(descriptor), True)
+ except Exception as ex:
+ raise SmartSimError(
+ f"Error creating backbone feature store: {descriptor}"
+ ) from ex
+
+ def _check_wait_timeout(
+ self, start_time: float, timeout: float, indicators: t.Dict[str, bool]
+ ) -> None:
+ """Perform timeout verification.
+
+ :param start_time: the start time to use for elapsed calculation
+ :param timeout: the timeout (in seconds)
+ :param indicators: latest retrieval status for requested keys
+ :raises SmartSimError: If the timeout elapses before all values are
+ retrieved
+ """
+ elapsed = time.time() - start_time
+ if timeout and elapsed > timeout:
+ raise SmartSimError(
+ f"Backbone {self.descriptor=} timeout after {elapsed} "
+ f"seconds retrieving keys: {indicators}"
+ )
+
+ def wait_for(
+ self, keys: t.List[str], timeout: float = _DEFAULT_WAIT_TIMEOUT
+ ) -> t.Dict[str, t.Union[str, bytes, None]]:
+ """Perform a blocking wait until all specified keys have been found
+ in the backbone.
+
+ :param keys: The required collection of keys to retrieve
+ :param timeout: The maximum wait time in seconds
+ :returns: Dictionary containing the keys and values requested
+ :raises SmartSimError: If the timeout elapses without retrieving
+ all requested keys
+ """
+ if timeout < 0:
+ timeout = self._DEFAULT_WAIT_TIMEOUT
+ logger.info(f"Using default wait_for timeout: {timeout}s")
+
+ if not keys:
+ return {}
+
+ values: t.Dict[str, t.Union[str, bytes, None]] = {k: None for k in set(keys)}
+ is_found = {k: False for k in values.keys()}
+
+ backoff = (0.1, 0.2, 0.4, 0.8)
+ backoff_iter = itertools.cycle(backoff)
+ start_time = time.time()
+
+ while not all(is_found.values()):
+ delay = next(backoff_iter)
+
+ for key in [k for k, v in is_found.items() if not v]:
+ try:
+ values[key] = self[key]
+ is_found[key] = True
+ except Exception:
+ if delay == backoff[-1]:
+ logger.debug(f"Re-attempting `{key}` retrieval in {delay}s")
+
+ if all(is_found.values()):
+ logger.debug(f"wait_for({keys}) retrieved all keys")
+ continue
+
+ self._check_wait_timeout(start_time, timeout, is_found)
+ time.sleep(delay)
+
+ return values
+
+ def get_env(self) -> t.Dict[str, str]:
+ """Returns a dictionary populated with environment variables necessary to
+ connect a process to the existing backbone instance.
+
+ :returns: The dictionary populated with env vars
+ """
+ return {self.MLI_BACKBONE: self.descriptor}
diff --git a/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py b/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py
new file mode 100644
index 0000000000..24f2221c87
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py
@@ -0,0 +1,126 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import typing as t
+
+# pylint: disable=import-error
+# isort: off
+import dragon.data.ddict.ddict as dragon_ddict
+
+# isort: on
+
+from smartsim._core.mli.infrastructure.storage.dragon_util import (
+ ddict_to_descriptor,
+ descriptor_to_ddict,
+)
+from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore
+from smartsim.error import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class DragonFeatureStore(FeatureStore):
+ """A feature store backed by a dragon distributed dictionary."""
+
+ def __init__(self, storage: "dragon_ddict.DDict") -> None:
+ """Initialize the DragonFeatureStore instance.
+
+ :param storage: A distributed dictionary to be used as the underlying
+ storage mechanism of the feature store"""
+ if storage is None:
+ raise ValueError(
+ "Storage is required when instantiating a DragonFeatureStore."
+ )
+
+ descriptor = ""
+ if isinstance(storage, dragon_ddict.DDict):
+ descriptor = ddict_to_descriptor(storage)
+
+ super().__init__(descriptor)
+ self._storage: t.Dict[str, t.Union[str, bytes]] = storage
+ """The underlying storage mechanism of the DragonFeatureStore; a
+ distributed, in-memory key-value store"""
+
+ def _get(self, key: str) -> t.Union[str, bytes]:
+ """Retrieve a value from the underlying storage mechanism.
+
+ :param key: The unique key that identifies the resource
+ :returns: The value identified by the key
+ :raises KeyError: If the key has not been used to store a value
+ """
+ try:
+ return self._storage[key]
+ except dragon_ddict.DDictError as e:
+ raise KeyError(f"Key not found in FeatureStore: {key}") from e
+
+ def _set(self, key: str, value: t.Union[str, bytes]) -> None:
+ """Store a value into the underlying storage mechanism.
+
+ :param key: The unique key that identifies the resource
+ :param value: The value to store
+ :returns: The value identified by the key
+ """
+ self._storage[key] = value
+
+ def _contains(self, key: str) -> bool:
+ """Determine if the storage mechanism contains a given key.
+
+ :param key: The unique key that identifies the resource
+ :returns: True if the key is defined, False otherwise
+ """
+ return key in self._storage
+
+ def pop(self, key: str) -> t.Union[str, bytes, None]:
+ """Remove the value from the dictionary and return the value.
+
+ :param key: Dictionary key to retrieve
+ :returns: The value held at the key if it exists, otherwise `None
+ `"""
+ try:
+ return self._storage.pop(key)
+ except dragon_ddict.DDictError:
+ return None
+
+ @classmethod
+ def from_descriptor(
+ cls,
+ descriptor: str,
+ ) -> "DragonFeatureStore":
+ """A factory method that creates an instance from a descriptor string.
+
+ :param descriptor: The descriptor that uniquely identifies the resource
+ :returns: An attached DragonFeatureStore
+ :raises SmartSimError: If attachment to DragonFeatureStore fails
+ """
+ try:
+ logger.debug(f"Attaching to FeatureStore with descriptor: {descriptor}")
+ storage = descriptor_to_ddict(descriptor)
+ return cls(storage)
+ except Exception as ex:
+ raise SmartSimError(
+ f"Error creating dragon feature store from descriptor: {descriptor}"
+ ) from ex
diff --git a/smartsim/_core/mli/infrastructure/storage/dragon_util.py b/smartsim/_core/mli/infrastructure/storage/dragon_util.py
new file mode 100644
index 0000000000..50d15664c0
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/storage/dragon_util.py
@@ -0,0 +1,101 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# pylint: disable=import-error
+# isort: off
+import dragon.data.ddict.ddict as dragon_ddict
+
+# isort: on
+
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+def ddict_to_descriptor(ddict: dragon_ddict.DDict) -> str:
+ """Convert a DDict to a descriptor string.
+
+ :param ddict: The dragon dictionary to convert
+ :returns: The descriptor string
+ :raises ValueError: If a ddict is not provided
+ """
+ if ddict is None:
+ raise ValueError("DDict is not available to create a descriptor")
+
+ # unlike other dragon objects, the dictionary serializes to a string
+ # instead of bytes
+ return str(ddict.serialize())
+
+
+def descriptor_to_ddict(descriptor: str) -> dragon_ddict.DDict:
+ """Create and attach a new DDict instance given
+ the string-encoded descriptor.
+
+ :param descriptor: The descriptor of a dictionary to attach to
+ :returns: The attached dragon dictionary"""
+ return dragon_ddict.DDict.attach(descriptor)
+
+
+def create_ddict(
+ num_nodes: int, mgr_per_node: int, mem_per_node: int
+) -> dragon_ddict.DDict:
+ """Create a distributed dragon dictionary.
+
+ :param num_nodes: The number of distributed nodes to distribute the dictionary to.
+ At least one node is required.
+ :param mgr_per_node: The number of manager processes per node
+ :param mem_per_node: The amount of memory (in megabytes) to allocate per node. Total
+ memory available will be calculated as `num_nodes * node_mem`
+
+ :returns: The instantiated dragon dictionary
+ :raises ValueError: If invalid num_nodes is supplied
+ :raises ValueError: If invalid mem_per_node is supplied
+ :raises ValueError: If invalid mgr_per_node is supplied
+ """
+ if num_nodes < 1:
+ raise ValueError("A dragon dictionary must have at least 1 node")
+
+ if mgr_per_node < 1:
+ raise ValueError("A dragon dict requires at least 2 managers per ndode")
+
+ if mem_per_node < dragon_ddict.DDICT_MIN_SIZE:
+ raise ValueError(
+ "A dragon dictionary requires at least "
+ f"{dragon_ddict.DDICT_MIN_SIZE / 1024} MB"
+ )
+
+ mem_total = num_nodes * mem_per_node
+
+ logger.debug(
+ f"Creating dragon dictionary with {num_nodes} nodes, {mem_total} MB memory"
+ )
+
+ distributed_dict = dragon_ddict.DDict(num_nodes, mgr_per_node, total_mem=mem_total)
+ logger.debug(
+ "Successfully created dragon dictionary with "
+ f"{num_nodes} nodes, {mem_total} MB total memory"
+ )
+ return distributed_dict
diff --git a/smartsim/_core/mli/infrastructure/storage/feature_store.py b/smartsim/_core/mli/infrastructure/storage/feature_store.py
new file mode 100644
index 0000000000..ebca07ed4e
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/storage/feature_store.py
@@ -0,0 +1,224 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import enum
+import typing as t
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class ReservedKeys(str, enum.Enum):
+ """Contains constants used to identify all featurestore keys that
+ may not be to used by users. Avoids overwriting system data."""
+
+ MLI_NOTIFY_CONSUMERS = "_SMARTSIM_MLI_NOTIFY_CONSUMERS"
+ """Storage location for the list of registered consumers that will receive
+ events from an EventBroadcaster"""
+
+ MLI_REGISTRAR_CONSUMER = "_SMARTIM_MLI_REGISTRAR_CONSUMER"
+ """Storage location for the channel used to send messages directly to
+ the MLI backend"""
+
+ MLI_WORKER_QUEUE = "_SMARTSIM_REQUEST_QUEUE"
+ """Storage location for the channel used to send work requests
+ to the available worker managers"""
+
+ @classmethod
+ def contains(cls, value: str) -> bool:
+ """Convert a string representation into an enumeration member.
+
+ :param value: The string to convert
+ :returns: The enumeration member if the conversion succeeded, otherwise None
+ """
+ try:
+ cls(value)
+ except ValueError:
+ return False
+
+ return True
+
+
+@dataclass(frozen=True)
+class TensorKey:
+ """A key,descriptor pair enabling retrieval of an item from a feature store."""
+
+ key: str
+ """The unique key of an item in a feature store"""
+ descriptor: str
+ """The unique identifier of the feature store containing the key"""
+
+ def __post_init__(self) -> None:
+ """Ensure the key and descriptor have at least one character.
+
+ :raises ValueError: If key or descriptor are empty strings
+ """
+ if len(self.key) < 1:
+ raise ValueError("Key must have at least one character.")
+ if len(self.descriptor) < 1:
+ raise ValueError("Descriptor must have at least one character.")
+
+
+@dataclass(frozen=True)
+class ModelKey:
+ """A key,descriptor pair enabling retrieval of an item from a feature store."""
+
+ key: str
+ """The unique key of an item in a feature store"""
+ descriptor: str
+ """The unique identifier of the feature store containing the key"""
+
+ def __post_init__(self) -> None:
+ """Ensure the key and descriptor have at least one character.
+
+ :raises ValueError: If key or descriptor are empty strings
+ """
+ if len(self.key) < 1:
+ raise ValueError("Key must have at least one character.")
+ if len(self.descriptor) < 1:
+ raise ValueError("Descriptor must have at least one character.")
+
+
+class FeatureStore(ABC):
+ """Abstract base class providing the common interface for retrieving
+ values from a feature store implementation."""
+
+ def __init__(self, descriptor: str, allow_reserved_writes: bool = False) -> None:
+ """Initialize the feature store.
+
+ :param descriptor: The stringified version of a storage descriptor
+ :param allow_reserved_writes: Override the default behavior of blocking
+ writes to reserved keys
+ """
+ self._enable_reserved_writes = allow_reserved_writes
+ """Flag used to ensure that any keys written by the system to a feature store
+ are not overwritten by user code. Disabled by default. Subclasses must set the
+ value intentionally."""
+ self._descriptor = descriptor
+ """Stringified version of the unique ID enabling a client to connect
+ to the feature store"""
+
+ def _check_reserved(self, key: str) -> None:
+ """A utility method used to verify access to write to a reserved key
+ in the FeatureStore. Used by subclasses in __setitem___ implementations.
+
+ :param key: A key to compare to the reserved keys
+ :raises SmartSimError: If the key is reserved
+ """
+ if not self._enable_reserved_writes and ReservedKeys.contains(key):
+ raise SmartSimError(
+ "Use of reserved key denied. "
+ "Unable to overwrite system configuration"
+ )
+
+ def __getitem__(self, key: str) -> t.Union[str, bytes]:
+ """Retrieve an item using key.
+
+ :param key: Unique key of an item to retrieve from the feature store
+ :returns: An item in the FeatureStore
+ :raises SmartSimError: If retrieving fails
+ """
+ try:
+ return self._get(key)
+ except KeyError:
+ raise
+ except Exception as ex:
+ # note: explicitly avoid round-trip to check for key existence
+ raise SmartSimError(
+ f"Could not get value for existing key {key}, error:\n{ex}"
+ ) from ex
+
+ def __setitem__(self, key: str, value: t.Union[str, bytes]) -> None:
+ """Assign a value using key.
+
+ :param key: Unique key of an item to set in the feature store
+ :param value: Value to persist in the feature store
+ """
+ self._check_reserved(key)
+ self._set(key, value)
+
+ def __contains__(self, key: str) -> bool:
+ """Membership operator to test for a key existing within the feature store.
+
+ :param key: Unique key of an item to retrieve from the feature store
+ :returns: `True` if the key is found, `False` otherwise
+ """
+ return self._contains(key)
+
+ @abstractmethod
+ def _get(self, key: str) -> t.Union[str, bytes]:
+ """Retrieve a value from the underlying storage mechanism.
+
+ :param key: The unique key that identifies the resource
+ :returns: The value identified by the key
+ :raises KeyError: If the key has not been used to store a value
+ """
+
+ @abstractmethod
+ def _set(self, key: str, value: t.Union[str, bytes]) -> None:
+ """Store a value into the underlying storage mechanism.
+
+ :param key: The unique key that identifies the resource
+ :param value: The value to store
+ """
+
+ @abstractmethod
+ def _contains(self, key: str) -> bool:
+ """Determine if the storage mechanism contains a given key.
+
+ :param key: The unique key that identifies the resource
+ :returns: `True` if the key is defined, `False` otherwise
+ """
+
+ @property
+ def _allow_reserved_writes(self) -> bool:
+ """Return the boolean flag indicating if writing to reserved keys is
+ enabled for this feature store.
+
+ :returns: `True` if enabled, `False` otherwise
+ """
+ return self._enable_reserved_writes
+
+ @_allow_reserved_writes.setter
+ def _allow_reserved_writes(self, value: bool) -> None:
+ """Modify the boolean flag indicating if writing to reserved keys is
+ enabled for this feature store.
+
+ :param value: The new value to set for the flag
+ """
+ self._enable_reserved_writes = value
+
+ @property
+ def descriptor(self) -> str:
+ """Unique identifier enabling a client to connect to the feature store.
+
+ :returns: A descriptor encoded as a string
+ """
+ return self._descriptor
diff --git a/smartsim/_core/mli/infrastructure/worker/__init__.py b/smartsim/_core/mli/infrastructure/worker/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py
new file mode 100644
index 0000000000..64e94e5eb6
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/worker/torch_worker.py
@@ -0,0 +1,276 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import io
+
+import numpy as np
+import torch
+
+# pylint: disable=import-error
+from dragon.managed_memory import MemoryAlloc, MemoryPool
+
+from .....error import SmartSimError
+from .....log import get_logger
+from ...mli_schemas.tensor import tensor_capnp
+from .worker import (
+ ExecuteResult,
+ FetchInputResult,
+ FetchModelResult,
+ LoadModelResult,
+ MachineLearningWorkerBase,
+ RequestBatch,
+ TransformInputResult,
+ TransformOutputResult,
+)
+
+# pylint: enable=import-error
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(4)
+logger = get_logger(__name__)
+
+
+class TorchWorker(MachineLearningWorkerBase):
+ """A worker that executes a PyTorch model."""
+
+ @staticmethod
+ def load_model(
+ batch: RequestBatch, fetch_result: FetchModelResult, device: str
+ ) -> LoadModelResult:
+ """Given a loaded MachineLearningModel, ensure it is loaded into
+ device memory.
+
+ :param request: The request that triggered the pipeline
+ :param device: The device on which the model must be placed
+ :returns: LoadModelResult wrapping the model loaded for the request
+ :raises ValueError: If model reference object is not found
+ :raises RuntimeError: If loading and evaluating the model failed
+ """
+ if fetch_result.model_bytes:
+ model_bytes = fetch_result.model_bytes
+ elif batch.raw_model and batch.raw_model.data:
+ model_bytes = batch.raw_model.data
+ else:
+ raise ValueError("Unable to load model without reference object")
+
+ device_to_torch = {"cpu": "cpu", "gpu": "cuda"}
+ for old, new in device_to_torch.items():
+ device = device.replace(old, new)
+
+ buffer = io.BytesIO(initial_bytes=model_bytes)
+ try:
+ with torch.no_grad():
+ model = torch.jit.load(buffer, map_location=device) # type: ignore
+ model.eval()
+ except Exception as e:
+ raise RuntimeError(
+ "Failed to load and evaluate the model: "
+ f"Model key {batch.model_id.key}, Device {device}"
+ ) from e
+ result = LoadModelResult(model)
+ return result
+
+ @staticmethod
+ def transform_input(
+ batch: RequestBatch,
+ fetch_results: list[FetchInputResult],
+ mem_pool: MemoryPool,
+ ) -> TransformInputResult:
+ """Given a collection of data, perform a transformation on the data and put
+ the raw tensor data on a MemoryPool allocation.
+
+ :param request: The request that triggered the pipeline
+ :param fetch_result: Raw outputs from fetching inputs out of a feature store
+ :param mem_pool: The memory pool used to access batched input tensors
+ :returns: The transformed inputs wrapped in a TransformInputResult
+ :raises ValueError: If tensors cannot be reconstructed
+ :raises IndexError: If index out of range
+ """
+ results: list[torch.Tensor] = []
+ total_samples = 0
+ slices: list[slice] = []
+
+ all_dims: list[list[int]] = []
+ all_dtypes: list[str] = []
+ if fetch_results[0].meta is None:
+ raise ValueError("Cannot reconstruct tensor without meta information")
+ # Traverse inputs to get total number of samples and compute slices
+ # Assumption: first dimension is samples, all tensors in the same input
+ # have same number of samples
+ # thus we only look at the first tensor for each input
+ for res_idx, fetch_result in enumerate(fetch_results):
+ if fetch_result.meta is None or any(
+ item_meta is None for item_meta in fetch_result.meta
+ ):
+ raise ValueError("Cannot reconstruct tensor without meta information")
+ first_tensor_desc: tensor_capnp.TensorDescriptor = fetch_result.meta[0]
+ num_samples = first_tensor_desc.dimensions[0]
+ slices.append(slice(total_samples, total_samples + num_samples))
+ total_samples = total_samples + num_samples
+
+ if res_idx == len(fetch_results) - 1:
+ # For each tensor in the last input, get remaining dimensions
+ # Assumptions: all inputs have the same number of tensors and
+ # last N-1 dimensions match across inputs for corresponding tensors
+ # thus: resulting array will be of size (num_samples, all_other_dims)
+ for item_meta in fetch_result.meta:
+ tensor_desc: tensor_capnp.TensorDescriptor = item_meta
+ tensor_dims = list(tensor_desc.dimensions)
+ all_dims.append([total_samples, *tensor_dims[1:]])
+ all_dtypes.append(str(tensor_desc.dataType))
+
+ for result_tensor_idx, (dims, dtype) in enumerate(zip(all_dims, all_dtypes)):
+ itemsize = np.empty((1), dtype=dtype).itemsize
+ alloc_size = int(np.prod(dims) * itemsize)
+ mem_alloc = mem_pool.alloc(alloc_size)
+ mem_view = mem_alloc.get_memview()
+ try:
+ mem_view[:alloc_size] = b"".join(
+ [
+ fetch_result.inputs[result_tensor_idx]
+ for fetch_result in fetch_results
+ ]
+ )
+ except IndexError as e:
+ raise IndexError(
+ "Error accessing elements in fetch_result.inputs "
+ f"with index {result_tensor_idx}"
+ ) from e
+
+ results.append(mem_alloc.serialize())
+
+ return TransformInputResult(results, slices, all_dims, all_dtypes)
+
+ # pylint: disable-next=unused-argument
+ @staticmethod
+ def execute(
+ batch: RequestBatch,
+ load_result: LoadModelResult,
+ transform_result: TransformInputResult,
+ device: str,
+ ) -> ExecuteResult:
+ """Execute an ML model on inputs transformed for use by the model.
+
+ :param batch: The batch of requests that triggered the pipeline
+ :param load_result: The result of loading the model onto device memory
+ :param transform_result: The result of transforming inputs for model consumption
+ :param device: The device on which the model will be executed
+ :returns: The result of inference wrapped in an ExecuteResult
+ :raises SmartSimError: If model is not loaded
+ :raises IndexError: If memory slicing is out of range
+ :raises ValueError: If tensor creation fails or is unable to evaluate the model
+ """
+ if not load_result.model:
+ raise SmartSimError("Model must be loaded to execute")
+ device_to_torch = {"cpu": "cpu", "gpu": "cuda"}
+ for old, new in device_to_torch.items():
+ device = device.replace(old, new)
+
+ tensors = []
+ mem_allocs = []
+ for transformed, dims, dtype in zip(
+ transform_result.transformed, transform_result.dims, transform_result.dtypes
+ ):
+ mem_alloc = MemoryAlloc.attach(transformed)
+ mem_allocs.append(mem_alloc)
+ itemsize = np.empty((1), dtype=dtype).itemsize
+ try:
+ tensors.append(
+ torch.from_numpy(
+ np.frombuffer(
+ mem_alloc.get_memview()[0 : np.prod(dims) * itemsize],
+ dtype=dtype,
+ ).reshape(dims)
+ )
+ )
+ except IndexError as e:
+ raise IndexError("Error during memory slicing") from e
+ except Exception as e:
+ raise ValueError("Error during tensor creation") from e
+
+ model: torch.nn.Module = load_result.model
+ try:
+ with torch.no_grad():
+ model.eval()
+ results = [
+ model(
+ *[
+ tensor.to(device, non_blocking=True).detach()
+ for tensor in tensors
+ ]
+ )
+ ]
+ except Exception as e:
+ raise ValueError(
+ f"Error while evaluating the model: Model {batch.model_id.key}"
+ ) from e
+
+ transform_result.transformed = []
+
+ execute_result = ExecuteResult(results, transform_result.slices)
+ for mem_alloc in mem_allocs:
+ mem_alloc.free()
+ return execute_result
+
+ @staticmethod
+ def transform_output(
+ batch: RequestBatch,
+ execute_result: ExecuteResult,
+ ) -> list[TransformOutputResult]:
+ """Given inference results, perform transformations required to
+ transmit results to the requestor.
+
+ :param batch: The batch of requests that triggered the pipeline
+ :param execute_result: The result of inference wrapped in an ExecuteResult
+ :returns: A list of transformed outputs
+ :raises IndexError: If indexing is out of range
+ :raises ValueError: If transforming output fails
+ """
+ transformed_list: list[TransformOutputResult] = []
+ cpu_predictions = [
+ prediction.cpu() for prediction in execute_result.predictions
+ ]
+ for result_slice in execute_result.slices:
+ transformed = []
+ for cpu_item in cpu_predictions:
+ try:
+ transformed.append(cpu_item[result_slice].numpy().tobytes())
+
+ # todo: need the shape from latest schemas added here.
+ transformed_list.append(
+ TransformOutputResult(transformed, None, "c", "float32")
+ ) # fixme
+ except IndexError as e:
+ raise IndexError(
+ f"Error accessing elements: result_slice {result_slice}"
+ ) from e
+ except Exception as e:
+ raise ValueError("Error transforming output") from e
+
+ execute_result.predictions = []
+
+ return transformed_list
diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py
new file mode 100644
index 0000000000..9556b8e438
--- /dev/null
+++ b/smartsim/_core/mli/infrastructure/worker/worker.py
@@ -0,0 +1,646 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# pylint: disable=import-error
+from dragon.managed_memory import MemoryPool
+
+# isort: off
+# isort: on
+
+import typing as t
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+from .....error import SmartSimError
+from .....log import get_logger
+from ...comm.channel.channel import CommChannelBase
+from ...message_handler import MessageHandler
+from ...mli_schemas.model.model_capnp import Model
+from ..storage.feature_store import FeatureStore, ModelKey, TensorKey
+
+if t.TYPE_CHECKING:
+ from smartsim._core.mli.mli_schemas.response.response_capnp import Status
+ from smartsim._core.mli.mli_schemas.tensor.tensor_capnp import TensorDescriptor
+
+logger = get_logger(__name__)
+
+# Placeholder
+ModelIdentifier = ModelKey
+
+
+class InferenceRequest:
+ """Internal representation of an inference request from a client."""
+
+ def __init__(
+ self,
+ model_key: t.Optional[ModelKey] = None,
+ callback: t.Optional[CommChannelBase] = None,
+ raw_inputs: t.Optional[t.List[bytes]] = None,
+ input_keys: t.Optional[t.List[TensorKey]] = None,
+ input_meta: t.Optional[t.List[t.Any]] = None,
+ output_keys: t.Optional[t.List[TensorKey]] = None,
+ raw_model: t.Optional[Model] = None,
+ batch_size: int = 0,
+ ):
+ """Initialize the InferenceRequest.
+
+ :param model_key: A tuple containing a (key, descriptor) pair
+ :param callback: The channel used for notification of inference completion
+ :param raw_inputs: Raw bytes of tensor inputs
+ :param input_keys: A list of tuples containing a (key, descriptor) pair
+ :param input_meta: Metadata about the input data
+ :param output_keys: A list of tuples containing a (key, descriptor) pair
+ :param raw_model: Raw bytes of an ML model
+ :param batch_size: The batch size to apply when batching
+ """
+ self.model_key = model_key
+ """A tuple containing a (key, descriptor) pair"""
+ self.raw_model = raw_model
+ """Raw bytes of an ML model"""
+ self.callback = callback
+ """The channel used for notification of inference completion"""
+ self.raw_inputs = raw_inputs or []
+ """Raw bytes of tensor inputs"""
+ self.input_keys = input_keys or []
+ """A list of tuples containing a (key, descriptor) pair"""
+ self.input_meta = input_meta or []
+ """Metadata about the input data"""
+ self.output_keys = output_keys or []
+ """A list of tuples containing a (key, descriptor) pair"""
+ self.batch_size = batch_size
+ """The batch size to apply when batching"""
+
+ @property
+ def has_raw_model(self) -> bool:
+ """Check if the InferenceRequest contains a raw_model.
+
+ :returns: True if raw_model is not None, False otherwise
+ """
+ return self.raw_model is not None
+
+ @property
+ def has_model_key(self) -> bool:
+ """Check if the InferenceRequest contains a model_key.
+
+ :returns: True if model_key is not None, False otherwise
+ """
+ return self.model_key is not None
+
+ @property
+ def has_raw_inputs(self) -> bool:
+ """Check if the InferenceRequest contains raw_inputs.
+
+ :returns: True if raw_outputs is not None and is not an empty list,
+ False otherwise
+ """
+ return self.raw_inputs is not None and bool(self.raw_inputs)
+
+ @property
+ def has_input_keys(self) -> bool:
+ """Check if the InferenceRequest contains input_keys.
+
+ :returns: True if input_keys is not None and is not an empty list,
+ False otherwise
+ """
+ return self.input_keys is not None and bool(self.input_keys)
+
+ @property
+ def has_output_keys(self) -> bool:
+ """Check if the InferenceRequest contains output_keys.
+
+ :returns: True if output_keys is not None and is not an empty list,
+ False otherwise
+ """
+ return self.output_keys is not None and bool(self.output_keys)
+
+ @property
+ def has_input_meta(self) -> bool:
+ """Check if the InferenceRequest contains input_meta.
+
+ :returns: True if input_meta is not None and is not an empty list,
+ False otherwise
+ """
+ return self.input_meta is not None and bool(self.input_meta)
+
+
+class InferenceReply:
+ """Internal representation of the reply to a client request for inference."""
+
+ def __init__(
+ self,
+ outputs: t.Optional[t.Collection[t.Any]] = None,
+ output_keys: t.Optional[t.Collection[TensorKey]] = None,
+ status_enum: "Status" = "running",
+ message: str = "In progress",
+ ) -> None:
+ """Initialize the InferenceReply.
+
+ :param outputs: List of output data
+ :param output_keys: List of keys used for output data
+ :param status_enum: Status of the reply
+ :param message: Status message that corresponds with the status enum
+ """
+ self.outputs: t.Collection[t.Any] = outputs or []
+ """List of output data"""
+ self.output_keys: t.Collection[t.Optional[TensorKey]] = output_keys or []
+ """List of keys used for output data"""
+ self.status_enum = status_enum
+ """Status of the reply"""
+ self.message = message
+ """Status message that corresponds with the status enum"""
+
+ @property
+ def has_outputs(self) -> bool:
+ """Check if the InferenceReply contains outputs.
+
+ :returns: True if outputs is not None and is not an empty list,
+ False otherwise
+ """
+ return self.outputs is not None and bool(self.outputs)
+
+ @property
+ def has_output_keys(self) -> bool:
+ """Check if the InferenceReply contains output_keys.
+
+ :returns: True if output_keys is not None and is not an empty list,
+ False otherwise
+ """
+ return self.output_keys is not None and bool(self.output_keys)
+
+
+class LoadModelResult:
+ """A wrapper around a loaded model."""
+
+ def __init__(self, model: t.Any) -> None:
+ """Initialize the LoadModelResult.
+
+ :param model: The loaded model
+ """
+ self.model = model
+ """The loaded model (e.g. a TensorFlow, PyTorch, ONNX, etc. model)"""
+
+
+class TransformInputResult:
+ """A wrapper around a transformed batch of input tensors"""
+
+ def __init__(
+ self,
+ result: t.Any,
+ slices: list[slice],
+ dims: list[list[int]],
+ dtypes: list[str],
+ ) -> None:
+ """Initialize the TransformInputResult.
+
+ :param result: List of Dragon MemoryAlloc objects on which
+ the tensors are stored
+ :param slices: The slices that represent which portion of the
+ input tensors belongs to which request
+ :param dims: Dimension of the transformed tensors
+ :param dtypes: Data type of transformed tensors
+ """
+ self.transformed = result
+ """List of Dragon MemoryAlloc objects on which the tensors are stored"""
+ self.slices = slices
+ """Each slice represents which portion of the input tensors belongs to
+ which request"""
+ self.dims = dims
+ """Dimension of the transformed tensors"""
+ self.dtypes = dtypes
+ """Data type of transformed tensors"""
+
+
+class ExecuteResult:
+ """A wrapper around inference results."""
+
+ def __init__(self, result: t.Any, slices: list[slice]) -> None:
+ """Initialize the ExecuteResult.
+
+ :param result: Result of the execution
+ :param slices: The slices that represent which portion of the input
+ tensors belongs to which request
+ """
+ self.predictions = result
+ """Result of the execution"""
+ self.slices = slices
+ """The slices that represent which portion of the input
+ tensors belongs to which request"""
+
+
+class FetchInputResult:
+ """A wrapper around fetched inputs."""
+
+ def __init__(self, result: t.List[bytes], meta: t.Optional[t.List[t.Any]]) -> None:
+ """Initialize the FetchInputResult.
+
+ :param result: List of input tensor bytes
+ :param meta: List of metadata that corresponds with the inputs
+ """
+ self.inputs = result
+ """List of input tensor bytes"""
+ self.meta = meta
+ """List of metadata that corresponds with the inputs"""
+
+
+class TransformOutputResult:
+ """A wrapper around inference results transformed for transmission."""
+
+ def __init__(
+ self, result: t.Any, shape: t.Optional[t.List[int]], order: str, dtype: str
+ ) -> None:
+ """Initialize the TransformOutputResult.
+
+ :param result: Transformed output results
+ :param shape: Shape of output results
+ :param order: Order of output results
+ :param dtype: Datatype of output results
+ """
+ self.outputs = result
+ """Transformed output results"""
+ self.shape = shape
+ """Shape of output results"""
+ self.order = order
+ """Order of output results"""
+ self.dtype = dtype
+ """Datatype of output results"""
+
+
+class CreateInputBatchResult:
+ """A wrapper around inputs batched into a single request."""
+
+ def __init__(self, result: t.Any) -> None:
+ """Initialize the CreateInputBatchResult.
+
+ :param result: Inputs batched into a single request
+ """
+ self.batch = result
+ """Inputs batched into a single request"""
+
+
+class FetchModelResult:
+ """A wrapper around raw fetched models."""
+
+ def __init__(self, result: bytes) -> None:
+ """Initialize the FetchModelResult.
+
+ :param result: The raw fetched model
+ """
+ self.model_bytes: bytes = result
+ """The raw fetched model"""
+
+
+@dataclass
+class RequestBatch:
+ """A batch of aggregated inference requests."""
+
+ requests: list[InferenceRequest]
+ """List of InferenceRequests in the batch"""
+ inputs: t.Optional[TransformInputResult]
+ """Transformed batch of input tensors"""
+ model_id: "ModelIdentifier"
+ """Model (key, descriptor) tuple"""
+
+ @property
+ def has_valid_requests(self) -> bool:
+ """Returns whether the batch contains at least one request.
+
+ :returns: True if at least one request is available
+ """
+ return len(self.requests) > 0
+
+ @property
+ def has_raw_model(self) -> bool:
+ """Returns whether the batch has a raw model.
+
+ :returns: True if the batch has a raw model
+ """
+ return self.raw_model is not None
+
+ @property
+ def raw_model(self) -> t.Optional[t.Any]:
+ """Returns the raw model to use to execute for this batch
+ if it is available.
+
+ :returns: A model if available, otherwise None"""
+ if self.has_valid_requests:
+ return self.requests[0].raw_model
+ return None
+
+ @property
+ def input_keys(self) -> t.List[TensorKey]:
+ """All input keys available in this batch's requests.
+
+ :returns: All input keys belonging to requests in this batch"""
+ keys = []
+ for request in self.requests:
+ keys.extend(request.input_keys)
+
+ return keys
+
+ @property
+ def output_keys(self) -> t.List[TensorKey]:
+ """All output keys available in this batch's requests.
+
+ :returns: All output keys belonging to requests in this batch"""
+ keys = []
+ for request in self.requests:
+ keys.extend(request.output_keys)
+
+ return keys
+
+
+class MachineLearningWorkerCore:
+ """Basic functionality of ML worker that is shared across all worker types."""
+
+ @staticmethod
+ def deserialize_message(
+ data_blob: bytes,
+ callback_factory: t.Callable[[str], CommChannelBase],
+ ) -> InferenceRequest:
+ """Deserialize a message from a byte stream into an InferenceRequest.
+
+ :param data_blob: The byte stream to deserialize
+ :param callback_factory: A factory method that can create an instance
+ of the desired concrete comm channel type
+ :returns: The raw input message deserialized into an InferenceRequest
+ """
+ request = MessageHandler.deserialize_request(data_blob)
+ model_key: t.Optional[ModelKey] = None
+ model_bytes: t.Optional[Model] = None
+
+ if request.model.which() == "key":
+ model_key = ModelKey(
+ key=request.model.key.key,
+ descriptor=request.model.key.descriptor,
+ )
+ elif request.model.which() == "data":
+ model_bytes = request.model.data
+
+ callback_key = request.replyChannel.descriptor
+ comm_channel = callback_factory(callback_key)
+ input_keys: t.Optional[t.List[TensorKey]] = None
+ input_bytes: t.Optional[t.List[bytes]] = None
+ output_keys: t.Optional[t.List[TensorKey]] = None
+ input_meta: t.Optional[t.List[TensorDescriptor]] = None
+
+ if request.input.which() == "keys":
+ input_keys = [
+ TensorKey(key=value.key, descriptor=value.descriptor)
+ for value in request.input.keys
+ ]
+ elif request.input.which() == "descriptors":
+ input_meta = request.input.descriptors # type: ignore
+
+ if request.output:
+ output_keys = [
+ TensorKey(key=value.key, descriptor=value.descriptor)
+ for value in request.output
+ ]
+
+ inference_request = InferenceRequest(
+ model_key=model_key,
+ callback=comm_channel,
+ raw_inputs=input_bytes,
+ input_meta=input_meta,
+ input_keys=input_keys,
+ output_keys=output_keys,
+ raw_model=model_bytes,
+ batch_size=0,
+ )
+ return inference_request
+
+ @staticmethod
+ def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]:
+ """Assemble the output information based on whether the output
+ information will be in the form of TensorKeys or TensorDescriptors.
+
+ :param reply: The reply that the output belongs to
+ :returns: The list of prepared outputs, depending on the output
+ information needed in the reply
+ """
+ prepared_outputs: t.List[t.Any] = []
+ if reply.has_output_keys:
+ for value in reply.output_keys:
+ if not value:
+ continue
+ msg_key = MessageHandler.build_tensor_key(value.key, value.descriptor)
+ prepared_outputs.append(msg_key)
+ elif reply.has_outputs:
+ for _ in reply.outputs:
+ msg_tensor_desc = MessageHandler.build_tensor_descriptor(
+ "c",
+ "float32",
+ [1],
+ )
+ prepared_outputs.append(msg_tensor_desc)
+ return prepared_outputs
+
+ @staticmethod
+ def fetch_model(
+ batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore]
+ ) -> FetchModelResult:
+ """Given a resource key, retrieve the raw model from a feature store.
+
+ :param batch: The batch of requests that triggered the pipeline
+ :param feature_stores: Available feature stores used for persistence
+ :returns: Raw bytes of the model
+ :raises SmartSimError: If neither a key or a model are provided or the
+ model cannot be retrieved from the feature store
+ :raises ValueError: If a feature store is not available and a raw
+ model is not provided
+ """
+ # All requests in the same batch share the model
+ if batch.raw_model:
+ return FetchModelResult(batch.raw_model.data)
+
+ if not feature_stores:
+ raise ValueError("Feature store is required for model retrieval")
+
+ if batch.model_id is None:
+ raise SmartSimError(
+ "Key must be provided to retrieve model from feature store"
+ )
+
+ key, fsd = batch.model_id.key, batch.model_id.descriptor
+
+ try:
+ feature_store = feature_stores[fsd]
+ raw_bytes: bytes = t.cast(bytes, feature_store[key])
+ return FetchModelResult(raw_bytes)
+ except (FileNotFoundError, KeyError) as ex:
+ logger.exception(ex)
+ raise SmartSimError(f"Model could not be retrieved with key {key}") from ex
+
+ @staticmethod
+ def fetch_inputs(
+ batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore]
+ ) -> t.List[FetchInputResult]:
+ """Given a collection of ResourceKeys, identify the physical location
+ and input metadata.
+
+ :param batch: The batch of requests that triggered the pipeline
+ :param feature_stores: Available feature stores used for persistence
+ :returns: The fetched input
+ :raises ValueError: If neither an input key or an input tensor are provided
+ :raises SmartSimError: If a tensor for a given key cannot be retrieved
+ """
+ fetch_results = []
+ for request in batch.requests:
+ if request.raw_inputs:
+ fetch_results.append(
+ FetchInputResult(request.raw_inputs, request.input_meta)
+ )
+ continue
+
+ if not feature_stores:
+ raise ValueError("No input and no feature store provided")
+
+ if request.has_input_keys:
+ data: t.List[bytes] = []
+
+ for fs_key in request.input_keys:
+ try:
+ feature_store = feature_stores[fs_key.descriptor]
+ tensor_bytes = t.cast(bytes, feature_store[fs_key.key])
+ data.append(tensor_bytes)
+ except KeyError as ex:
+ logger.exception(ex)
+ raise SmartSimError(
+ f"Tensor could not be retrieved with key {fs_key.key}"
+ ) from ex
+ fetch_results.append(
+ FetchInputResult(data, meta=None)
+ ) # fixme: need to get both tensor and descriptor
+ continue
+
+ raise ValueError("No input source")
+
+ return fetch_results
+
+ @staticmethod
+ def place_output(
+ request: InferenceRequest,
+ transform_result: TransformOutputResult,
+ feature_stores: t.Dict[str, FeatureStore],
+ ) -> t.Collection[t.Optional[TensorKey]]:
+ """Given a collection of data, make it available as a shared resource in the
+ feature store.
+
+ :param request: The request that triggered the pipeline
+ :param transform_result: Transformed version of the inference result
+ :param feature_stores: Available feature stores used for persistence
+ :returns: A collection of keys that were placed in the feature store
+ :raises ValueError: If a feature store is not provided
+ """
+ if not feature_stores:
+ raise ValueError("Feature store is required for output persistence")
+
+ keys: t.List[t.Optional[TensorKey]] = []
+ # need to decide how to get back to original sub-batch inputs so they can be
+ # accurately placed, datum might need to include this.
+
+ # Consider parallelizing all PUT feature_store operations
+ for fs_key, v in zip(request.output_keys, transform_result.outputs):
+ feature_store = feature_stores[fs_key.descriptor]
+ feature_store[fs_key.key] = v
+ keys.append(fs_key)
+
+ return keys
+
+
+class MachineLearningWorkerBase(MachineLearningWorkerCore, ABC):
+ """Abstract base class providing contract for a machine learning
+ worker implementation."""
+
+ @staticmethod
+ @abstractmethod
+ def load_model(
+ batch: RequestBatch, fetch_result: FetchModelResult, device: str
+ ) -> LoadModelResult:
+ """Given the raw bytes of an ML model that were fetched, ensure
+ it is loaded into device memory.
+
+ :param request: The request that triggered the pipeline
+ :param fetch_result: The result of a fetch-model operation; contains
+ the raw bytes of the ML model.
+ :param device: The device on which the model must be placed
+ :returns: LoadModelResult wrapping the model loaded for the request
+ :raises ValueError: If model reference object is not found
+ :raises RuntimeError: If loading and evaluating the model failed
+ """
+
+ @staticmethod
+ @abstractmethod
+ def transform_input(
+ batch: RequestBatch,
+ fetch_results: list[FetchInputResult],
+ mem_pool: MemoryPool,
+ ) -> TransformInputResult:
+ """Given a collection of data, perform a transformation on the data and put
+ the raw tensor data on a MemoryPool allocation.
+
+ :param batch: The request that triggered the pipeline
+ :param fetch_result: Raw outputs from fetching inputs out of a feature store
+ :param mem_pool: The memory pool used to access batched input tensors
+ :returns: The transformed inputs wrapped in a TransformInputResult
+ :raises ValueError: If tensors cannot be reconstructed
+ :raises IndexError: If index out of range
+ """
+
+ @staticmethod
+ @abstractmethod
+ def execute(
+ batch: RequestBatch,
+ load_result: LoadModelResult,
+ transform_result: TransformInputResult,
+ device: str,
+ ) -> ExecuteResult:
+ """Execute an ML model on inputs transformed for use by the model.
+
+ :param batch: The batch of requests that triggered the pipeline
+ :param load_result: The result of loading the model onto device memory
+ :param transform_result: The result of transforming inputs for model consumption
+ :param device: The device on which the model will be executed
+ :returns: The result of inference wrapped in an ExecuteResult
+ :raises SmartSimError: If model is not loaded
+ :raises IndexError: If memory slicing is out of range
+ :raises ValueError: If tensor creation fails or is unable to evaluate the model
+ """
+
+ @staticmethod
+ @abstractmethod
+ def transform_output(
+ batch: RequestBatch, execute_result: ExecuteResult
+ ) -> t.List[TransformOutputResult]:
+ """Given inference results, perform transformations required to
+ transmit results to the requestor.
+
+ :param batch: The batch of requests that triggered the pipeline
+ :param execute_result: The result of inference wrapped in an ExecuteResult
+ :returns: A list of transformed outputs
+ :raises IndexError: If indexing is out of range
+ :raises ValueError: If transforming output fails
+ """
diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py
new file mode 100644
index 0000000000..e3d46a7ab3
--- /dev/null
+++ b/smartsim/_core/mli/message_handler.py
@@ -0,0 +1,602 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import typing as t
+
+from .mli_schemas.data import data_references_capnp
+from .mli_schemas.model import model_capnp
+from .mli_schemas.request import request_capnp
+from .mli_schemas.request.request_attributes import request_attributes_capnp
+from .mli_schemas.response import response_capnp
+from .mli_schemas.response.response_attributes import response_attributes_capnp
+from .mli_schemas.tensor import tensor_capnp
+
+
+class MessageHandler:
+ """Utility methods for transforming capnproto messages to and from
+ internal representations.
+ """
+
+ @staticmethod
+ def build_tensor_descriptor(
+ order: "tensor_capnp.Order",
+ data_type: "tensor_capnp.NumericalType",
+ dimensions: t.List[int],
+ ) -> tensor_capnp.TensorDescriptor:
+ """
+ Builds a TensorDescriptor message using the provided
+ order, data type, and dimensions.
+
+ :param order: Order of the tensor, such as row-major (c) or column-major (f)
+ :param data_type: Data type of the tensor
+ :param dimensions: Dimensions of the tensor
+ :returns: The TensorDescriptor
+ :raises ValueError: If building fails
+ """
+ try:
+ description = tensor_capnp.TensorDescriptor.new_message()
+ description.order = order
+ description.dataType = data_type
+ description.dimensions = dimensions
+ except Exception as e:
+ raise ValueError("Error building tensor descriptor.") from e
+
+ return description
+
+ @staticmethod
+ def build_output_tensor_descriptor(
+ order: "tensor_capnp.Order",
+ keys: t.List["data_references_capnp.TensorKey"],
+ data_type: "tensor_capnp.ReturnNumericalType",
+ dimensions: t.List[int],
+ ) -> tensor_capnp.OutputDescriptor:
+ """
+ Builds an OutputDescriptor message using the provided
+ order, data type, and dimensions.
+
+ :param order: Order of the tensor, such as row-major (c) or column-major (f)
+ :param keys: List of TensorKey to apply transorm descriptor to
+ :param data_type: Tranform data type of the tensor
+ :param dimensions: Transform dimensions of the tensor
+ :returns: The OutputDescriptor
+ :raises ValueError: If building fails
+ """
+ try:
+ description = tensor_capnp.OutputDescriptor.new_message()
+ description.order = order
+ description.optionalKeys = keys
+ description.optionalDatatype = data_type
+ description.optionalDimension = dimensions
+
+ except Exception as e:
+ raise ValueError("Error building output tensor descriptor.") from e
+
+ return description
+
+ @staticmethod
+ def build_tensor_key(key: str, descriptor: str) -> data_references_capnp.TensorKey:
+ """
+ Builds a new TensorKey message with the provided key.
+
+ :param key: String to set the TensorKey
+ :param descriptor: A descriptor identifying the feature store
+ containing the key
+ :returns: The TensorKey
+ :raises ValueError: If building fails
+ """
+ try:
+ tensor_key = data_references_capnp.TensorKey.new_message()
+ tensor_key.key = key
+ tensor_key.descriptor = descriptor
+ except Exception as e:
+ raise ValueError("Error building tensor key.") from e
+ return tensor_key
+
+ @staticmethod
+ def build_model(data: bytes, name: str, version: str) -> model_capnp.Model:
+ """
+ Builds a new Model message with the provided data, name, and version.
+
+ :param data: Model data
+ :param name: Model name
+ :param version: Model version
+ :returns: The Model
+ :raises ValueError: If building fails
+ """
+ try:
+ model = model_capnp.Model.new_message()
+ model.data = data
+ model.name = name
+ model.version = version
+ except Exception as e:
+ raise ValueError("Error building model.") from e
+ return model
+
+ @staticmethod
+ def build_model_key(key: str, descriptor: str) -> data_references_capnp.ModelKey:
+ """
+ Builds a new ModelKey message with the provided key.
+
+ :param key: String to set the ModelKey
+ :param descriptor: A descriptor identifying the feature store
+ containing the key
+ :returns: The ModelKey
+ :raises ValueError: If building fails
+ """
+ try:
+ model_key = data_references_capnp.ModelKey.new_message()
+ model_key.key = key
+ model_key.descriptor = descriptor
+ except Exception as e:
+ raise ValueError("Error building tensor key.") from e
+ return model_key
+
+ @staticmethod
+ def build_torch_request_attributes(
+ tensor_type: "request_attributes_capnp.TorchTensorType",
+ ) -> request_attributes_capnp.TorchRequestAttributes:
+ """
+ Builds a new TorchRequestAttributes message with the provided tensor type.
+
+ :param tensor_type: Type of the tensor passed in
+ :returns: The TorchRequestAttributes
+ :raises ValueError: If building fails
+ """
+ try:
+ attributes = request_attributes_capnp.TorchRequestAttributes.new_message()
+ attributes.tensorType = tensor_type
+ except Exception as e:
+ raise ValueError("Error building Torch request attributes.") from e
+ return attributes
+
+ @staticmethod
+ def build_tf_request_attributes(
+ name: str, tensor_type: "request_attributes_capnp.TFTensorType"
+ ) -> request_attributes_capnp.TensorFlowRequestAttributes:
+ """
+ Builds a new TensorFlowRequestAttributes message with
+ the provided name and tensor type.
+
+ :param name: Name of the tensor
+ :param tensor_type: Type of the tensor passed in
+ :returns: The TensorFlowRequestAttributes
+ :raises ValueError: If building fails
+ """
+ try:
+ attributes = (
+ request_attributes_capnp.TensorFlowRequestAttributes.new_message()
+ )
+ attributes.name = name
+ attributes.tensorType = tensor_type
+ except Exception as e:
+ raise ValueError("Error building TensorFlow request attributes.") from e
+ return attributes
+
+ @staticmethod
+ def build_torch_response_attributes() -> (
+ response_attributes_capnp.TorchResponseAttributes
+ ):
+ """
+ Builds a new TorchResponseAttributes message.
+
+ :returns: The TorchResponseAttributes
+ """
+ return response_attributes_capnp.TorchResponseAttributes.new_message()
+
+ @staticmethod
+ def build_tf_response_attributes() -> (
+ response_attributes_capnp.TensorFlowResponseAttributes
+ ):
+ """
+ Builds a new TensorFlowResponseAttributes message.
+
+ :returns: The TensorFlowResponseAttributes
+ """
+ return response_attributes_capnp.TensorFlowResponseAttributes.new_message()
+
+ @staticmethod
+ def _assign_model(
+ request: request_capnp.Request,
+ model: t.Union[data_references_capnp.ModelKey, model_capnp.Model],
+ ) -> None:
+ """
+ Assigns a model to the supplied request.
+
+ :param request: Request being built
+ :param model: Model to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ class_name = model.schema.node.displayName.split(":")[-1] # type: ignore
+ if class_name == "Model":
+ request.model.data = model # type: ignore
+ elif class_name == "ModelKey":
+ request.model.key = model # type: ignore
+ else:
+ raise ValueError("""Invalid custom attribute class name.
+ Expected 'Model' or 'ModelKey'.""")
+ except Exception as e:
+ raise ValueError("Error building model portion of request.") from e
+
+ @staticmethod
+ def _assign_reply_channel(
+ request: request_capnp.Request, reply_channel: str
+ ) -> None:
+ """
+ Assigns a reply channel to the supplied request.
+
+ :param request: Request being built
+ :param reply_channel: Reply channel to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ request.replyChannel.descriptor = reply_channel
+ except Exception as e:
+ raise ValueError("Error building reply channel portion of request.") from e
+
+ @staticmethod
+ def _assign_inputs(
+ request: request_capnp.Request,
+ inputs: t.Union[
+ t.List[data_references_capnp.TensorKey],
+ t.List[tensor_capnp.TensorDescriptor],
+ ],
+ ) -> None:
+ """
+ Assigns inputs to the supplied request.
+
+ :param request: Request being built
+ :param inputs: Inputs to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ if inputs:
+ display_name = inputs[0].schema.node.displayName # type: ignore
+ input_class_name = display_name.split(":")[-1]
+ if input_class_name == "TensorDescriptor":
+ request.input.descriptors = inputs # type: ignore
+ elif input_class_name == "TensorKey":
+ request.input.keys = inputs # type: ignore
+ else:
+ raise ValueError("""Invalid input class name. Expected
+ 'TensorDescriptor' or 'TensorKey'.""")
+ except Exception as e:
+ raise ValueError("Error building inputs portion of request.") from e
+
+ @staticmethod
+ def _assign_outputs(
+ request: request_capnp.Request,
+ outputs: t.List[data_references_capnp.TensorKey],
+ ) -> None:
+ """
+ Assigns outputs to the supplied request.
+
+ :param request: Request being built
+ :param outputs: Outputs to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ request.output = outputs
+
+ except Exception as e:
+ raise ValueError("Error building outputs portion of request.") from e
+
+ @staticmethod
+ def _assign_output_descriptors(
+ request: request_capnp.Request,
+ output_descriptors: t.List[tensor_capnp.OutputDescriptor],
+ ) -> None:
+ """
+ Assigns a list of output tensor descriptors to the supplied request.
+
+ :param request: Request being built
+ :param output_descriptors: Output descriptors to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ request.outputDescriptors = output_descriptors
+ except Exception as e:
+ raise ValueError(
+ "Error building the output descriptors portion of request."
+ ) from e
+
+ @staticmethod
+ def _assign_custom_request_attributes(
+ request: request_capnp.Request,
+ custom_attrs: t.Union[
+ request_attributes_capnp.TorchRequestAttributes,
+ request_attributes_capnp.TensorFlowRequestAttributes,
+ None,
+ ],
+ ) -> None:
+ """
+ Assigns request attributes to the supplied request.
+
+ :param request: Request being built
+ :param custom_attrs: Custom attributes to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ if custom_attrs is None:
+ request.customAttributes.none = custom_attrs
+ else:
+ custom_attribute_class_name = (
+ custom_attrs.schema.node.displayName.split(":")[-1] # type: ignore
+ )
+ if custom_attribute_class_name == "TorchRequestAttributes":
+ request.customAttributes.torch = custom_attrs # type: ignore
+ elif custom_attribute_class_name == "TensorFlowRequestAttributes":
+ request.customAttributes.tf = custom_attrs # type: ignore
+ else:
+ raise ValueError("""Invalid custom attribute class name.
+ Expected 'TensorFlowRequestAttributes' or
+ 'TorchRequestAttributes'.""")
+ except Exception as e:
+ raise ValueError(
+ "Error building custom attributes portion of request."
+ ) from e
+
+ @staticmethod
+ def build_request(
+ reply_channel: str,
+ model: t.Union[data_references_capnp.ModelKey, model_capnp.Model],
+ inputs: t.Union[
+ t.List[data_references_capnp.TensorKey],
+ t.List[tensor_capnp.TensorDescriptor],
+ ],
+ outputs: t.List[data_references_capnp.TensorKey],
+ output_descriptors: t.List[tensor_capnp.OutputDescriptor],
+ custom_attributes: t.Union[
+ request_attributes_capnp.TorchRequestAttributes,
+ request_attributes_capnp.TensorFlowRequestAttributes,
+ None,
+ ],
+ ) -> request_capnp.RequestBuilder:
+ """
+ Builds the request message.
+
+ :param reply_channel: Reply channel to be assigned to request
+ :param model: Model to be assigned to request
+ :param inputs: Inputs to be assigned to request
+ :param outputs: Outputs to be assigned to request
+ :param output_descriptors: Output descriptors to be assigned to request
+ :param custom_attributes: Custom attributes to be assigned to request
+ :returns: The Request
+ """
+ request = request_capnp.Request.new_message()
+ MessageHandler._assign_reply_channel(request, reply_channel)
+ MessageHandler._assign_model(request, model)
+ MessageHandler._assign_inputs(request, inputs)
+ MessageHandler._assign_outputs(request, outputs)
+ MessageHandler._assign_output_descriptors(request, output_descriptors)
+ MessageHandler._assign_custom_request_attributes(request, custom_attributes)
+ return request
+
+ @staticmethod
+ def serialize_request(request: request_capnp.RequestBuilder) -> bytes:
+ """
+ Serializes a built request message.
+
+ :param request: Request to be serialized
+ :returns: Serialized request bytes
+ :raises ValueError: If serialization fails
+ """
+ display_name = request.schema.node.displayName # type: ignore
+ class_name = display_name.split(":")[-1]
+ if class_name != "Request":
+ raise ValueError(
+ "Error serializing the request. Value passed in is not "
+ f"a request: {class_name}"
+ )
+ try:
+ return request.to_bytes()
+ except Exception as e:
+ raise ValueError("Error serializing the request") from e
+
+ @staticmethod
+ def deserialize_request(request_bytes: bytes) -> request_capnp.Request:
+ """
+ Deserializes a serialized request message.
+
+ :param request_bytes: Bytes to be deserialized into a request
+ :returns: Deserialized request
+ :raises ValueError: If deserialization fails
+ """
+ try:
+ bytes_message = request_capnp.Request.from_bytes(
+ request_bytes, traversal_limit_in_words=2**63
+ )
+
+ with bytes_message as message:
+ return message
+ except Exception as e:
+ raise ValueError("Error deserializing the request") from e
+
+ @staticmethod
+ def _assign_status(
+ response: response_capnp.Response, status: "response_capnp.Status"
+ ) -> None:
+ """
+ Assigns a status to the supplied response.
+
+ :param response: Response being built
+ :param status: Status to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ response.status = status
+ except Exception as e:
+ raise ValueError("Error assigning status to response.") from e
+
+ @staticmethod
+ def _assign_message(response: response_capnp.Response, message: str) -> None:
+ """
+ Assigns a message to the supplied response.
+
+ :param response: Response being built
+ :param message: Message to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ response.message = message
+ except Exception as e:
+ raise ValueError("Error assigning message to response.") from e
+
+ @staticmethod
+ def _assign_result(
+ response: response_capnp.Response,
+ result: t.Union[
+ t.List[tensor_capnp.TensorDescriptor],
+ t.List[data_references_capnp.TensorKey],
+ None,
+ ],
+ ) -> None:
+ """
+ Assigns a result to the supplied response.
+
+ :param response: Response being built
+ :param result: Result to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ if result:
+ first_result = result[0]
+ display_name = first_result.schema.node.displayName # type: ignore
+ result_class_name = display_name.split(":")[-1]
+ if result_class_name == "TensorDescriptor":
+ response.result.descriptors = result # type: ignore
+ elif result_class_name == "TensorKey":
+ response.result.keys = result # type: ignore
+ else:
+ raise ValueError("""Invalid custom attribute class name.
+ Expected 'TensorDescriptor' or 'TensorKey'.""")
+ except Exception as e:
+ raise ValueError("Error assigning result to response.") from e
+
+ @staticmethod
+ def _assign_custom_response_attributes(
+ response: response_capnp.Response,
+ custom_attrs: t.Union[
+ response_attributes_capnp.TorchResponseAttributes,
+ response_attributes_capnp.TensorFlowResponseAttributes,
+ None,
+ ],
+ ) -> None:
+ """
+ Assigns custom attributes to the supplied response.
+
+ :param response: Response being built
+ :param custom_attrs: Custom attributes to be assigned
+ :raises ValueError: If building fails
+ """
+ try:
+ if custom_attrs is None:
+ response.customAttributes.none = custom_attrs
+ else:
+ custom_attribute_class_name = (
+ custom_attrs.schema.node.displayName.split(":")[-1] # type: ignore
+ )
+ if custom_attribute_class_name == "TorchResponseAttributes":
+ response.customAttributes.torch = custom_attrs # type: ignore
+ elif custom_attribute_class_name == "TensorFlowResponseAttributes":
+ response.customAttributes.tf = custom_attrs # type: ignore
+ else:
+ raise ValueError("""Invalid custom attribute class name.
+ Expected 'TensorFlowResponseAttributes' or
+ 'TorchResponseAttributes'.""")
+ except Exception as e:
+ raise ValueError("Error assigning custom attributes to response.") from e
+
+ @staticmethod
+ def build_response(
+ status: "response_capnp.Status",
+ message: str,
+ result: t.Union[
+ t.List[tensor_capnp.TensorDescriptor],
+ t.List[data_references_capnp.TensorKey],
+ None,
+ ],
+ custom_attributes: t.Union[
+ response_attributes_capnp.TorchResponseAttributes,
+ response_attributes_capnp.TensorFlowResponseAttributes,
+ None,
+ ],
+ ) -> response_capnp.ResponseBuilder:
+ """
+ Builds the response message.
+
+ :param status: Status to be assigned to response
+ :param message: Message to be assigned to response
+ :param result: Result to be assigned to response
+ :param custom_attributes: Custom attributes to be assigned to response
+ :returns: The Response
+ """
+ response = response_capnp.Response.new_message()
+ MessageHandler._assign_status(response, status)
+ MessageHandler._assign_message(response, message)
+ MessageHandler._assign_result(response, result)
+ MessageHandler._assign_custom_response_attributes(response, custom_attributes)
+ return response
+
+ @staticmethod
+ def serialize_response(response: response_capnp.ResponseBuilder) -> bytes:
+ """
+ Serializes a built response message.
+
+ :param response: Response to be serialized
+ :returns: Serialized response bytes
+ :raises ValueError: If serialization fails
+ """
+ display_name = response.schema.node.displayName # type: ignore
+ class_name = display_name.split(":")[-1]
+ if class_name != "Response":
+ raise ValueError(
+ "Error serializing the response. Value passed in is not "
+ f"a response: {class_name}"
+ )
+ try:
+ return response.to_bytes()
+ except Exception as e:
+ raise ValueError("Error serializing the response") from e
+
+ @staticmethod
+ def deserialize_response(response_bytes: bytes) -> response_capnp.Response:
+ """
+ Deserializes a serialized response message.
+
+ :param response_bytes: Bytes to be deserialized into a response
+ :returns: Deserialized response
+ :raises ValueError: If deserialization fails
+ """
+ try:
+ bytes_message = response_capnp.Response.from_bytes(
+ response_bytes, traversal_limit_in_words=2**63
+ )
+
+ with bytes_message as message:
+ return message
+
+ except Exception as e:
+ raise ValueError("Error deserializing the response") from e
diff --git a/smartsim/_core/mli/mli_schemas/data/data_references.capnp b/smartsim/_core/mli/mli_schemas/data/data_references.capnp
new file mode 100644
index 0000000000..65293be7b2
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/data/data_references.capnp
@@ -0,0 +1,37 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+@0x8ca69fd1aacb6668;
+
+struct ModelKey {
+ key @0 :Text;
+ descriptor @1 :Text;
+}
+
+struct TensorKey {
+ key @0 :Text;
+ descriptor @1 :Text;
+}
diff --git a/smartsim/_core/mli/mli_schemas/data/data_references_capnp.py b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.py
new file mode 100644
index 0000000000..099d10c438
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.py
@@ -0,0 +1,41 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `data_references.capnp`."""
+
+import os
+
+import capnp # type: ignore
+
+capnp.remove_import_hook()
+here = os.path.dirname(os.path.abspath(__file__))
+module_file = os.path.abspath(os.path.join(here, "data_references.capnp"))
+ModelKey = capnp.load(module_file).ModelKey
+ModelKeyBuilder = ModelKey
+ModelKeyReader = ModelKey
+TensorKey = capnp.load(module_file).TensorKey
+TensorKeyBuilder = TensorKey
+TensorKeyReader = TensorKey
diff --git a/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi
new file mode 100644
index 0000000000..a5e318a556
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi
@@ -0,0 +1,107 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `data_references.capnp`."""
+
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from io import BufferedWriter
+from typing import Iterator
+
+class ModelKey:
+ key: str
+ descriptor: str
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[ModelKeyReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> ModelKeyReader: ...
+ @staticmethod
+ def new_message() -> ModelKeyBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class ModelKeyReader(ModelKey):
+ def as_builder(self) -> ModelKeyBuilder: ...
+
+class ModelKeyBuilder(ModelKey):
+ @staticmethod
+ def from_dict(dictionary: dict) -> ModelKeyBuilder: ...
+ def copy(self) -> ModelKeyBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> ModelKeyReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+
+class TensorKey:
+ key: str
+ descriptor: str
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[TensorKeyReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> TensorKeyReader: ...
+ @staticmethod
+ def new_message() -> TensorKeyBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class TensorKeyReader(TensorKey):
+ def as_builder(self) -> TensorKeyBuilder: ...
+
+class TensorKeyBuilder(TensorKey):
+ @staticmethod
+ def from_dict(dictionary: dict) -> TensorKeyBuilder: ...
+ def copy(self) -> TensorKeyBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> TensorKeyReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
diff --git a/smartsim/_core/mli/mli_schemas/model/__init__.py b/smartsim/_core/mli/mli_schemas/model/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/smartsim/_core/mli/mli_schemas/model/model.capnp b/smartsim/_core/mli/mli_schemas/model/model.capnp
new file mode 100644
index 0000000000..fc9ed73663
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/model/model.capnp
@@ -0,0 +1,33 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+@0xaefb9301e14ba4bd;
+
+struct Model {
+ data @0 :Data;
+ name @1 :Text;
+ version @2 :Text;
+}
diff --git a/smartsim/_core/mli/mli_schemas/model/model_capnp.py b/smartsim/_core/mli/mli_schemas/model/model_capnp.py
new file mode 100644
index 0000000000..be2c276c23
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/model/model_capnp.py
@@ -0,0 +1,38 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `model.capnp`."""
+
+import os
+
+import capnp # type: ignore
+
+capnp.remove_import_hook()
+here = os.path.dirname(os.path.abspath(__file__))
+module_file = os.path.abspath(os.path.join(here, "model.capnp"))
+Model = capnp.load(module_file).Model
+ModelBuilder = Model
+ModelReader = Model
diff --git a/smartsim/_core/mli/mli_schemas/model/model_capnp.pyi b/smartsim/_core/mli/mli_schemas/model/model_capnp.pyi
new file mode 100644
index 0000000000..6ca53a3579
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/model/model_capnp.pyi
@@ -0,0 +1,72 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `model.capnp`."""
+
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from io import BufferedWriter
+from typing import Iterator
+
+class Model:
+ data: bytes
+ name: str
+ version: str
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[ModelReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> ModelReader: ...
+ @staticmethod
+ def new_message() -> ModelBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class ModelReader(Model):
+ def as_builder(self) -> ModelBuilder: ...
+
+class ModelBuilder(Model):
+ @staticmethod
+ def from_dict(dictionary: dict) -> ModelBuilder: ...
+ def copy(self) -> ModelBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> ModelReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
diff --git a/smartsim/_core/mli/mli_schemas/request/request.capnp b/smartsim/_core/mli/mli_schemas/request/request.capnp
new file mode 100644
index 0000000000..26d9542d9f
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/request/request.capnp
@@ -0,0 +1,55 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+@0xa27f0152c7bb299e;
+
+using Tensors = import "../tensor/tensor.capnp";
+using RequestAttributes = import "request_attributes/request_attributes.capnp";
+using DataRef = import "../data/data_references.capnp";
+using Models = import "../model/model.capnp";
+
+struct ChannelDescriptor {
+ descriptor @0 :Text;
+}
+
+struct Request {
+ replyChannel @0 :ChannelDescriptor;
+ model :union {
+ key @1 :DataRef.ModelKey;
+ data @2 :Models.Model;
+ }
+ input :union {
+ keys @3 :List(DataRef.TensorKey);
+ descriptors @4 :List(Tensors.TensorDescriptor);
+ }
+ output @5 :List(DataRef.TensorKey);
+ outputDescriptors @6 :List(Tensors.OutputDescriptor);
+ customAttributes :union {
+ torch @7 :RequestAttributes.TorchRequestAttributes;
+ tf @8 :RequestAttributes.TensorFlowRequestAttributes;
+ none @9 :Void;
+ }
+}
diff --git a/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp
new file mode 100644
index 0000000000..f0a319f0a3
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes.capnp
@@ -0,0 +1,49 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+@0xdd14d8ba5c06743f;
+
+enum TorchTensorType {
+ nested @0; # ragged
+ sparse @1;
+ tensor @2; # "normal" tensor
+}
+
+enum TFTensorType {
+ ragged @0;
+ sparse @1;
+ variable @2;
+ constant @3;
+}
+
+struct TorchRequestAttributes {
+ tensorType @0 :TorchTensorType;
+}
+
+struct TensorFlowRequestAttributes {
+ name @0 :Text;
+ tensorType @1 :TFTensorType;
+}
diff --git a/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py
new file mode 100644
index 0000000000..8969f38457
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.py
@@ -0,0 +1,41 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `request_attributes.capnp`."""
+
+import os
+
+import capnp # type: ignore
+
+capnp.remove_import_hook()
+here = os.path.dirname(os.path.abspath(__file__))
+module_file = os.path.abspath(os.path.join(here, "request_attributes.capnp"))
+TorchRequestAttributes = capnp.load(module_file).TorchRequestAttributes
+TorchRequestAttributesBuilder = TorchRequestAttributes
+TorchRequestAttributesReader = TorchRequestAttributes
+TensorFlowRequestAttributes = capnp.load(module_file).TensorFlowRequestAttributes
+TensorFlowRequestAttributesBuilder = TensorFlowRequestAttributes
+TensorFlowRequestAttributesReader = TensorFlowRequestAttributes
diff --git a/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi
new file mode 100644
index 0000000000..c474de4b4f
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/request/request_attributes/request_attributes_capnp.pyi
@@ -0,0 +1,109 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `request_attributes.capnp`."""
+
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from io import BufferedWriter
+from typing import Iterator, Literal
+
+TorchTensorType = Literal["nested", "sparse", "tensor"]
+TFTensorType = Literal["ragged", "sparse", "variable", "constant"]
+
+class TorchRequestAttributes:
+ tensorType: TorchTensorType
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[TorchRequestAttributesReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> TorchRequestAttributesReader: ...
+ @staticmethod
+ def new_message() -> TorchRequestAttributesBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class TorchRequestAttributesReader(TorchRequestAttributes):
+ def as_builder(self) -> TorchRequestAttributesBuilder: ...
+
+class TorchRequestAttributesBuilder(TorchRequestAttributes):
+ @staticmethod
+ def from_dict(dictionary: dict) -> TorchRequestAttributesBuilder: ...
+ def copy(self) -> TorchRequestAttributesBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> TorchRequestAttributesReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+
+class TensorFlowRequestAttributes:
+ name: str
+ tensorType: TFTensorType
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[TensorFlowRequestAttributesReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> TensorFlowRequestAttributesReader: ...
+ @staticmethod
+ def new_message() -> TensorFlowRequestAttributesBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class TensorFlowRequestAttributesReader(TensorFlowRequestAttributes):
+ def as_builder(self) -> TensorFlowRequestAttributesBuilder: ...
+
+class TensorFlowRequestAttributesBuilder(TensorFlowRequestAttributes):
+ @staticmethod
+ def from_dict(dictionary: dict) -> TensorFlowRequestAttributesBuilder: ...
+ def copy(self) -> TensorFlowRequestAttributesBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> TensorFlowRequestAttributesReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
diff --git a/smartsim/_core/mli/mli_schemas/request/request_capnp.py b/smartsim/_core/mli/mli_schemas/request/request_capnp.py
new file mode 100644
index 0000000000..90b8ce194e
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/request/request_capnp.py
@@ -0,0 +1,41 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `request.capnp`."""
+
+import os
+
+import capnp # type: ignore
+
+capnp.remove_import_hook()
+here = os.path.dirname(os.path.abspath(__file__))
+module_file = os.path.abspath(os.path.join(here, "request.capnp"))
+ChannelDescriptor = capnp.load(module_file).ChannelDescriptor
+ChannelDescriptorBuilder = ChannelDescriptor
+ChannelDescriptorReader = ChannelDescriptor
+Request = capnp.load(module_file).Request
+RequestBuilder = Request
+RequestReader = Request
diff --git a/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi
new file mode 100644
index 0000000000..2aab80b1d0
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi
@@ -0,0 +1,319 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `request.capnp`."""
+
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from io import BufferedWriter
+from typing import Iterator, Literal, Sequence, overload
+
+from ..data.data_references_capnp import (
+ ModelKey,
+ ModelKeyBuilder,
+ ModelKeyReader,
+ TensorKey,
+ TensorKeyBuilder,
+ TensorKeyReader,
+)
+from ..model.model_capnp import Model, ModelBuilder, ModelReader
+from ..tensor.tensor_capnp import (
+ OutputDescriptor,
+ OutputDescriptorBuilder,
+ OutputDescriptorReader,
+ TensorDescriptor,
+ TensorDescriptorBuilder,
+ TensorDescriptorReader,
+)
+from .request_attributes.request_attributes_capnp import (
+ TensorFlowRequestAttributes,
+ TensorFlowRequestAttributesBuilder,
+ TensorFlowRequestAttributesReader,
+ TorchRequestAttributes,
+ TorchRequestAttributesBuilder,
+ TorchRequestAttributesReader,
+)
+
+class ChannelDescriptor:
+ descriptor: str
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[ChannelDescriptorReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> ChannelDescriptorReader: ...
+ @staticmethod
+ def new_message() -> ChannelDescriptorBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class ChannelDescriptorReader(ChannelDescriptor):
+ def as_builder(self) -> ChannelDescriptorBuilder: ...
+
+class ChannelDescriptorBuilder(ChannelDescriptor):
+ @staticmethod
+ def from_dict(dictionary: dict) -> ChannelDescriptorBuilder: ...
+ def copy(self) -> ChannelDescriptorBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> ChannelDescriptorReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+
+class Request:
+ class Model:
+ key: ModelKey | ModelKeyBuilder | ModelKeyReader
+ data: Model | ModelBuilder | ModelReader
+ def which(self) -> Literal["key", "data"]: ...
+ @overload
+ def init(self, name: Literal["key"]) -> ModelKey: ...
+ @overload
+ def init(self, name: Literal["data"]) -> Model: ...
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[Request.ModelReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Request.ModelReader: ...
+ @staticmethod
+ def new_message() -> Request.ModelBuilder: ...
+ def to_dict(self) -> dict: ...
+
+ class ModelReader(Request.Model):
+ key: ModelKeyReader
+ data: ModelReader
+ def as_builder(self) -> Request.ModelBuilder: ...
+
+ class ModelBuilder(Request.Model):
+ key: ModelKey | ModelKeyBuilder | ModelKeyReader
+ data: Model | ModelBuilder | ModelReader
+ @staticmethod
+ def from_dict(dictionary: dict) -> Request.ModelBuilder: ...
+ def copy(self) -> Request.ModelBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> Request.ModelReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+
+ class Input:
+ keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader]
+ descriptors: Sequence[
+ TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader
+ ]
+ def which(self) -> Literal["keys", "descriptors"]: ...
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[Request.InputReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Request.InputReader: ...
+ @staticmethod
+ def new_message() -> Request.InputBuilder: ...
+ def to_dict(self) -> dict: ...
+
+ class InputReader(Request.Input):
+ keys: Sequence[TensorKeyReader]
+ descriptors: Sequence[TensorDescriptorReader]
+ def as_builder(self) -> Request.InputBuilder: ...
+
+ class InputBuilder(Request.Input):
+ keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader]
+ descriptors: Sequence[
+ TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader
+ ]
+ @staticmethod
+ def from_dict(dictionary: dict) -> Request.InputBuilder: ...
+ def copy(self) -> Request.InputBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> Request.InputReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+
+ class CustomAttributes:
+ torch: (
+ TorchRequestAttributes
+ | TorchRequestAttributesBuilder
+ | TorchRequestAttributesReader
+ )
+ tf: (
+ TensorFlowRequestAttributes
+ | TensorFlowRequestAttributesBuilder
+ | TensorFlowRequestAttributesReader
+ )
+ none: None
+ def which(self) -> Literal["torch", "tf", "none"]: ...
+ @overload
+ def init(self, name: Literal["torch"]) -> TorchRequestAttributes: ...
+ @overload
+ def init(self, name: Literal["tf"]) -> TensorFlowRequestAttributes: ...
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[Request.CustomAttributesReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Request.CustomAttributesReader: ...
+ @staticmethod
+ def new_message() -> Request.CustomAttributesBuilder: ...
+ def to_dict(self) -> dict: ...
+
+ class CustomAttributesReader(Request.CustomAttributes):
+ torch: TorchRequestAttributesReader
+ tf: TensorFlowRequestAttributesReader
+ def as_builder(self) -> Request.CustomAttributesBuilder: ...
+
+ class CustomAttributesBuilder(Request.CustomAttributes):
+ torch: (
+ TorchRequestAttributes
+ | TorchRequestAttributesBuilder
+ | TorchRequestAttributesReader
+ )
+ tf: (
+ TensorFlowRequestAttributes
+ | TensorFlowRequestAttributesBuilder
+ | TensorFlowRequestAttributesReader
+ )
+ @staticmethod
+ def from_dict(dictionary: dict) -> Request.CustomAttributesBuilder: ...
+ def copy(self) -> Request.CustomAttributesBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> Request.CustomAttributesReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+ replyChannel: ChannelDescriptor | ChannelDescriptorBuilder | ChannelDescriptorReader
+ model: Request.Model | Request.ModelBuilder | Request.ModelReader
+ input: Request.Input | Request.InputBuilder | Request.InputReader
+ output: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader]
+ outputDescriptors: Sequence[
+ OutputDescriptor | OutputDescriptorBuilder | OutputDescriptorReader
+ ]
+ customAttributes: (
+ Request.CustomAttributes
+ | Request.CustomAttributesBuilder
+ | Request.CustomAttributesReader
+ )
+ @overload
+ def init(self, name: Literal["replyChannel"]) -> ChannelDescriptor: ...
+ @overload
+ def init(self, name: Literal["model"]) -> Model: ...
+ @overload
+ def init(self, name: Literal["input"]) -> Input: ...
+ @overload
+ def init(self, name: Literal["customAttributes"]) -> CustomAttributes: ...
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[RequestReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> RequestReader: ...
+ @staticmethod
+ def new_message() -> RequestBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class RequestReader(Request):
+ replyChannel: ChannelDescriptorReader
+ model: Request.ModelReader
+ input: Request.InputReader
+ output: Sequence[TensorKeyReader]
+ outputDescriptors: Sequence[OutputDescriptorReader]
+ customAttributes: Request.CustomAttributesReader
+ def as_builder(self) -> RequestBuilder: ...
+
+class RequestBuilder(Request):
+ replyChannel: ChannelDescriptor | ChannelDescriptorBuilder | ChannelDescriptorReader
+ model: Request.Model | Request.ModelBuilder | Request.ModelReader
+ input: Request.Input | Request.InputBuilder | Request.InputReader
+ output: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader]
+ outputDescriptors: Sequence[
+ OutputDescriptor | OutputDescriptorBuilder | OutputDescriptorReader
+ ]
+ customAttributes: (
+ Request.CustomAttributes
+ | Request.CustomAttributesBuilder
+ | Request.CustomAttributesReader
+ )
+ @staticmethod
+ def from_dict(dictionary: dict) -> RequestBuilder: ...
+ def copy(self) -> RequestBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> RequestReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
diff --git a/smartsim/_core/mli/mli_schemas/response/response.capnp b/smartsim/_core/mli/mli_schemas/response/response.capnp
new file mode 100644
index 0000000000..7194524cd0
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/response/response.capnp
@@ -0,0 +1,52 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+@0xa05dcb4444780705;
+
+using Tensors = import "../tensor/tensor.capnp";
+using ResponseAttributes = import "response_attributes/response_attributes.capnp";
+using DataRef = import "../data/data_references.capnp";
+
+enum Status {
+ complete @0;
+ fail @1;
+ timeout @2;
+ running @3;
+}
+
+struct Response {
+ status @0 :Status;
+ message @1 :Text;
+ result :union {
+ keys @2 :List(DataRef.TensorKey);
+ descriptors @3 :List(Tensors.TensorDescriptor);
+ }
+ customAttributes :union {
+ torch @4 :ResponseAttributes.TorchResponseAttributes;
+ tf @5 :ResponseAttributes.TensorFlowResponseAttributes;
+ none @6 :Void;
+ }
+}
diff --git a/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp
new file mode 100644
index 0000000000..b4dcf18e88
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes.capnp
@@ -0,0 +1,33 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+@0xee59c60fccbb1bf9;
+
+struct TorchResponseAttributes {
+}
+
+struct TensorFlowResponseAttributes {
+}
diff --git a/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py
new file mode 100644
index 0000000000..4839334d52
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.py
@@ -0,0 +1,41 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `response_attributes.capnp`."""
+
+import os
+
+import capnp # type: ignore
+
+capnp.remove_import_hook()
+here = os.path.dirname(os.path.abspath(__file__))
+module_file = os.path.abspath(os.path.join(here, "response_attributes.capnp"))
+TorchResponseAttributes = capnp.load(module_file).TorchResponseAttributes
+TorchResponseAttributesBuilder = TorchResponseAttributes
+TorchResponseAttributesReader = TorchResponseAttributes
+TensorFlowResponseAttributes = capnp.load(module_file).TensorFlowResponseAttributes
+TensorFlowResponseAttributesBuilder = TensorFlowResponseAttributes
+TensorFlowResponseAttributesReader = TensorFlowResponseAttributes
diff --git a/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi
new file mode 100644
index 0000000000..f40688d74a
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/response/response_attributes/response_attributes_capnp.pyi
@@ -0,0 +1,103 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `response_attributes.capnp`."""
+
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from io import BufferedWriter
+from typing import Iterator
+
+class TorchResponseAttributes:
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[TorchResponseAttributesReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> TorchResponseAttributesReader: ...
+ @staticmethod
+ def new_message() -> TorchResponseAttributesBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class TorchResponseAttributesReader(TorchResponseAttributes):
+ def as_builder(self) -> TorchResponseAttributesBuilder: ...
+
+class TorchResponseAttributesBuilder(TorchResponseAttributes):
+ @staticmethod
+ def from_dict(dictionary: dict) -> TorchResponseAttributesBuilder: ...
+ def copy(self) -> TorchResponseAttributesBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> TorchResponseAttributesReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+
+class TensorFlowResponseAttributes:
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[TensorFlowResponseAttributesReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> TensorFlowResponseAttributesReader: ...
+ @staticmethod
+ def new_message() -> TensorFlowResponseAttributesBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class TensorFlowResponseAttributesReader(TensorFlowResponseAttributes):
+ def as_builder(self) -> TensorFlowResponseAttributesBuilder: ...
+
+class TensorFlowResponseAttributesBuilder(TensorFlowResponseAttributes):
+ @staticmethod
+ def from_dict(dictionary: dict) -> TensorFlowResponseAttributesBuilder: ...
+ def copy(self) -> TensorFlowResponseAttributesBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> TensorFlowResponseAttributesReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.py b/smartsim/_core/mli/mli_schemas/response/response_capnp.py
new file mode 100644
index 0000000000..eaa3451045
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.py
@@ -0,0 +1,38 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `response.capnp`."""
+
+import os
+
+import capnp # type: ignore
+
+capnp.remove_import_hook()
+here = os.path.dirname(os.path.abspath(__file__))
+module_file = os.path.abspath(os.path.join(here, "response.capnp"))
+Response = capnp.load(module_file).Response
+ResponseBuilder = Response
+ResponseReader = Response
diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi
new file mode 100644
index 0000000000..6b4c50fd05
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi
@@ -0,0 +1,212 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `response.capnp`."""
+
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from io import BufferedWriter
+from typing import Iterator, Literal, Sequence, overload
+
+from ..data.data_references_capnp import TensorKey, TensorKeyBuilder, TensorKeyReader
+from ..tensor.tensor_capnp import (
+ TensorDescriptor,
+ TensorDescriptorBuilder,
+ TensorDescriptorReader,
+)
+from .response_attributes.response_attributes_capnp import (
+ TensorFlowResponseAttributes,
+ TensorFlowResponseAttributesBuilder,
+ TensorFlowResponseAttributesReader,
+ TorchResponseAttributes,
+ TorchResponseAttributesBuilder,
+ TorchResponseAttributesReader,
+)
+
+Status = Literal["complete", "fail", "timeout", "running"]
+
+class Response:
+ class Result:
+ keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader]
+ descriptors: Sequence[
+ TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader
+ ]
+ def which(self) -> Literal["keys", "descriptors"]: ...
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[Response.ResultReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Response.ResultReader: ...
+ @staticmethod
+ def new_message() -> Response.ResultBuilder: ...
+ def to_dict(self) -> dict: ...
+
+ class ResultReader(Response.Result):
+ keys: Sequence[TensorKeyReader]
+ descriptors: Sequence[TensorDescriptorReader]
+ def as_builder(self) -> Response.ResultBuilder: ...
+
+ class ResultBuilder(Response.Result):
+ keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader]
+ descriptors: Sequence[
+ TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader
+ ]
+ @staticmethod
+ def from_dict(dictionary: dict) -> Response.ResultBuilder: ...
+ def copy(self) -> Response.ResultBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> Response.ResultReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+
+ class CustomAttributes:
+ torch: (
+ TorchResponseAttributes
+ | TorchResponseAttributesBuilder
+ | TorchResponseAttributesReader
+ )
+ tf: (
+ TensorFlowResponseAttributes
+ | TensorFlowResponseAttributesBuilder
+ | TensorFlowResponseAttributesReader
+ )
+ none: None
+ def which(self) -> Literal["torch", "tf", "none"]: ...
+ @overload
+ def init(self, name: Literal["torch"]) -> TorchResponseAttributes: ...
+ @overload
+ def init(self, name: Literal["tf"]) -> TensorFlowResponseAttributes: ...
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[Response.CustomAttributesReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Response.CustomAttributesReader: ...
+ @staticmethod
+ def new_message() -> Response.CustomAttributesBuilder: ...
+ def to_dict(self) -> dict: ...
+
+ class CustomAttributesReader(Response.CustomAttributes):
+ torch: TorchResponseAttributesReader
+ tf: TensorFlowResponseAttributesReader
+ def as_builder(self) -> Response.CustomAttributesBuilder: ...
+
+ class CustomAttributesBuilder(Response.CustomAttributes):
+ torch: (
+ TorchResponseAttributes
+ | TorchResponseAttributesBuilder
+ | TorchResponseAttributesReader
+ )
+ tf: (
+ TensorFlowResponseAttributes
+ | TensorFlowResponseAttributesBuilder
+ | TensorFlowResponseAttributesReader
+ )
+ @staticmethod
+ def from_dict(dictionary: dict) -> Response.CustomAttributesBuilder: ...
+ def copy(self) -> Response.CustomAttributesBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> Response.CustomAttributesReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+ status: Status
+ message: str
+ result: Response.Result | Response.ResultBuilder | Response.ResultReader
+ customAttributes: (
+ Response.CustomAttributes
+ | Response.CustomAttributesBuilder
+ | Response.CustomAttributesReader
+ )
+ @overload
+ def init(self, name: Literal["result"]) -> Result: ...
+ @overload
+ def init(self, name: Literal["customAttributes"]) -> CustomAttributes: ...
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[ResponseReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> ResponseReader: ...
+ @staticmethod
+ def new_message() -> ResponseBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class ResponseReader(Response):
+ result: Response.ResultReader
+ customAttributes: Response.CustomAttributesReader
+ def as_builder(self) -> ResponseBuilder: ...
+
+class ResponseBuilder(Response):
+ result: Response.Result | Response.ResultBuilder | Response.ResultReader
+ customAttributes: (
+ Response.CustomAttributes
+ | Response.CustomAttributesBuilder
+ | Response.CustomAttributesReader
+ )
+ @staticmethod
+ def from_dict(dictionary: dict) -> ResponseBuilder: ...
+ def copy(self) -> ResponseBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> ResponseReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp b/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp
new file mode 100644
index 0000000000..4b2218b166
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp
@@ -0,0 +1,75 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+@0x9a0aeb2e04838fb1;
+
+using DataRef = import "../data/data_references.capnp";
+
+enum Order {
+ c @0; # row major (contiguous layout)
+ f @1; # column major (fortran contiguous layout)
+}
+
+enum NumericalType {
+ int8 @0;
+ int16 @1;
+ int32 @2;
+ int64 @3;
+ uInt8 @4;
+ uInt16 @5;
+ uInt32 @6;
+ uInt64 @7;
+ float32 @8;
+ float64 @9;
+}
+
+enum ReturnNumericalType {
+ int8 @0;
+ int16 @1;
+ int32 @2;
+ int64 @3;
+ uInt8 @4;
+ uInt16 @5;
+ uInt32 @6;
+ uInt64 @7;
+ float32 @8;
+ float64 @9;
+ none @10;
+ auto @11;
+}
+
+struct TensorDescriptor {
+ dimensions @0 :List(Int32);
+ order @1 :Order;
+ dataType @2 :NumericalType;
+}
+
+struct OutputDescriptor {
+ order @0 :Order;
+ optionalKeys @1 :List(DataRef.TensorKey);
+ optionalDimension @2 :List(Int32);
+ optionalDatatype @3 :ReturnNumericalType;
+}
diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py
new file mode 100644
index 0000000000..8c9d6c9029
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py
@@ -0,0 +1,41 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `tensor.capnp`."""
+
+import os
+
+import capnp # type: ignore
+
+capnp.remove_import_hook()
+here = os.path.dirname(os.path.abspath(__file__))
+module_file = os.path.abspath(os.path.join(here, "tensor.capnp"))
+TensorDescriptor = capnp.load(module_file).TensorDescriptor
+TensorDescriptorBuilder = TensorDescriptor
+TensorDescriptorReader = TensorDescriptor
+OutputDescriptor = capnp.load(module_file).OutputDescriptor
+OutputDescriptorBuilder = OutputDescriptor
+OutputDescriptorReader = OutputDescriptor
diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi
new file mode 100644
index 0000000000..b55f26b452
--- /dev/null
+++ b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi
@@ -0,0 +1,142 @@
+# BSD 2-Clause License
+
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""This is an automatically generated stub for `tensor.capnp`."""
+
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from io import BufferedWriter
+from typing import Iterator, Literal, Sequence
+
+from ..data.data_references_capnp import TensorKey, TensorKeyBuilder, TensorKeyReader
+
+Order = Literal["c", "f"]
+NumericalType = Literal[
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uInt8",
+ "uInt16",
+ "uInt32",
+ "uInt64",
+ "float32",
+ "float64",
+]
+ReturnNumericalType = Literal[
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uInt8",
+ "uInt16",
+ "uInt32",
+ "uInt64",
+ "float32",
+ "float64",
+ "none",
+ "auto",
+]
+
+class TensorDescriptor:
+ dimensions: Sequence[int]
+ order: Order
+ dataType: NumericalType
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[TensorDescriptorReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> TensorDescriptorReader: ...
+ @staticmethod
+ def new_message() -> TensorDescriptorBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class TensorDescriptorReader(TensorDescriptor):
+ def as_builder(self) -> TensorDescriptorBuilder: ...
+
+class TensorDescriptorBuilder(TensorDescriptor):
+ @staticmethod
+ def from_dict(dictionary: dict) -> TensorDescriptorBuilder: ...
+ def copy(self) -> TensorDescriptorBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> TensorDescriptorReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
+
+class OutputDescriptor:
+ order: Order
+ optionalKeys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader]
+ optionalDimension: Sequence[int]
+ optionalDatatype: ReturnNumericalType
+ @staticmethod
+ @contextmanager
+ def from_bytes(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> Iterator[OutputDescriptorReader]: ...
+ @staticmethod
+ def from_bytes_packed(
+ data: bytes,
+ traversal_limit_in_words: int | None = ...,
+ nesting_limit: int | None = ...,
+ ) -> OutputDescriptorReader: ...
+ @staticmethod
+ def new_message() -> OutputDescriptorBuilder: ...
+ def to_dict(self) -> dict: ...
+
+class OutputDescriptorReader(OutputDescriptor):
+ optionalKeys: Sequence[TensorKeyReader]
+ def as_builder(self) -> OutputDescriptorBuilder: ...
+
+class OutputDescriptorBuilder(OutputDescriptor):
+ optionalKeys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader]
+ @staticmethod
+ def from_dict(dictionary: dict) -> OutputDescriptorBuilder: ...
+ def copy(self) -> OutputDescriptorBuilder: ...
+ def to_bytes(self) -> bytes: ...
+ def to_bytes_packed(self) -> bytes: ...
+ def to_segments(self) -> list[bytes]: ...
+ def as_reader(self) -> OutputDescriptorReader: ...
+ @staticmethod
+ def write(file: BufferedWriter) -> None: ...
+ @staticmethod
+ def write_packed(file: BufferedWriter) -> None: ...
diff --git a/smartsim/_core/schemas/utils.py b/smartsim/_core/schemas/utils.py
index 9cb36bcf57..905fe8955c 100644
--- a/smartsim/_core/schemas/utils.py
+++ b/smartsim/_core/schemas/utils.py
@@ -48,7 +48,7 @@ class _Message(t.Generic[_SchemaT]):
delimiter: str = pydantic.Field(min_length=1, default=_DEFAULT_MSG_DELIM)
def __str__(self) -> str:
- return self.delimiter.join((self.header, self.payload.json()))
+ return self.delimiter.join((self.header, self.payload.model_dump_json()))
@classmethod
def from_str(
@@ -58,7 +58,7 @@ def from_str(
delimiter: str = _DEFAULT_MSG_DELIM,
) -> "_Message[_SchemaT]":
header, payload = str_.split(delimiter, 1)
- return cls(payload_type.parse_raw(payload), header, delimiter)
+ return cls(payload_type.model_validate_json(payload), header, delimiter)
class SchemaRegistry(t.Generic[_SchemaT]):
diff --git a/smartsim/_core/shell/shell_launcher.py b/smartsim/_core/shell/shell_launcher.py
index 9c05f38f6a..530ac8a641 100644
--- a/smartsim/_core/shell/shell_launcher.py
+++ b/smartsim/_core/shell/shell_launcher.py
@@ -50,6 +50,8 @@
logger = get_logger(__name__)
+# pylint: disable=unspecified-encoding
+
class ShellLauncherCommand(t.NamedTuple):
env: EnvironMappingType
@@ -110,14 +112,13 @@ def impl(
else exe
)
# pylint: disable-next=consider-using-with
- return ShellLauncherCommand( # pylint: disable-next=unspecified-encoding
+ return ShellLauncherCommand(
env, pathlib.Path(path), open(stdout_path), open(stderr_path), command_tuple
)
return impl
-# pylint: disable=no-self-use
class ShellLauncher:
"""A launcher for launching/tracking local shell commands"""
diff --git a/smartsim/_core/types.py b/smartsim/_core/types.py
new file mode 100644
index 0000000000..d3dc029eaa
--- /dev/null
+++ b/smartsim/_core/types.py
@@ -0,0 +1,32 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import enum
+
+
+class Device(enum.Enum):
+ CPU = "cpu"
+ GPU = "gpu"
diff --git a/smartsim/_core/utils/__init__.py b/smartsim/_core/utils/__init__.py
index 30256034cb..4159c90424 100644
--- a/smartsim/_core/utils/__init__.py
+++ b/smartsim/_core/utils/__init__.py
@@ -29,5 +29,6 @@
colorize,
delete_elements,
execute_platform_cmd,
+ expand_exe_path,
is_crayex_platform,
)
diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py
index 04c17d04c8..e498c26209 100644
--- a/smartsim/_core/utils/helpers.py
+++ b/smartsim/_core/utils/helpers.py
@@ -32,11 +32,14 @@
import base64
import collections.abc
import functools
+import itertools
import os
import signal
import subprocess
+import sys
import typing as t
import uuid
+import warnings
from datetime import datetime
from shutil import which
@@ -52,6 +55,7 @@
_Ts = TypeVarTuple("_Ts")
+_TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime"]
_T = t.TypeVar("_T")
_HashableT = t.TypeVar("_HashableT", bound=t.Hashable)
_TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object]
@@ -66,7 +70,6 @@ def unpack(value: _NestedJobSequenceType) -> t.Generator[Job, None, None]:
:param value: Sequence containing elements of type Job or other
sequences that are also of type _NestedJobSequenceType
:return: flattened list of Jobs"""
-
from smartsim.launchable.job import Job # pylint: disable=import-outside-toplevel
for item in value:
@@ -602,3 +605,47 @@ def push_unique(self, fn: _TSignalHandlerFn) -> bool:
if did_push := fn not in self:
self.push(fn)
return did_push
+
+
+def _create_pinning_string(
+ pin_ids: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]], cpus: int
+) -> t.Optional[str]:
+ """Create a comma-separated string of CPU ids. By default, ``None``
+ returns 0,1,...,cpus-1; an empty iterable will disable pinning
+ altogether, and an iterable constructs a comma separated string of
+ integers (e.g. ``[0, 2, 5]`` -> ``"0,2,5"``)
+
+ :params pin_ids: CPU ids
+ :params cpu: number of CPUs
+ :raises TypeError: if pin id is not an iterable of ints
+ :returns: a comma separated string of CPU ids
+ """
+
+ try:
+ pin_ids = tuple(pin_ids) if pin_ids is not None else None
+ except TypeError:
+ raise TypeError(
+ "Expected a cpu pinning specification of type iterable of ints or "
+ f"iterables of ints. Instead got type `{type(pin_ids)}`"
+ ) from None
+
+ # Deal with MacOSX limitations first. The "None" (default) disables pinning
+ # and is equivalent to []. The only invalid option is a non-empty pinning
+ if sys.platform == "darwin":
+ if pin_ids:
+ warnings.warn(
+ "CPU pinning is not supported on MacOSX. Ignoring pinning "
+ "specification.",
+ RuntimeWarning,
+ )
+ return None
+
+ # Flatten the iterable into a list and check to make sure that the resulting
+ # elements are all ints
+ if pin_ids is None:
+ return ",".join(_stringify_id(i) for i in range(cpus))
+ if not pin_ids:
+ return None
+ pin_ids = ((x,) if isinstance(x, int) else x for x in pin_ids)
+ to_fmt = itertools.chain.from_iterable(pin_ids)
+ return ",".join(sorted({_stringify_id(x) for x in to_fmt}))
diff --git a/smartsim/_core/utils/serialize.py b/smartsim/_core/utils/serialize.py
index 4725c95654..c2f70a25ab 100644
--- a/smartsim/_core/utils/serialize.py
+++ b/smartsim/_core/utils/serialize.py
@@ -33,6 +33,8 @@
import smartsim._core._cli.utils as _utils
import smartsim.log
+from smartsim.settings.batch_settings import BatchSettings
+from smartsim.settings.launch_settings import LaunchSettings
if t.TYPE_CHECKING:
from smartsim._core.control.manifest import LaunchedManifest as _Manifest
@@ -40,8 +42,6 @@
from smartsim.database.feature_store import FeatureStore
from smartsim.entity import Application, FSNode
from smartsim.entity.dbobject import FSModel, FSScript
- from smartsim.settings.batch_settings import BatchSettings
- from smartsim.settings.launch_settings import LaunchSettings
TStepLaunchMetaData = t.Tuple[
@@ -235,7 +235,7 @@ def _dictify_fs(
fs_type = "Unknown"
return {
- "name": feature_store.name,
+ "name": feature_store.fs_identifier,
"type": fs_type,
"interface": feature_store._interfaces, # pylint: disable=protected-access
"shards": [
diff --git a/smartsim/_core/utils/timings.py b/smartsim/_core/utils/timings.py
new file mode 100644
index 0000000000..f99950739e
--- /dev/null
+++ b/smartsim/_core/utils/timings.py
@@ -0,0 +1,175 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import time
+import typing as t
+from collections import OrderedDict
+
+import numpy as np
+
+from ...log import get_logger
+
+logger = get_logger("PerfTimer")
+
+
+class PerfTimer:
+ def __init__(
+ self,
+ filename: str = "timings",
+ prefix: str = "",
+ timing_on: bool = True,
+ debug: bool = False,
+ ):
+ self._start: t.Optional[float] = None
+ self._interm: t.Optional[float] = None
+ self._timings: OrderedDict[str, list[t.Union[float, int, str]]] = OrderedDict()
+ self._timing_on = timing_on
+ self._filename = filename
+ self._prefix = prefix
+ self._debug = debug
+
+ def _add_label_to_timings(self, label: str) -> None:
+ if label not in self._timings:
+ self._timings[label] = []
+
+ @staticmethod
+ def _format_number(number: t.Union[float, int]) -> str:
+ """Formats the input value with a fixed precision appropriate for logging"""
+ return f"{number:0.4e}"
+
+ def start_timings(
+ self,
+ first_label: t.Optional[str] = None,
+ first_value: t.Optional[t.Union[float, int]] = None,
+ ) -> None:
+ """Start a recording session by recording
+
+ :param first_label: a label for an event that will be manually prepended
+ to the timing information before starting timers
+ :param first_label: a value for an event that will be manually prepended
+ to the timing information before starting timers"""
+ if self._timing_on:
+ if first_label is not None and first_value is not None:
+ mod_label = self._make_label(first_label)
+ value = self._format_number(first_value)
+ self._log(f"Started timing: {first_label}: {value}")
+ self._add_label_to_timings(mod_label)
+ self._timings[mod_label].append(value)
+ self._start = time.perf_counter()
+ self._interm = time.perf_counter()
+
+ def end_timings(self) -> None:
+ """Record a timing event and clear the last checkpoint"""
+ if self._timing_on and self._start is not None:
+ mod_label = self._make_label("total_time")
+ self._add_label_to_timings(mod_label)
+ delta = self._format_number(time.perf_counter() - self._start)
+ self._timings[self._make_label("total_time")].append(delta)
+ self._log(f"Finished timing: {mod_label}: {delta}")
+ self._interm = None
+
+ def _make_label(self, label: str) -> str:
+ """Return a label formatted with the current label prefix
+
+ :param label: the original label
+ :returns: the adjusted label value"""
+ return self._prefix + label
+
+ def _get_delta(self) -> float:
+ """Calculates the offset from the last intermediate checkpoint time
+
+ :returns: the number of seconds elapsed"""
+ if self._interm is None:
+ return 0
+ return time.perf_counter() - self._interm
+
+ def get_last(self, label: str) -> str:
+ """Return the last timing value collected for the given label in
+ the format `{label}: {value}`. If no timing value has been collected
+ with the label, returns `Not measured yet`"""
+ mod_label = self._make_label(label)
+ if mod_label in self._timings:
+ value = self._timings[mod_label][-1]
+ if value:
+ return f"{label}: {value}"
+
+ return "Not measured yet"
+
+ def measure_time(self, label: str) -> None:
+ """Record a new time event if timing is enabled
+
+ :param label: the label to record a timing event for"""
+ if self._timing_on and self._interm is not None:
+ mod_label = self._make_label(label)
+ self._add_label_to_timings(mod_label)
+ delta = self._format_number(self._get_delta())
+ self._timings[mod_label].append(delta)
+ self._log(f"{mod_label}: {delta}")
+ self._interm = time.perf_counter()
+
+ def _log(self, msg: str) -> None:
+ """Conditionally logs a message when the debug flag is enabled
+
+ :param msg: the message to be logged"""
+ if self._debug:
+ logger.info(msg)
+
+ @property
+ def max_length(self) -> int:
+ """Returns the number of records contained in the largest timing set"""
+ if len(self._timings) == 0:
+ return 0
+ return max(len(value) for value in self._timings.values())
+
+ def print_timings(self, to_file: bool = False) -> None:
+ """Print timing information to standard output. If `to_file`
+ is `True`, also write results to a file.
+
+ :param to_file: If `True`, also saves timing information
+ to the files `timings.npy` and `timings.txt`
+ """
+ print(" ".join(self._timings.keys()))
+ try:
+ value_array = np.array(list(self._timings.values()), dtype=float)
+ except Exception as e:
+ logger.exception(e)
+ return
+ value_array = np.transpose(value_array)
+ if self._debug:
+ for i in range(value_array.shape[0]):
+ print(" ".join(self._format_number(value) for value in value_array[i]))
+ if to_file:
+ np.save(self._prefix + self._filename + ".npy", value_array)
+
+ @property
+ def is_active(self) -> bool:
+ """Return `True` if timer is recording, `False` otherwise"""
+ return self._timing_on
+
+ @is_active.setter
+ def is_active(self, active: bool) -> None:
+ """Set to `True` to record timing information, `False` otherwise"""
+ self._timing_on = active
diff --git a/smartsim/entity/_mock.py b/smartsim/entity/_mock.py
index 8f1043ed3c..7b9c43c5c8 100644
--- a/smartsim/entity/_mock.py
+++ b/smartsim/entity/_mock.py
@@ -34,6 +34,18 @@
import typing as t
+import pytest
+
+from smartsim._core.mli.infrastructure.control.worker_manager import build_failure_reply
+
+dragon = pytest.importorskip("dragon")
+
+if t.TYPE_CHECKING:
+ from smartsim._core.mli.mli_schemas.response.response_capnp import Status
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
class Mock:
"""Base mock class"""
@@ -44,3 +56,28 @@ def __getattr__(self, _: str) -> Mock:
def __deepcopy__(self, _: dict[t.Any, t.Any]) -> Mock:
return type(self)()
+
+
+@pytest.mark.parametrize(
+ "status, message",
+ [
+ pytest.param("timeout", "Worker timed out", id="timeout"),
+ pytest.param("fail", "Failed while executing", id="fail"),
+ ],
+)
+def test_build_failure_reply(status: "Status", message: str):
+ "Ensures failure replies can be built successfully"
+ response = build_failure_reply(status, message)
+ display_name = response.schema.node.displayName # type: ignore
+ class_name = display_name.split(":")[-1]
+ assert class_name == "Response"
+ assert response.status == status
+ assert response.message == message
+
+
+def test_build_failure_reply_fails():
+ "Ensures ValueError is raised if a Status Enum is not used"
+ with pytest.raises(ValueError) as ex:
+ build_failure_reply("not a status enum", "message")
+
+ assert "Error assigning status to response" in ex.value.args[0]
diff --git a/smartsim/entity/application.py b/smartsim/entity/application.py
index da8ec052cf..3dd37d4afe 100644
--- a/smartsim/entity/application.py
+++ b/smartsim/entity/application.py
@@ -28,14 +28,11 @@
import collections
import copy
-import itertools
-import sys
import textwrap
import typing as t
-import warnings
from .._core.generation.operations.operations import FileSysOperationSet
-from .._core.utils.helpers import _stringify_id, expand_exe_path
+from .._core.utils.helpers import expand_exe_path
from ..log import get_logger
from .entity import SmartSimEntity
@@ -219,11 +216,10 @@ def key_prefixing_enabled(self, value: bool) -> None:
self.key_prefixing_enabled = copy.deepcopy(value)
def as_executable_sequence(self) -> t.Sequence[str]:
- """Converts the executable and its arguments into a sequence
- of program arguments.
+ """Converts the executable and its arguments into a sequence of program
+ arguments.
- :return: a sequence of strings representing the executable and
- its arguments
+ :return: a sequence of strings representing the executable and its arguments
"""
return [self.exe, *self.exe_args]
@@ -251,50 +247,6 @@ def _build_exe_args(exe_args: t.Union[str, t.Sequence[str], None]) -> t.List[str
return list(exe_args)
- @staticmethod
- def _create_pinning_string(
- pin_ids: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]], cpus: int
- ) -> t.Optional[str]:
- """Create a comma-separated string of CPU ids. By default, ``None``
- returns 0,1,...,cpus-1; an empty iterable will disable pinning
- altogether, and an iterable constructs a comma separated string of
- integers (e.g. ``[0, 2, 5]`` -> ``"0,2,5"``)
-
- :params pin_ids: CPU ids
- :params cpu: number of CPUs
- :raises TypeError: if pin id is not an iterable of ints
- :returns: a comma separated string of CPU ids
- """
-
- try:
- pin_ids = tuple(pin_ids) if pin_ids is not None else None
- except TypeError:
- raise TypeError(
- "Expected a cpu pinning specification of type iterable of ints or "
- f"iterables of ints. Instead got type `{type(pin_ids)}`"
- ) from None
-
- # Deal with MacOSX limitations first. The "None" (default) disables pinning
- # and is equivalent to []. The only invalid option is a non-empty pinning
- if sys.platform == "darwin":
- if pin_ids:
- warnings.warn(
- "CPU pinning is not supported on MacOSX. Ignoring pinning "
- "specification.",
- RuntimeWarning,
- )
- return None
-
- # Flatten the iterable into a list and check to make sure that the resulting
- # elements are all ints
- if pin_ids is None:
- return ",".join(_stringify_id(i) for i in range(cpus))
- if not pin_ids:
- return None
- pin_ids = ((x,) if isinstance(x, int) else x for x in pin_ids)
- to_fmt = itertools.chain.from_iterable(pin_ids)
- return ",".join(sorted({_stringify_id(x) for x in to_fmt}))
-
def __str__(self) -> str: # pragma: no cover
exe_args_str = "\n".join(self.exe_args)
entities_str = "\n".join(str(entity) for entity in self.incoming_entities)
diff --git a/smartsim/entity/dbobject.py b/smartsim/entity/dbobject.py
index f82aeea183..477564e83d 100644
--- a/smartsim/entity/dbobject.py
+++ b/smartsim/entity/dbobject.py
@@ -27,7 +27,8 @@
import typing as t
from pathlib import Path
-from .._core._install.builder import Device
+from smartsim._core.types import Device
+
from ..error import SSUnsupportedError
__all__ = ["FSObject", "FSModel", "FSScript"]
diff --git a/smartsim/experiment.py b/smartsim/experiment.py
index 4dc99975e9..9e6a657d90 100644
--- a/smartsim/experiment.py
+++ b/smartsim/experiment.py
@@ -157,6 +157,13 @@ def __init__(self, name: str, exp_path: str | None = None):
experiment
"""
+ def _set_dragon_server_path(self) -> None:
+ """Set path for dragon server through environment varialbes"""
+ if not "SMARTSIM_DRAGON_SERVER_PATH" in environ:
+ environ["_SMARTSIM_DRAGON_SERVER_PATH_EXP"] = osp.join(
+ self.exp_path, CONFIG.dragon_default_subdir
+ )
+
def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]:
"""Execute a collection of `Job` instances.
@@ -175,7 +182,7 @@ def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]:
jobs_ = list(_helpers.unpack(jobs))
run_id = datetime.datetime.now().replace(microsecond=0).isoformat()
- root = pathlib.Path(self.exp_path, run_id)
+ root = pathlib.Path(self.exp_path, run_id.replace(":", "."))
return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs_)
def _dispatch(
@@ -202,18 +209,18 @@ def execute_dispatch(generator: Generator, job: Job, idx: int) -> LaunchedJobID:
args = job.launch_settings.launch_args
env = job.launch_settings.env_vars
exe = job.entity.as_executable_sequence()
- dispatch_instance = dispatcher.get_dispatch(args)
+ dispatch_item = dispatcher.get_dispatch(args)
try:
# Check to see if one of the existing launchers can be
# configured to handle the launch arguments ...
- launch_config = dispatch_instance.configure_first_compatible_launcher(
+ launch_config = dispatch_item.configure_first_compatible_launcher(
from_available_launchers=self._launch_history.iter_past_launchers(),
with_arguments=args,
)
except errors.LauncherNotFoundError:
# ... otherwise create a new launcher that _can_ handle the
# launch arguments and configure _that_ one
- launch_config = dispatch_instance.create_new_launcher_configuration(
+ launch_config = dispatch_item.create_new_launcher_configuration(
for_experiment=self, with_arguments=args
)
# Generate the job directory and return the generated job path
@@ -483,8 +490,8 @@ def _append_to_fs_identifier_list(self, fs_identifier: str) -> None:
if fs_identifier in self._fs_identifiers:
logger.warning(
f"A feature store with the identifier {fs_identifier} has already "
- "been made. An error will be raised if multiple Feature Stores "
- "are started with the same identifier"
+ "been made. An error will be raised if multiple Feature Stores are "
+ "with the same identifier"
)
# Otherwise, add
self._fs_identifiers.add(fs_identifier)
diff --git a/smartsim/launchable/mpmd_job.py b/smartsim/launchable/mpmd_job.py
index ab2aa2db6b..de9545032b 100644
--- a/smartsim/launchable/mpmd_job.py
+++ b/smartsim/launchable/mpmd_job.py
@@ -60,7 +60,7 @@ def _check_entity(mpmd_pairs: t.List[MPMDPair]) -> None:
ret: SmartSimEntity | None = None
for mpmd_pair in mpmd_pairs:
if flag == 1:
- if isinstance(ret, type(mpmd_pair.entity)):
+ if type(ret) == type(mpmd_pair.entity):
flag = 0
else:
raise SSUnsupportedError(
@@ -108,7 +108,6 @@ def get_launch_steps(self) -> LaunchCommands:
# TODO: return MPMDJobWarehouseRunner.run(self)
raise NotImplementedError
- # pylint: disable=unnecessary-lambda-assignment
def __str__(self) -> str: # pragma: no cover
"""returns A user-readable string of a MPMD Job"""
fmt = lambda mpmd_pair: textwrap.dedent(
diff --git a/smartsim/log.py b/smartsim/log.py
index 3d6c0860ee..c8fed9329f 100644
--- a/smartsim/log.py
+++ b/smartsim/log.py
@@ -252,16 +252,21 @@ def filter(self, record: logging.LogRecord) -> bool:
return record.levelno <= level_no
-def log_to_file(filename: str, log_level: str = "debug") -> None:
+def log_to_file(
+ filename: str, log_level: str = "debug", logger: t.Optional[logging.Logger] = None
+) -> None:
"""Installs a second filestream handler to the root logger,
allowing subsequent logging calls to be sent to filename.
- :param filename: the name of the desired log file.
- :param log_level: as defined in get_logger. Can be specified
+ :param filename: The name of the desired log file.
+ :param log_level: As defined in get_logger. Can be specified
to allow the file to store more or less verbose
logging information.
+ :param logger: If supplied, a logger to add the file stream logging
+ behavior to. By default, a new logger is instantiated.
"""
- logger = logging.getLogger("SmartSim")
+ if logger is None:
+ logger = logging.getLogger("SmartSim")
stream = open( # pylint: disable=consider-using-with
filename, "w+", encoding="utf-8"
)
diff --git a/smartsim/ml/tf/__init__.py b/smartsim/ml/tf/__init__.py
index 46d89d7336..ee791ea985 100644
--- a/smartsim/ml/tf/__init__.py
+++ b/smartsim/ml/tf/__init__.py
@@ -31,23 +31,12 @@
logger = get_logger(__name__)
vers = Versioner()
-TF_VERSION = vers.TENSORFLOW
try:
import tensorflow as tf
except ImportError: # pragma: no cover
raise ModuleNotFoundError(
- f"TensorFlow {TF_VERSION} is not installed. "
- "Please install it to use smartsim.ml.tf"
- ) from None
-
-try:
- installed_tf = Version_(tf.__version__)
- assert installed_tf >= TF_VERSION
-except AssertionError: # pragma: no cover
- raise SmartSimError(
- f"TensorFlow >= {TF_VERSION} is required for smartsim. "
- f"tf, you have {tf.__version__}"
+ f"TensorFlow is not installed. Please install it to use smartsim.ml.tf"
) from None
diff --git a/smartsim/ml/tf/utils.py b/smartsim/ml/tf/utils.py
index dc66c3b55a..74e39d35b2 100644
--- a/smartsim/ml/tf/utils.py
+++ b/smartsim/ml/tf/utils.py
@@ -29,7 +29,7 @@
import keras
import tensorflow as tf
-from tensorflow.python.framework.convert_to_constants import (
+from tensorflow.python.framework.convert_to_constants import ( # type: ignore[import-not-found,unused-ignore]
convert_variables_to_constants_v2,
)
@@ -58,7 +58,7 @@ def freeze_model(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
)
- frozen_func = convert_variables_to_constants_v2(full_model)
+ frozen_func = convert_variables_to_constants_v2(full_model) # type: ignore[no-untyped-call,unused-ignore]
frozen_func.graph.as_graph_def()
input_names = [x.name.split(":")[0] for x in frozen_func.inputs]
@@ -89,7 +89,7 @@ def serialize_model(model: keras.Model) -> t.Tuple[str, t.List[str], t.List[str]
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
)
- frozen_func = convert_variables_to_constants_v2(full_model)
+ frozen_func = convert_variables_to_constants_v2(full_model) # type: ignore[no-untyped-call,unused-ignore]
frozen_func.graph.as_graph_def()
input_names = [x.name.split(":")[0] for x in frozen_func.inputs]
diff --git a/smartsim/settings/arguments/launch/dragon.py b/smartsim/settings/arguments/launch/dragon.py
index 17874455ca..e58cf69b1b 100644
--- a/smartsim/settings/arguments/launch/dragon.py
+++ b/smartsim/settings/arguments/launch/dragon.py
@@ -86,6 +86,24 @@ def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None:
raise TypeError("feature_list must be string or list of strings")
self.set("node-feature", ",".join(feature_list))
+ def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None:
+ """Specify the hostlist for this job
+
+ :param host_list: hosts to launch on
+ :raises ValueError: if an empty host list is supplied
+ """
+ if not host_list:
+ raise ValueError("empty hostlist provided")
+
+ if isinstance(host_list, str):
+ host_list = host_list.replace(" ", "").split(",")
+
+ # strip out all whitespace-only values
+ cleaned_list = [host.strip() for host in host_list if host and host.strip()]
+ if not len(cleaned_list) == len(host_list):
+ raise ValueError(f"invalid names found in hostlist: {host_list}")
+ self.set("host-list", ",".join(cleaned_list))
+
def set_cpu_affinity(self, devices: t.List[int]) -> None:
"""Set the CPU affinity for this job
diff --git a/smartsim/settings/batch_settings.py b/smartsim/settings/batch_settings.py
index 7489fa8edd..61b69ca8e7 100644
--- a/smartsim/settings/batch_settings.py
+++ b/smartsim/settings/batch_settings.py
@@ -65,7 +65,8 @@ class BatchSettings(BaseSettings):
def __init__(
self,
batch_scheduler: t.Union[BatchSchedulerType, str],
- batch_args: StringArgument = None,
+ batch_args: StringArgument | None = None,
+ # batch_args: StringArgument = None,
env_vars: StringArgument | None = None,
) -> None:
"""Initialize a BatchSettings instance.
@@ -82,9 +83,9 @@ def __init__(
# OR
sbatch_settings = BatchSettings(batch_scheduler=BatchSchedulerType.Slurm)
- This will assign a SlurmBatchArguments object to ``sbatch_settings.batch_args``.
- Using the object, users may access the child class functions to set
- batch configurations. For example:
+ This will assign a SlurmBatchArguments object to
+ ``sbatch_settings.batch_args``. Using the object, users may access the child
+ class functions to set batch configurations. For example:
.. highlight:: python
.. code-block:: python
@@ -105,9 +106,9 @@ def __init__(
:param batch_scheduler: The type of scheduler to initialize
(e.g., Slurm, PBS, LSF)
- :param batch_args: A dictionary of arguments for the scheduler, where the keys
- are strings and the values can be either strings or None. This argument is
- optional and defaults to None.
+ :param batch_args: A dictionary of arguments for the scheduler, where
+ the keys are strings and the values can be either strings or None.
+ This argument is optional and defaults to None.
:param env_vars: Environment variables for the batch settings, where the keys
are strings and the values can be either strings or None. This argument is
also optional and defaults to None.
@@ -122,7 +123,6 @@ def __init__(
"""The BatchSettings child class based on scheduler type"""
self.env_vars = env_vars or {}
"""The environment configuration"""
- self.batch_args = batch_args or {}
@property
def batch_scheduler(self) -> str:
diff --git a/smartsim/settings/common.py b/smartsim/settings/common.py
index 1d58da90b3..df7eb243aa 100644
--- a/smartsim/settings/common.py
+++ b/smartsim/settings/common.py
@@ -44,7 +44,7 @@ def set_check_input(key: str, value: t.Optional[str]) -> None:
if key.startswith("-"):
key = key.lstrip("-")
logger.warning(
- "One or more leading `-` characters were provided to \
- the run argument. Leading dashes were stripped and \
- the arguments were passed to the run_command."
+ "One or more leading `-` characters were provided to the run argument.\n"
+ "Leading dashes were stripped and the arguments were passed to the \n"
+ "run_command."
)
diff --git a/smartsim/settings/launch_settings.py b/smartsim/settings/launch_settings.py
index 3f878f59dd..136de7638b 100644
--- a/smartsim/settings/launch_settings.py
+++ b/smartsim/settings/launch_settings.py
@@ -114,8 +114,8 @@ def __init__(
:param launcher: The type of launcher to initialize (e.g., Dragon, Slurm,
PALS, ALPS, Local, Mpiexec, Mpirun, Orterun, LSF)
:param launch_args: A dictionary of arguments for the launcher, where the keys
- are strings and the values can be either strings or None.
- This argument is optional and defaults to None.
+ are strings and the values can be either strings or None. This argument is
+ optional and defaults to None.
:param env_vars: Environment variables for the launch settings, where the keys
are strings and the values can be either strings or None. This argument is
also optional and defaults to None.
diff --git a/smartsim/settings/sge_settings.py b/smartsim/settings/sge_settings.py
index 72dbbf5ce2..757d167c64 100644
--- a/smartsim/settings/sge_settings.py
+++ b/smartsim/settings/sge_settings.py
@@ -36,7 +36,7 @@
# ***************************************
# TODO: Remove pylint disable after merge
# ***************************************
-# pylint: disable=no-self-use
+# pylint: disable=no-self-use,no-member
class SgeQsubBatchSettings(BatchSettings):
diff --git a/tests/_legacy/backends/run_torch.py b/tests/_legacy/backends/run_torch.py
index 83c8a9a8e7..1071e740ef 100644
--- a/tests/_legacy/backends/run_torch.py
+++ b/tests/_legacy/backends/run_torch.py
@@ -25,6 +25,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import io
+import typing as t
import numpy as np
import torch
@@ -74,7 +75,7 @@ def calc_svd(input_tensor):
return input_tensor.svd()
-def run(device):
+def run(device: str, num_devices: int) -> t.Any:
# connect a client to the feature store
client = Client(cluster=False)
@@ -92,9 +93,23 @@ def run(device):
net = create_torch_model()
# 20 samples of "image" data
example_forward_input = torch.rand(20, 1, 28, 28)
- client.set_model("cnn", net, "TORCH", device=device)
client.put_tensor("input", example_forward_input.numpy())
- client.run_model("cnn", inputs=["input"], outputs=["output"])
+ if device == "CPU":
+ client.set_model("cnn", net, "TORCH", device=device)
+ client.run_model("cnn", inputs=["input"], outputs=["output"])
+ else:
+ client.set_model_multigpu(
+ "cnn", net, "TORCH", first_gpu=0, num_gpus=num_devices
+ )
+ client.run_model_multigpu(
+ "cnn",
+ offset=1,
+ first_gpu=0,
+ num_gpus=num_devices,
+ inputs=["input"],
+ outputs=["output"],
+ )
+
output = client.get_tensor("output")
print(f"Prediction: {output}")
@@ -106,5 +121,11 @@ def run(device):
parser.add_argument(
"--device", type=str, default="CPU", help="device type for model execution"
)
+ parser.add_argument(
+ "--num-devices",
+ type=int,
+ default=1,
+ help="Number of devices to set the model on",
+ )
args = parser.parse_args()
- run(args.device)
+ run(args.device, args.num_devices)
diff --git a/tests/_legacy/backends/test_cli_mini_exp.py b/tests/_legacy/backends/test_cli_mini_exp.py
index 1fd1107215..83ecfc5b07 100644
--- a/tests/_legacy/backends/test_cli_mini_exp.py
+++ b/tests/_legacy/backends/test_cli_mini_exp.py
@@ -32,7 +32,7 @@
import smartsim._core._cli.validate
import smartsim._core._install.builder as build
-from smartsim._core.utils.helpers import installed_redisai_backends
+from smartsim._core._install.platform import Device
sklearn_available = True
try:
@@ -70,7 +70,7 @@ def _mock_make_managed_local_feature_store(*a, **kw):
"_make_managed_local_feature_store",
_mock_make_managed_local_feature_store,
)
- backends = installed_redisai_backends()
+ backends = [] # todo: update test to replace installed_redisai_backends()
(fs_port,) = fs.ports
smartsim._core._cli.validate.test_install(
@@ -79,7 +79,7 @@ def _mock_make_managed_local_feature_store(*a, **kw):
location=test_dir,
port=fs_port,
# Always test on CPU, heads don't always have GPU
- device=build.Device.CPU,
+ device=Device.CPU,
# Test the backends the dev has installed
with_tf="tensorflow" in backends,
with_pt="torch" in backends,
diff --git a/tests/_legacy/backends/test_dbmodel.py b/tests/_legacy/backends/test_dbmodel.py
index 5c9a253c75..da495004fa 100644
--- a/tests/_legacy/backends/test_dbmodel.py
+++ b/tests/_legacy/backends/test_dbmodel.py
@@ -30,7 +30,6 @@
import pytest
from smartsim import Experiment
-from smartsim._core.utils import installed_redisai_backends
from smartsim.entity import Ensemble
from smartsim.entity.dbobject import FSModel
from smartsim.error.errors import SSUnsupportedError
@@ -70,7 +69,9 @@ def call(self, x):
except:
logger.warning("Could not set TF max memory limit for GPU")
-should_run_tf &= "tensorflow" in installed_redisai_backends()
+should_run_tf &= (
+ "tensorflow" in []
+) # todo: update test to replace installed_redisai_backends()
# Check if PyTorch is available for tests
try:
@@ -107,7 +108,9 @@ def forward(self, x):
return output
-should_run_pt &= "torch" in installed_redisai_backends()
+should_run_pt &= (
+ "torch" in []
+) # todo: update test to replace installed_redisai_backends()
def save_tf_cnn(path, file_name):
diff --git a/tests/_legacy/backends/test_dbscript.py b/tests/_legacy/backends/test_dbscript.py
index 9619b0325f..ec6e2f861c 100644
--- a/tests/_legacy/backends/test_dbscript.py
+++ b/tests/_legacy/backends/test_dbscript.py
@@ -24,18 +24,15 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-import os
import sys
import pytest
from smartredis import *
from smartsim import Experiment
-from smartsim._core.utils import installed_redisai_backends
from smartsim.entity.dbobject import FSScript
from smartsim.error.errors import SSUnsupportedError
from smartsim.log import get_logger
-from smartsim.settings import MpiexecSettings, MpirunSettings
from smartsim.status import JobStatus
logger = get_logger(__name__)
@@ -49,7 +46,7 @@
except ImportError:
should_run = False
-should_run &= "torch" in installed_redisai_backends()
+should_run &= "torch" in [] # todo: update test to replace installed_redisai_backends()
def timestwo(x):
diff --git a/tests/_legacy/backends/test_onnx.py b/tests/_legacy/backends/test_onnx.py
index 3580ec07e3..67c9775aa3 100644
--- a/tests/_legacy/backends/test_onnx.py
+++ b/tests/_legacy/backends/test_onnx.py
@@ -30,8 +30,6 @@
import pytest
-from smartsim import Experiment
-from smartsim._core.utils import installed_redisai_backends
from smartsim.status import JobStatus
sklearn_available = True
@@ -47,7 +45,9 @@
sklearn_available = False
-onnx_backend_available = "onnxruntime" in installed_redisai_backends()
+onnx_backend_available = (
+ "onnxruntime" in []
+) # todo: update test to replace installed_redisai_backends()
should_run = sklearn_available and onnx_backend_available
diff --git a/tests/_legacy/backends/test_tf.py b/tests/_legacy/backends/test_tf.py
index 320fe84721..526c08e29e 100644
--- a/tests/_legacy/backends/test_tf.py
+++ b/tests/_legacy/backends/test_tf.py
@@ -29,8 +29,6 @@
import pytest
-from smartsim import Experiment
-from smartsim._core.utils import installed_redisai_backends
from smartsim.error import SmartSimError
from smartsim.status import JobStatus
@@ -43,7 +41,9 @@
print(e)
tf_available = False
-tf_backend_available = "tensorflow" in installed_redisai_backends()
+tf_backend_available = (
+ "tensorflow" in []
+) # todo: update test to replace installed_redisai_backends()
@pytest.mark.skipif(
diff --git a/tests/_legacy/backends/test_torch.py b/tests/_legacy/backends/test_torch.py
index 2bf6c741a4..2606d08837 100644
--- a/tests/_legacy/backends/test_torch.py
+++ b/tests/_legacy/backends/test_torch.py
@@ -29,8 +29,6 @@
import pytest
-from smartsim import Experiment
-from smartsim._core.utils import installed_redisai_backends
from smartsim.status import JobStatus
torch_available = True
@@ -40,7 +38,9 @@
except ImportError:
torch_available = False
-torch_backend_available = "torch" in installed_redisai_backends()
+torch_backend_available = (
+ "torch" in []
+) # todo: update test to replace installed_redisai_backends()
should_run = torch_available and torch_backend_available
pytestmark = pytest.mark.skipif(
@@ -65,9 +65,11 @@ def test_torch_model_and_script(
fs = prepare_fs(single_fs).featurestore
wlm_experiment.reconnect_feature_store(fs.checkpoint_file)
test_device = mlutils.get_test_device()
+ test_num_gpus = mlutils.get_test_num_gpus() if pytest.test_device == "GPU" else 1
run_settings = wlm_experiment.create_run_settings(
- "python", f"run_torch.py --device={test_device}"
+ "python",
+ ["run_torch.py", f"--device={test_device}", f"--num-devices={test_num_gpus}"],
)
if wlmutils.get_test_launcher() != "local":
run_settings.set_tasks(1)
diff --git a/tests/_legacy/install/test_build.py b/tests/_legacy/install/test_build.py
new file mode 100644
index 0000000000..f8a5c4896b
--- /dev/null
+++ b/tests/_legacy/install/test_build.py
@@ -0,0 +1,148 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import operator
+
+import pytest
+
+from smartsim._core._cli.build import parse_requirement
+from smartsim._core._install.buildenv import Version_
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+
+_SUPPORTED_OPERATORS = ("==", ">=", ">", "<=", "<")
+
+
+@pytest.mark.parametrize(
+ "spec, name, pin",
+ (
+ pytest.param("foo", "foo", None, id="Just Name"),
+ pytest.param("foo==1", "foo", "==1", id="With Major"),
+ pytest.param("foo==1.2", "foo", "==1.2", id="With Minor"),
+ pytest.param("foo==1.2.3", "foo", "==1.2.3", id="With Patch"),
+ pytest.param("foo[with-extras]==1.2.3", "foo", "==1.2.3", id="With Extra"),
+ pytest.param(
+ "foo[with,many,extras]==1.2.3", "foo", "==1.2.3", id="With Many Extras"
+ ),
+ *(
+ pytest.param(
+ f"foo{symbol}1.2.3{tag}",
+ "foo",
+ f"{symbol}1.2.3{tag}",
+ id=f"{symbol=} | {tag=}",
+ )
+ for symbol in _SUPPORTED_OPERATORS
+ for tag in ("", "+cuda", "+rocm", "+cpu")
+ ),
+ ),
+)
+def test_parse_requirement_name_and_version(spec, name, pin):
+ p_name, p_pin, _ = parse_requirement(spec)
+ assert p_name == name
+ assert p_pin == pin
+
+
+# fmt: off
+@pytest.mark.parametrize(
+ "spec, ver, should_pass",
+ (
+ pytest.param("foo" , Version_("1.2.3") , True, id="No spec"),
+ # EQ --------------------------------------------------------------------------
+ pytest.param("foo==1.2.3" , Version_("1.2.3") , True, id="EQ Spec, EQ Version"),
+ pytest.param("foo==1.2.3" , Version_("1.2.5") , False, id="EQ Spec, GT Version"),
+ pytest.param("foo==1.2.3" , Version_("1.2.2") , False, id="EQ Spec, LT Version"),
+ pytest.param("foo==1.2.3+rocm", Version_("1.2.3+rocm"), True, id="EQ Spec, Compatible Version with suffix"),
+ pytest.param("foo==1.2.3" , Version_("1.2.3+cuda"), False, id="EQ Spec, Compatible Version, Extra Suffix"),
+ pytest.param("foo==1.2.3+cuda", Version_("1.2.3") , False, id="EQ Spec, Compatible Version, Missing Suffix"),
+ pytest.param("foo==1.2.3+cuda", Version_("1.2.3+rocm"), False, id="EQ Spec, Compatible Version, Mismatched Suffix"),
+ # LT --------------------------------------------------------------------------
+ pytest.param("foo<1.2.3" , Version_("1.2.3") , False, id="LT Spec, EQ Version"),
+ pytest.param("foo<1.2.3" , Version_("1.2.5") , False, id="LT Spec, GT Version"),
+ pytest.param("foo<1.2.3" , Version_("1.2.2") , True, id="LT Spec, LT Version"),
+ pytest.param("foo<1.2.3+rocm" , Version_("1.2.2+rocm"), True, id="LT Spec, Compatible Version with suffix"),
+ pytest.param("foo<1.2.3" , Version_("1.2.2+cuda"), False, id="LT Spec, Compatible Version, Extra Suffix"),
+ pytest.param("foo<1.2.3+cuda" , Version_("1.2.2") , False, id="LT Spec, Compatible Version, Missing Suffix"),
+ pytest.param("foo<1.2.3+cuda" , Version_("1.2.2+rocm"), False, id="LT Spec, Compatible Version, Mismatched Suffix"),
+ # LE --------------------------------------------------------------------------
+ pytest.param("foo<=1.2.3" , Version_("1.2.3") , True, id="LE Spec, EQ Version"),
+ pytest.param("foo<=1.2.3" , Version_("1.2.5") , False, id="LE Spec, GT Version"),
+ pytest.param("foo<=1.2.3" , Version_("1.2.2") , True, id="LE Spec, LT Version"),
+ pytest.param("foo<=1.2.3+rocm", Version_("1.2.3+rocm"), True, id="LE Spec, Compatible Version with suffix"),
+ pytest.param("foo<=1.2.3" , Version_("1.2.3+cuda"), False, id="LE Spec, Compatible Version, Extra Suffix"),
+ pytest.param("foo<=1.2.3+cuda", Version_("1.2.3") , False, id="LE Spec, Compatible Version, Missing Suffix"),
+ pytest.param("foo<=1.2.3+cuda", Version_("1.2.3+rocm"), False, id="LE Spec, Compatible Version, Mismatched Suffix"),
+ # GT --------------------------------------------------------------------------
+ pytest.param("foo>1.2.3" , Version_("1.2.3") , False, id="GT Spec, EQ Version"),
+ pytest.param("foo>1.2.3" , Version_("1.2.5") , True, id="GT Spec, GT Version"),
+ pytest.param("foo>1.2.3" , Version_("1.2.2") , False, id="GT Spec, LT Version"),
+ pytest.param("foo>1.2.3+rocm" , Version_("1.2.4+rocm"), True, id="GT Spec, Compatible Version with suffix"),
+ pytest.param("foo>1.2.3" , Version_("1.2.4+cuda"), False, id="GT Spec, Compatible Version, Extra Suffix"),
+ pytest.param("foo>1.2.3+cuda" , Version_("1.2.4") , False, id="GT Spec, Compatible Version, Missing Suffix"),
+ pytest.param("foo>1.2.3+cuda" , Version_("1.2.4+rocm"), False, id="GT Spec, Compatible Version, Mismatched Suffix"),
+ # GE --------------------------------------------------------------------------
+ pytest.param("foo>=1.2.3" , Version_("1.2.3") , True, id="GE Spec, EQ Version"),
+ pytest.param("foo>=1.2.3" , Version_("1.2.5") , True, id="GE Spec, GT Version"),
+ pytest.param("foo>=1.2.3" , Version_("1.2.2") , False, id="GE Spec, LT Version"),
+ pytest.param("foo>=1.2.3+rocm", Version_("1.2.3+rocm"), True, id="GE Spec, Compatible Version with suffix"),
+ pytest.param("foo>=1.2.3" , Version_("1.2.3+cuda"), False, id="GE Spec, Compatible Version, Extra Suffix"),
+ pytest.param("foo>=1.2.3+cuda", Version_("1.2.3") , False, id="GE Spec, Compatible Version, Missing Suffix"),
+ pytest.param("foo>=1.2.3+cuda", Version_("1.2.3+rocm"), False, id="GE Spec, Compatible Version, Mismatched Suffix"),
+ )
+)
+# fmt: on
+def test_parse_requirement_comparison_fn(spec, ver, should_pass):
+ _, _, cmp = parse_requirement(spec)
+ assert cmp(ver) == should_pass
+
+
+@pytest.mark.parametrize(
+ "spec, ctx",
+ (
+ *(
+ pytest.param(
+ f"thing{symbol}",
+ pytest.raises(ValueError, match="Invalid requirement string:"),
+ id=f"No version w/ operator {symbol}",
+ )
+ for symbol in _SUPPORTED_OPERATORS
+ ),
+ pytest.param(
+ "thing>=>1.2.3",
+ pytest.raises(ValueError, match="Invalid requirement string:"),
+ id="Operator too long",
+ ),
+ pytest.param(
+ "thing<>1.2.3",
+ pytest.raises(ValueError, match="Unrecognized comparison operator: <>"),
+ id="Nonsense operator",
+ ),
+ ),
+)
+def test_parse_requirement_errors_on_invalid_spec(spec, ctx):
+ with ctx:
+ parse_requirement(spec)
diff --git a/tests/_legacy/install/test_mlpackage.py b/tests/_legacy/install/test_mlpackage.py
new file mode 100644
index 0000000000..d27e69b2ba
--- /dev/null
+++ b/tests/_legacy/install/test_mlpackage.py
@@ -0,0 +1,122 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+import pathlib
+from unittest.mock import MagicMock
+
+import pytest
+
+from smartsim._core._install.mlpackages import (
+ MLPackage,
+ MLPackageCollection,
+ RAIPatch,
+ load_platform_configs,
+)
+from smartsim._core._install.platform import Platform
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+mock_platform = MagicMock(spec=Platform)
+
+
+@pytest.fixture
+def mock_ml_packages():
+ foo = MagicMock(spec=MLPackage)
+ foo.name = "foo"
+ bar = MagicMock(spec=MLPackage)
+ bar.name = "bar"
+ yield [foo, bar]
+
+
+@pytest.mark.parametrize(
+ "patch",
+ [MagicMock(spec=RAIPatch), [MagicMock(spec=RAIPatch) for i in range(3)], ()],
+ ids=["one patch", "multiple patches", "no patch"],
+)
+def test_mlpackage_constructor(patch):
+ MLPackage(
+ "foo",
+ "0.0.0",
+ "https://nothing.com",
+ ["bar==0.1", "baz==0.2"],
+ pathlib.Path("/nothing/fake"),
+ patch,
+ )
+
+
+def test_mlpackage_collection_constructor(mock_ml_packages):
+ MLPackageCollection(mock_platform, mock_ml_packages)
+
+
+def test_mlpackage_collection_mutable_mapping_methods(mock_ml_packages):
+ ml_packages = MLPackageCollection(mock_platform, mock_ml_packages)
+ for val in ml_packages._ml_packages.values():
+ val.version = "0.0.0"
+ assert ml_packages._ml_packages == ml_packages
+
+ # Test iter
+ package_names = [pkg.name for pkg in mock_ml_packages]
+ assert [name for name in ml_packages] == package_names
+
+ # Test get item
+ for pkg in mock_ml_packages:
+ assert ml_packages[pkg.name] is pkg
+
+ # Test len
+ assert len(ml_packages) == len(mock_ml_packages)
+
+ # Test delitem
+ key = next(iter(mock_ml_packages)).name
+ del ml_packages[key]
+ with pytest.raises(KeyError):
+ ml_packages[key]
+ assert len(ml_packages) == (len(mock_ml_packages) - 1)
+
+ # Test setitem
+ with pytest.raises(TypeError):
+ ml_packages["baz"] = MagicMock(spec=MLPackage)
+
+ # Test contains
+ name, package = next(iter(ml_packages.items()))
+ assert name in ml_packages
+
+ # Test str
+ assert "Package" in str(ml_packages)
+ assert "Version" in str(ml_packages)
+ assert package.version in str(ml_packages)
+ assert name in str(ml_packages)
+
+
+def test_load_configs_raises_when_dir_dne(test_dir):
+ dne_dir = pathlib.Path(test_dir, "dne")
+ dir_str = os.fspath(dne_dir)
+ with pytest.raises(
+ FileNotFoundError,
+ match=f"Platform configuration directory `{dir_str}` does not exist",
+ ):
+ load_platform_configs(dne_dir)
diff --git a/tests/_legacy/install/test_package_retriever.py b/tests/_legacy/install/test_package_retriever.py
new file mode 100644
index 0000000000..d415ae2358
--- /dev/null
+++ b/tests/_legacy/install/test_package_retriever.py
@@ -0,0 +1,106 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import contextlib
+import filecmp
+import os
+import pathlib
+import random
+import string
+import tarfile
+import zipfile
+
+import pytest
+
+from smartsim._core._install.utils import retrieve
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+
+@contextlib.contextmanager
+def temp_cd(path):
+ original = os.getcwd()
+ os.chdir(path)
+ try:
+ yield
+ finally:
+ os.chdir(original)
+
+
+def make_test_file(test_file):
+ data = "".join(random.choices(string.ascii_letters + string.digits, k=1024))
+ with open(test_file, "w") as f:
+ f.write(data)
+
+
+def test_local_archive_zip(test_dir):
+ with temp_cd(test_dir):
+ test_file = "./test.data"
+ make_test_file(test_file)
+
+ zip_file = "./test.zip"
+ with zipfile.ZipFile(zip_file, "w") as f:
+ f.write(test_file)
+
+ retrieve(zip_file, pathlib.Path("./output"))
+
+ assert filecmp.cmp(
+ test_file, pathlib.Path("./output") / "test.data", shallow=False
+ )
+
+
+def test_local_archive_tgz(test_dir):
+ with temp_cd(test_dir):
+ test_file = "./test.data"
+ make_test_file(test_file)
+
+ tgz_file = "./test.tgz"
+ with tarfile.open(tgz_file, "w:gz") as f:
+ f.add(test_file)
+
+ retrieve(tgz_file, pathlib.Path("./output"))
+
+ assert filecmp.cmp(
+ test_file, pathlib.Path("./output") / "test.data", shallow=False
+ )
+
+
+def test_git(test_dir):
+ retrieve(
+ "https://github.com/CrayLabs/SmartSim.git",
+ f"{test_dir}/smartsim_git",
+ branch="master",
+ )
+ assert pathlib.Path(f"{test_dir}/smartsim_git").is_dir()
+
+
+def test_https(test_dir):
+ output_dir = pathlib.Path(test_dir) / "output"
+ retrieve(
+ "https://github.com/CrayLabs/SmartSim/archive/refs/tags/v0.5.0.zip", output_dir
+ )
+ assert output_dir.exists()
diff --git a/tests/_legacy/install/test_platform.py b/tests/_legacy/install/test_platform.py
new file mode 100644
index 0000000000..76ff3f76b1
--- /dev/null
+++ b/tests/_legacy/install/test_platform.py
@@ -0,0 +1,89 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import os
+import platform
+
+import pytest
+
+from smartsim._core._install.platform import Architecture, Device, OperatingSystem
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+
+def test_device_cpu():
+ cpu_enum = Device.CPU
+ assert not cpu_enum.is_gpu()
+ assert not cpu_enum.is_cuda()
+ assert not cpu_enum.is_rocm()
+
+
+@pytest.mark.parametrize("cuda_device", Device.cuda_enums())
+def test_cuda(monkeypatch, test_dir, cuda_device):
+ version = cuda_device.value.split("-")[1]
+ fake_full_version = version + ".8888" ".9999"
+ monkeypatch.setenv("CUDA_HOME", test_dir)
+
+ mock_version = dict(cuda=dict(version=fake_full_version))
+ print(mock_version)
+ with open(f"{test_dir}/version.json", "w") as outfile:
+ json.dump(mock_version, outfile)
+
+ assert Device.detect_cuda_version() == cuda_device
+ assert cuda_device.is_gpu()
+ assert cuda_device.is_cuda()
+ assert not cuda_device.is_rocm()
+
+
+@pytest.mark.parametrize("rocm_device", Device.rocm_enums())
+def test_rocm(monkeypatch, test_dir, rocm_device):
+ version = rocm_device.value.split("-")[1]
+ fake_full_version = version + ".8888" + "-9999"
+ monkeypatch.setenv("ROCM_HOME", test_dir)
+ info_dir = f"{test_dir}/.info"
+ os.mkdir(info_dir)
+
+ with open(f"{info_dir}/version", "w") as outfile:
+ outfile.write(fake_full_version)
+
+ assert Device.detect_rocm_version() == rocm_device
+ assert rocm_device.is_gpu()
+ assert not rocm_device.is_cuda()
+ assert rocm_device.is_rocm()
+
+
+@pytest.mark.parametrize("os", ("linux", "darwin"))
+def test_operating_system(monkeypatch, os):
+ monkeypatch.setattr(platform, "system", lambda: os)
+ assert OperatingSystem.autodetect().value == os
+
+
+@pytest.mark.parametrize("arch", ("x86_64", "arm64"))
+def test_architecture(monkeypatch, arch):
+ monkeypatch.setattr(platform, "machine", lambda: arch)
+ assert Architecture.autodetect().value == arch
diff --git a/tests/_legacy/install/test_redisai_builder.py b/tests/_legacy/install/test_redisai_builder.py
new file mode 100644
index 0000000000..81673a7f12
--- /dev/null
+++ b/tests/_legacy/install/test_redisai_builder.py
@@ -0,0 +1,60 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from pathlib import Path
+
+import pytest
+
+from smartsim._core._install.buildenv import BuildEnv
+from smartsim._core._install.mlpackages import (
+ DEFAULT_MLPACKAGE_PATH,
+ MLPackage,
+ load_platform_configs,
+)
+from smartsim._core._install.platform import Platform
+from smartsim._core._install.redisaiBuilder import RedisAIBuilder
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+DEFAULT_MLPACKAGES = load_platform_configs(DEFAULT_MLPACKAGE_PATH)
+
+
+@pytest.mark.parametrize(
+ "platform",
+ [platform for platform in DEFAULT_MLPACKAGES],
+ ids=[str(platform) for platform in DEFAULT_MLPACKAGES],
+)
+def test_backends_to_be_installed(monkeypatch, test_dir, platform):
+ mlpackages = DEFAULT_MLPACKAGES[platform]
+ monkeypatch.setattr(MLPackage, "retrieve", lambda *args, **kwargs: None)
+ builder = RedisAIBuilder(platform, mlpackages, BuildEnv(), Path(test_dir))
+
+ BACKENDS = ["libtorch", "libtensorflow", "onnxruntime"]
+ TOGGLES = ["build_torch", "build_tensorflow", "build_onnxruntime"]
+
+ for backend, toggle in zip(BACKENDS, TOGGLES):
+ assert getattr(builder, toggle) == (backend in mlpackages)
diff --git a/tests/_legacy/on_wlm/test_dragon.py b/tests/_legacy/on_wlm/test_dragon.py
index b685b65020..d835d60ce1 100644
--- a/tests/_legacy/on_wlm/test_dragon.py
+++ b/tests/_legacy/on_wlm/test_dragon.py
@@ -56,7 +56,7 @@ def test_dragon_global_path(global_dragon_teardown, wlmutils, test_dir, monkeypa
def test_dragon_exp_path(global_dragon_teardown, wlmutils, test_dir, monkeypatch):
monkeypatch.delenv("SMARTSIM_DRAGON_SERVER_PATH", raising=False)
- monkeypatch.delenv("SMARTSIM_DRAGON_SERVER_PATH_EXP", raising=False)
+ monkeypatch.delenv("_SMARTSIM_DRAGON_SERVER_PATH_EXP", raising=False)
exp: Experiment = Experiment(
"test_dragon_connection",
exp_path=test_dir,
diff --git a/tests/_legacy/test_cli.py b/tests/_legacy/test_cli.py
index 397f1196c6..c47ea046b7 100644
--- a/tests/_legacy/test_cli.py
+++ b/tests/_legacy/test_cli.py
@@ -436,24 +436,22 @@ def mock_execute(ns: argparse.Namespace, _unparsed: t.Optional[t.List[str]] = No
# fmt: off
@pytest.mark.parametrize(
- "command,mock_location,exp_output,optional_arg,exp_valid,exp_err_msg,check_prop,exp_prop_val",
+ "command, mock_location, exp_output, optional_arg, exp_valid, exp_err_msg, check_prop, exp_prop_val",
[
- pytest.param("build", "build_execute", "verbose mocked-build", "-v", True, "", "v", True, id="verbose 'on'"),
- pytest.param("build", "build_execute", "cpu mocked-build", "--device=cpu", True, "", "device", "cpu", id="device 'cpu'"),
- pytest.param("build", "build_execute", "gpu mocked-build", "--device=gpu", True, "", "device", "gpu", id="device 'gpu'"),
- pytest.param("build", "build_execute", "gpuX mocked-build", "--device=gpux", False, "invalid choice: 'gpux'", "", "", id="set bad device 'gpuX'"),
- pytest.param("build", "build_execute", "no tensorflow mocked-build", "--no_tf", True, "", "no_tf", True, id="set no TF"),
- pytest.param("build", "build_execute", "no torch mocked-build", "--no_pt", True, "", "no_pt", True, id="set no torch"),
- pytest.param("build", "build_execute", "onnx mocked-build", "--onnx", True, "", "onnx", True, id="set w/onnx"),
- pytest.param("build", "build_execute", "torch-dir mocked-build", "--torch_dir /foo/bar", True, "", "torch_dir", "/foo/bar", id="set torch dir"),
- pytest.param("build", "build_execute", "bad-torch-dir mocked-build", "--torch_dir", False, "error: argument --torch_dir", "", "", id="set torch dir, no path"),
- pytest.param("build", "build_execute", "keydb mocked-build", "--keydb", True, "", "keydb", True, id="keydb on"),
- pytest.param("clean", "clean_execute", "clobbering mocked-clean", "--clobber", True, "", "clobber", True, id="clean w/clobber"),
- pytest.param("validate", "validate_execute", "port mocked-validate", "--port=12345", True, "", "port", 12345, id="validate w/ manual port"),
- pytest.param("validate", "validate_execute", "abbrv port mocked-validate", "-p 12345", True, "", "port", 12345, id="validate w/ manual abbreviated port"),
- pytest.param("validate", "validate_execute", "cpu mocked-validate", "--device=cpu", True, "", "device", "cpu", id="validate: device 'cpu'"),
- pytest.param("validate", "validate_execute", "gpu mocked-validate", "--device=gpu", True, "", "device", "gpu", id="validate: device 'gpu'"),
- pytest.param("validate", "validate_execute", "gpuX mocked-validate", "--device=gpux", False, "invalid choice: 'gpux'", "", "", id="validate: set bad device 'gpuX'"),
+ pytest.param( "build", "build_execute", "verbose mocked-build", "-v", True, "", "v", True, id="verbose 'on'"),
+ pytest.param( "build", "build_execute", "cpu mocked-build", "--device=cpu", True, "", "device", "cpu", id="device 'cpu'"),
+ pytest.param( "build", "build_execute", "gpuX mocked-build", "--device=gpux", False, "invalid choice: 'gpux'", "", "", id="set bad device 'gpuX'"),
+ pytest.param( "build", "build_execute", "no tensorflow mocked-build", "--skip-tensorflow", True, "", "no_tf", True, id="Skip TF"),
+ pytest.param( "build", "build_execute", "no torch mocked-build", "--skip-torch", True, "", "no_pt", True, id="Skip Torch"),
+ pytest.param( "build", "build_execute", "onnx mocked-build", "--skip-onnx", True, "", "onnx", True, id="Skip Onnx"),
+ pytest.param( "build", "build_execute", "config-dir mocked-build", "--config-dir /foo/bar", True, "", "config-dir", "/foo/bar", id="set torch dir"),
+ pytest.param( "build", "build_execute", "bad-config-dir mocked-build", "--config-dir", False, "error: argument --config-dir", "", "", id="set config dir w/o path"),
+ pytest.param( "clean", "clean_execute", "clobbering mocked-clean", "--clobber", True, "", "clobber", True, id="clean w/clobber"),
+ pytest.param("validate", "validate_execute", "port mocked-validate", "--port=12345", True, "", "port", 12345, id="validate w/ manual port"),
+ pytest.param("validate", "validate_execute", "abbrv port mocked-validate", "-p 12345", True, "", "port", 12345, id="validate w/ manual abbreviated port"),
+ pytest.param("validate", "validate_execute", "cpu mocked-validate", "--device=cpu", True, "", "device", "cpu", id="validate: device 'cpu'"),
+ pytest.param("validate", "validate_execute", "gpu mocked-validate", "--device=gpu", True, "", "device", "gpu", id="validate: device 'gpu'"),
+ pytest.param("validate", "validate_execute", "gpuX mocked-validate", "--device=gpux", False, "invalid choice: 'gpux'", "", "", id="validate: set bad device 'gpuX'"),
]
)
# fmt: on
@@ -733,19 +731,7 @@ def mock_operation(*args, **kwargs) -> int:
# mock out the internal get_fs_path method so we don't actually do file system ops
monkeypatch.setattr(smartsim._core._cli.build, "tabulate", mock_operation)
- monkeypatch.setattr(
- smartsim._core._cli.build, "build_feature_store", mock_operation
- )
monkeypatch.setattr(smartsim._core._cli.build, "build_redis_ai", mock_operation)
- monkeypatch.setattr(
- smartsim._core._cli.build, "check_py_torch_version", mock_operation
- )
- monkeypatch.setattr(
- smartsim._core._cli.build, "check_py_tf_version", mock_operation
- )
- monkeypatch.setattr(
- smartsim._core._cli.build, "check_py_onnx_version", mock_operation
- )
command = "build"
cfg = MenuItemConfig(
diff --git a/tests/_legacy/test_colo_model_local.py b/tests/_legacy/test_colo_model_local.py
index 1ab97c4cc3..54848907d3 100644
--- a/tests/_legacy/test_colo_model_local.py
+++ b/tests/_legacy/test_colo_model_local.py
@@ -29,7 +29,7 @@
import pytest
from smartsim import Experiment
-from smartsim.entity import Application
+from smartsim._core.utils.helpers import _create_pinning_string
from smartsim.error import SSUnsupportedError
from smartsim.status import JobStatus
@@ -116,7 +116,7 @@ def test_unsupported_custom_pinning(fileutils, test_dir, coloutils, custom_pinni
],
)
def test_create_pinning_string(pin_list, num_cpus, expected):
- assert Application._create_pinning_string(pin_list, num_cpus) == expected
+ assert _create_pinning_string(pin_list, num_cpus) == expected
@pytest.mark.parametrize("fs_type", supported_fss)
diff --git a/tests/_legacy/test_config.py b/tests/_legacy/test_config.py
index 00a1fcdd36..5a84103ffd 100644
--- a/tests/_legacy/test_config.py
+++ b/tests/_legacy/test_config.py
@@ -66,9 +66,9 @@ def get_redisai_env(
"""
env = os.environ.copy()
if rai_path is not None:
- env["RAI_PATH"] = rai_path
+ env["SMARTSIM_RAI_LIB"] = rai_path
else:
- env.pop("RAI_PATH", None)
+ env.pop("SMARTSIM_RAI_LIB", None)
if lib_path is not None:
env["SMARTSIM_DEP_INSTALL_PATH"] = lib_path
@@ -85,7 +85,7 @@ def make_file(filepath: str) -> None:
def test_redisai_invalid_rai_path(test_dir, monkeypatch):
- """An invalid RAI_PATH and valid SMARTSIM_DEP_INSTALL_PATH should fail"""
+ """An invalid SMARTSIM_RAI_LIB and valid SMARTSIM_DEP_INSTALL_PATH should fail"""
rai_file_path = os.path.join(test_dir, "lib", "mock-redisai.so")
make_file(os.path.join(test_dir, "lib", "redisai.so"))
@@ -94,7 +94,7 @@ def test_redisai_invalid_rai_path(test_dir, monkeypatch):
config = Config()
- # Fail when no file exists @ RAI_PATH
+ # Fail when no file exists @ SMARTSIM_RAI_LIB
with pytest.raises(SSConfigError) as ex:
_ = config.redisai
@@ -102,7 +102,7 @@ def test_redisai_invalid_rai_path(test_dir, monkeypatch):
def test_redisai_valid_rai_path(test_dir, monkeypatch):
- """A valid RAI_PATH should override valid SMARTSIM_DEP_INSTALL_PATH and succeed"""
+ """A valid SMARTSIM_RAI_LIB should override valid SMARTSIM_DEP_INSTALL_PATH and succeed"""
rai_file_path = os.path.join(test_dir, "lib", "mock-redisai.so")
make_file(rai_file_path)
@@ -117,7 +117,7 @@ def test_redisai_valid_rai_path(test_dir, monkeypatch):
def test_redisai_invalid_lib_path(test_dir, monkeypatch):
- """Invalid RAI_PATH and invalid SMARTSIM_DEP_INSTALL_PATH should fail"""
+ """Invalid SMARTSIM_RAI_LIB and invalid SMARTSIM_DEP_INSTALL_PATH should fail"""
rai_file_path = f"{test_dir}/railib/redisai.so"
@@ -133,7 +133,7 @@ def test_redisai_invalid_lib_path(test_dir, monkeypatch):
def test_redisai_valid_lib_path(test_dir, monkeypatch):
- """Valid RAI_PATH and invalid SMARTSIM_DEP_INSTALL_PATH should succeed"""
+ """Valid SMARTSIM_RAI_LIB and invalid SMARTSIM_DEP_INSTALL_PATH should succeed"""
rai_file_path = os.path.join(test_dir, "lib", "mock-redisai.so")
make_file(rai_file_path)
@@ -147,7 +147,7 @@ def test_redisai_valid_lib_path(test_dir, monkeypatch):
def test_redisai_valid_lib_path_null_rai(test_dir, monkeypatch):
- """Missing RAI_PATH and valid SMARTSIM_DEP_INSTALL_PATH should succeed"""
+ """Missing SMARTSIM_RAI_LIB and valid SMARTSIM_DEP_INSTALL_PATH should succeed"""
rai_file_path: t.Optional[str] = None
lib_file_path = os.path.join(test_dir, "lib", "redisai.so")
@@ -166,11 +166,11 @@ def test_redis_conf():
assert Path(config.database_conf).is_file()
assert isinstance(config.database_conf, str)
- os.environ["REDIS_CONF"] = "not/a/path"
+ os.environ["SMARTSIM_REDIS_CONF"] = "not/a/path"
config = Config()
with pytest.raises(SSConfigError):
config.database_conf
- os.environ.pop("REDIS_CONF")
+ os.environ.pop("SMARTSIM_REDIS_CONF")
def test_redis_exe():
@@ -178,11 +178,11 @@ def test_redis_exe():
assert Path(config.database_exe).is_file()
assert isinstance(config.database_exe, str)
- os.environ["REDIS_PATH"] = "not/a/path"
+ os.environ["SMARTSIM_REDIS_SERVER_EXE"] = "not/a/path"
config = Config()
with pytest.raises(SSConfigError):
config.database_exe
- os.environ.pop("REDIS_PATH")
+ os.environ.pop("SMARTSIM_REDIS_SERVER_EXE")
def test_redis_cli():
@@ -190,11 +190,11 @@ def test_redis_cli():
assert Path(config.redisai).is_file()
assert isinstance(config.redisai, str)
- os.environ["REDIS_CLI_PATH"] = "not/a/path"
+ os.environ["SMARTSIM_REDIS_CLI_EXE"] = "not/a/path"
config = Config()
with pytest.raises(SSConfigError):
config.database_cli
- os.environ.pop("REDIS_CLI_PATH")
+ os.environ.pop("SMARTSIM_REDIS_CLI_EXE")
@pytest.mark.parametrize(
diff --git a/tests/_legacy/test_dragon_installer.py b/tests/_legacy/test_dragon_installer.py
index b23a1a7ef0..8ce7404c5f 100644
--- a/tests/_legacy/test_dragon_installer.py
+++ b/tests/_legacy/test_dragon_installer.py
@@ -31,12 +31,17 @@
from collections import namedtuple
import pytest
+from github.GitRelease import GitRelease
from github.GitReleaseAsset import GitReleaseAsset
from github.Requester import Requester
import smartsim
+import smartsim._core._install.utils
import smartsim._core.utils.helpers as helpers
from smartsim._core._cli.scripts.dragon_install import (
+ DEFAULT_DRAGON_REPO,
+ DEFAULT_DRAGON_VERSION,
+ DragonInstallRequest,
cleanup,
create_dotenv,
install_dragon,
@@ -58,14 +63,25 @@
def test_archive(test_dir: str, archive_path: pathlib.Path) -> pathlib.Path:
"""Fixture for returning a simple tarfile to test on"""
num_files = 10
+
+ archive_name = archive_path.name
+ archive_name = archive_name.replace(".tar.gz", "")
+
with tarfile.TarFile.open(archive_path, mode="w:gz") as tar:
- mock_whl = pathlib.Path(test_dir) / "mock.whl"
+ mock_whl = pathlib.Path(test_dir) / archive_name / f"{archive_name}.whl"
+ mock_whl.parent.mkdir(parents=True, exist_ok=True)
mock_whl.touch()
+ tar.add(mock_whl)
+
for i in range(num_files):
- content = pathlib.Path(test_dir) / f"{i:04}.txt"
+ content = pathlib.Path(test_dir) / archive_name / f"{i:04}.txt"
content.write_text(f"i am file {i}\n")
tar.add(content)
+ content.unlink()
+
+ mock_whl.unlink()
+
return archive_path
@@ -118,11 +134,41 @@ def test_assets(monkeypatch: pytest.MonkeyPatch) -> t.Dict[str, GitReleaseAsset]
_git_attr(value=f"http://foo/{archive_name}"),
)
monkeypatch.setattr(asset, "_name", _git_attr(value=archive_name))
+ monkeypatch.setattr(asset, "_id", _git_attr(value=123))
assets.append(asset)
return assets
+@pytest.fixture
+def test_releases(monkeypatch: pytest.MonkeyPatch) -> t.Dict[str, GitRelease]:
+ requester = Requester(
+ auth=None,
+ base_url="https://github.com",
+ user_agent="mozilla",
+ per_page=10,
+ verify=False,
+ timeout=1,
+ retry=1,
+ pool_size=1,
+ )
+ headers = {"mock-header": "mock-value"}
+ attributes = {"title": "mock-title"}
+ completed = True
+
+ releases: t.List[GitRelease] = []
+
+ for python_version in ["py3.9", "py3.10", "py3.11"]:
+ for dragon_version in ["dragon-0.8", "dragon-0.9", "dragon-0.10"]:
+ attributes = {
+ "title": f"{python_version}-{dragon_version}-release",
+ "tag_name": f"v{dragon_version}-weekly",
+ }
+ releases.append(GitRelease(requester, headers, attributes, completed))
+
+ return releases
+
+
def test_cleanup_no_op(archive_path: pathlib.Path) -> None:
"""Ensure that the cleanup method doesn't bomb when called with
missing archive path; simulate a failed download"""
@@ -143,17 +189,25 @@ def test_cleanup_archive_exists(test_archive: pathlib.Path) -> None:
assert not test_archive.exists()
-def test_retrieve_cached(
- test_dir: str,
- # archive_path: pathlib.Path,
+@pytest.mark.skip("Deprecated due to builder.py changes")
+def test_retrieve_updated(
test_archive: pathlib.Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
- """Verify that a previously retrieved asset archive is re-used"""
- with tarfile.TarFile.open(test_archive) as tar:
- tar.extractall(test_dir)
+ """Verify that a previously retrieved asset archive is not re-used if a new
+ version is found"""
- ts1 = test_archive.parent.stat().st_ctime
+ old_asset_id = 100
+ asset_id = 123
+
+ def mock__retrieve_archive(source_, destination_) -> None:
+ mock_extraction_dir = pathlib.Path(destination_)
+ with tarfile.TarFile.open(test_archive) as tar:
+ tar.extractall(mock_extraction_dir)
+
+ # we'll use the mock extract to create the files that would normally be downloaded
+ expected_output_dir = test_archive.parent / str(asset_id)
+ old_output_dir = test_archive.parent / str(old_asset_id)
requester = Requester(
auth=None,
@@ -174,14 +228,22 @@ def test_retrieve_cached(
# ensure mocked asset has values that we use...
monkeypatch.setattr(asset, "_browser_download_url", _git_attr(value="http://foo"))
monkeypatch.setattr(asset, "_name", _git_attr(value=mock_archive_name))
+ monkeypatch.setattr(asset, "_id", _git_attr(value=asset_id))
+ monkeypatch.setattr(
+ smartsim._core._install.utils,
+ "retrieve",
+ lambda s_, d_: mock__retrieve_archive(s_, expected_output_dir),
+ ) # mock the retrieval of the updated archive
+
+ # tell it to retrieve. it should return the path to the new download, not the old one
+ request = DragonInstallRequest(test_archive.parent)
+ asset_path = retrieve_asset(request, asset)
- asset_path = retrieve_asset(test_archive.parent, asset)
- ts2 = asset_path.stat().st_ctime
+ # sanity check we don't have the same paths
+ assert old_output_dir != expected_output_dir
- assert (
- asset_path == test_archive.parent
- ) # show that the expected path matches the output path
- assert ts1 == ts2 # show that the file wasn't changed...
+ # verify the "cached" copy wasn't used
+ assert asset_path == expected_output_dir
@pytest.mark.parametrize(
@@ -214,11 +276,13 @@ def test_retrieve_cached(
)
def test_retrieve_asset_info(
test_assets: t.Collection[GitReleaseAsset],
+ test_releases: t.Collection[GitRelease],
monkeypatch: pytest.MonkeyPatch,
dragon_pin: str,
pyv: str,
is_found: bool,
is_crayex: bool,
+ test_dir: str,
) -> None:
"""Verify that an information is retrieved correctly based on the python
version, platform (e.g. CrayEX, !CrayEx), and target dragon pin"""
@@ -234,20 +298,23 @@ def test_retrieve_asset_info(
"is_crayex_platform",
lambda: is_crayex,
)
+ # avoid hitting github API
ctx.setattr(
smartsim._core._cli.scripts.dragon_install,
- "dragon_pin",
- lambda: dragon_pin,
+ "_get_all_releases",
+ lambda x: test_releases,
)
# avoid hitting github API
ctx.setattr(
smartsim._core._cli.scripts.dragon_install,
"_get_release_assets",
- lambda: test_assets,
+ lambda x: test_assets,
)
+ request = DragonInstallRequest(test_dir, version=dragon_pin)
+
if is_found:
- chosen_asset = retrieve_asset_info()
+ chosen_asset = retrieve_asset_info(request)
assert chosen_asset
assert pyv in chosen_asset.name
@@ -259,7 +326,7 @@ def test_retrieve_asset_info(
assert "crayex" not in chosen_asset.name.lower()
else:
with pytest.raises(SmartSimCLIActionCancelled):
- retrieve_asset_info()
+ retrieve_asset_info(request)
def test_check_for_utility_missing(test_dir: str) -> None:
@@ -357,23 +424,56 @@ def mock_util_check(util: str) -> bool:
assert is_cray == platform_result
-def test_install_package_no_wheel(extraction_dir: pathlib.Path):
+def test_install_package_no_wheel(test_dir: str, extraction_dir: pathlib.Path):
"""Verify that a missing wheel does not blow up and has a failure retcode"""
exp_path = extraction_dir
+ request = DragonInstallRequest(test_dir)
- result = install_package(exp_path)
+ result = install_package(request, exp_path)
assert result != 0
def test_install_macos(monkeypatch: pytest.MonkeyPatch, extraction_dir: pathlib.Path):
- """Verify that installation exits cleanly if installing on unsupported platform"""
+ """Verify that installation exits cleanly if installing on unsupported platform."""
with monkeypatch.context() as ctx:
ctx.setattr(sys, "platform", "darwin")
- result = install_dragon(extraction_dir)
+ request = DragonInstallRequest(extraction_dir)
+
+ result = install_dragon(request)
assert result == 1
+@pytest.mark.parametrize(
+ "version, exp_result",
+ [
+ pytest.param("0.9", 2, id="0.9 DNE In Public Repo"),
+ pytest.param("0.91", 2, id="0.91 DNE In Public Repo"),
+ pytest.param("0.10", 0, id="0.10 Exists In Public Repo"),
+ pytest.param("0.19", 2, id="0.19 DNE In Public Repo"),
+ ],
+)
+def test_install_specify_asset_version(
+ monkeypatch: pytest.MonkeyPatch,
+ extraction_dir: pathlib.Path,
+ version: str,
+ exp_result: int,
+):
+ """Verify that installation completes as expected when fed a variety of
+ version numbers that can or cannot be found on release assets of the
+ public dragon repository.
+
+ :param extraction_dir: file system path where the dragon package should
+ be downloaded and extracted
+ :param version: Dragon version number to attempt to install
+ :param exp_result: Expected return code from the call to `install_dragon`
+ """
+ request = DragonInstallRequest(extraction_dir, version=version)
+
+ result = install_dragon(request)
+ assert result == exp_result
+
+
def test_create_dotenv(monkeypatch: pytest.MonkeyPatch, test_dir: str):
"""Verify that attempting to create a .env file without any existing
file or container directory works"""
@@ -387,7 +487,7 @@ def test_create_dotenv(monkeypatch: pytest.MonkeyPatch, test_dir: str):
# ensure no .env exists before trying to create it.
assert not exp_env_path.exists()
- create_dotenv(mock_dragon_root)
+ create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION)
# ensure the .env is created as side-effect of create_dotenv
assert exp_env_path.exists()
@@ -409,7 +509,7 @@ def test_create_dotenv_existing_dir(monkeypatch: pytest.MonkeyPatch, test_dir: s
# ensure no .env exists before trying to create it.
assert not exp_env_path.exists()
- create_dotenv(mock_dragon_root)
+ create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION)
# ensure the .env is created as side-effect of create_dotenv
assert exp_env_path.exists()
@@ -434,17 +534,25 @@ def test_create_dotenv_existing_dotenv(monkeypatch: pytest.MonkeyPatch, test_dir
# ensure .env exists so we can update it
assert exp_env_path.exists()
- create_dotenv(mock_dragon_root)
+ create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION)
# ensure the .env is created as side-effect of create_dotenv
assert exp_env_path.exists()
# ensure file was overwritten and env vars are not duplicated
dotenv_content = exp_env_path.read_text(encoding="utf-8")
- split_content = dotenv_content.split(var_name)
-
- # split to confirm env var only appars once
- assert len(split_content) == 2
+ lines = [
+ line for line in dotenv_content.split("\n") if line and not "#" in line
+ ]
+ for line in lines:
+ if line.startswith(var_name):
+ # make sure the var isn't defined recursively
+ # DRAGON_BASE_DIR=$DRAGON_BASE_DIR
+ assert var_name not in line[len(var_name) + 1 :]
+ else:
+ # make sure any values reference the original base dir var
+ if var_name in line:
+ assert f"${var_name}" in line
def test_create_dotenv_format(monkeypatch: pytest.MonkeyPatch, test_dir: str):
@@ -456,13 +564,13 @@ def test_create_dotenv_format(monkeypatch: pytest.MonkeyPatch, test_dir: str):
with monkeypatch.context() as ctx:
ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path)
- create_dotenv(mock_dragon_root)
+ create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION)
# ensure the .env is created as side-effect of create_dotenv
content = exp_env_path.read_text(encoding="utf-8")
# ensure we have values written, but ignore empty lines
- lines = [line for line in content.split("\n") if line]
+ lines = [line for line in content.split("\n") if line and not "#" in line]
assert lines
# ensure each line is formatted as key=value
diff --git a/tests/_legacy/test_dragon_launcher.py b/tests/_legacy/test_dragon_launcher.py
index 77f094b7d7..c4f241b24b 100644
--- a/tests/_legacy/test_dragon_launcher.py
+++ b/tests/_legacy/test_dragon_launcher.py
@@ -37,7 +37,10 @@
import zmq
import smartsim._core.config
-from smartsim._core._cli.scripts.dragon_install import create_dotenv
+from smartsim._core._cli.scripts.dragon_install import (
+ DEFAULT_DRAGON_VERSION,
+ create_dotenv,
+)
from smartsim._core.config.config import get_config
from smartsim._core.launcher.dragon.dragon_launcher import (
DragonConnector,
@@ -494,7 +497,7 @@ def test_load_env_env_file_created(monkeypatch: pytest.MonkeyPatch, test_dir: st
with monkeypatch.context() as ctx:
ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path)
- create_dotenv(mock_dragon_root)
+ create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION)
dragon_conf = smartsim._core.config.CONFIG.dragon_dotenv
# verify config does exist
@@ -507,7 +510,26 @@ def test_load_env_env_file_created(monkeypatch: pytest.MonkeyPatch, test_dir: st
assert loaded_env
# confirm .env was parsed as expected by inspecting a key
+ assert "DRAGON_BASE_DIR" in loaded_env
+ base_dir = loaded_env["DRAGON_BASE_DIR"]
+
assert "DRAGON_ROOT_DIR" in loaded_env
+ assert loaded_env["DRAGON_ROOT_DIR"] == base_dir
+
+ assert "DRAGON_INCLUDE_DIR" in loaded_env
+ assert loaded_env["DRAGON_INCLUDE_DIR"] == f"{base_dir}/include"
+
+ assert "DRAGON_LIB_DIR" in loaded_env
+ assert loaded_env["DRAGON_LIB_DIR"] == f"{base_dir}/lib"
+
+ assert "DRAGON_VERSION" in loaded_env
+ assert loaded_env["DRAGON_VERSION"] == DEFAULT_DRAGON_VERSION
+
+ assert "PATH" in loaded_env
+ assert loaded_env["PATH"] == f"{base_dir}/bin"
+
+ assert "LD_LIBRARY_PATH" in loaded_env
+ assert loaded_env["LD_LIBRARY_PATH"] == f"{base_dir}/lib"
def test_load_env_cached_env(monkeypatch: pytest.MonkeyPatch, test_dir: str):
@@ -517,7 +539,7 @@ def test_load_env_cached_env(monkeypatch: pytest.MonkeyPatch, test_dir: str):
with monkeypatch.context() as ctx:
ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path)
- create_dotenv(mock_dragon_root)
+ create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION)
# load config w/launcher
connector = DragonConnector()
@@ -541,7 +563,7 @@ def test_merge_env(monkeypatch: pytest.MonkeyPatch, test_dir: str):
with monkeypatch.context() as ctx:
ctx.setattr(smartsim._core.config.CONFIG, "conf_dir", test_path)
- create_dotenv(mock_dragon_root)
+ create_dotenv(mock_dragon_root, DEFAULT_DRAGON_VERSION)
# load config w/launcher
connector = DragonConnector()
@@ -593,11 +615,14 @@ def test_run_step_fail(test_dir: str) -> None:
step0 = DragonStep("step0", test_dir, rs)
step0.meta["status_dir"] = status_dir
- mock_connector = MagicMock() # DragonConnector()
+ mock_connector = MagicMock(spec=DragonConnector)
mock_connector.is_connected = True
mock_connector.send_request = MagicMock(
return_value=DragonRunResponse(step_id=step0.name, error_message="mock fail!")
)
+ mock_connector.merge_persisted_env = MagicMock(
+ return_value={"FOO": "bar", "BAZ": "boop"}
+ )
launcher = DragonLauncher()
launcher._connector = mock_connector
@@ -676,7 +701,7 @@ def test_run_step_success(test_dir: str) -> None:
step0 = DragonStep("step0", test_dir, rs)
step0.meta["status_dir"] = status_dir
- mock_connector = MagicMock() # DragonConnector()
+ mock_connector = MagicMock(spec=DragonConnector)
mock_connector.is_connected = True
mock_connector.send_request = MagicMock(
return_value=DragonRunResponse(step_id=step0.name)
@@ -684,6 +709,9 @@ def test_run_step_success(test_dir: str) -> None:
launcher = DragonLauncher()
launcher._connector = mock_connector
+ mock_connector.merge_persisted_env = MagicMock(
+ return_value={"FOO": "bar", "BAZ": "boop"}
+ )
result = launcher.run(step0)
diff --git a/tests/_legacy/test_dragon_run_policy.py b/tests/_legacy/test_dragon_run_policy.py
index 5da84bf305..14219f9a32 100644
--- a/tests/_legacy/test_dragon_run_policy.py
+++ b/tests/_legacy/test_dragon_run_policy.py
@@ -114,9 +114,6 @@ def test_create_run_policy_non_run_request(dragon_request: DragonRequest) -> Non
policy = DragonBackend.create_run_policy(dragon_request, "localhost")
assert policy is not None, "Default policy was not returned"
- assert (
- policy.device == Policy.Device.DEFAULT
- ), "Default device was not Device.DEFAULT"
assert policy.cpu_affinity == [], "Default cpu affinity was not empty"
assert policy.gpu_affinity == [], "Default gpu affinity was not empty"
@@ -140,10 +137,8 @@ def test_create_run_policy_run_request_no_run_policy() -> None:
policy = DragonBackend.create_run_policy(run_req, "localhost")
- assert policy.device == Policy.Device.DEFAULT
assert set(policy.cpu_affinity) == set()
assert policy.gpu_affinity == []
- assert policy.affinity == Policy.Affinity.DEFAULT
@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
@@ -167,7 +162,6 @@ def test_create_run_policy_run_request_default_run_policy() -> None:
assert set(policy.cpu_affinity) == set()
assert set(policy.gpu_affinity) == set()
- assert policy.affinity == Policy.Affinity.DEFAULT
@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
@@ -192,7 +186,6 @@ def test_create_run_policy_run_request_cpu_affinity_no_device() -> None:
assert set(policy.cpu_affinity) == affinity
assert policy.gpu_affinity == []
- assert policy.affinity == Policy.Affinity.SPECIFIC
@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
@@ -216,7 +209,6 @@ def test_create_run_policy_run_request_cpu_affinity() -> None:
assert set(policy.cpu_affinity) == affinity
assert policy.gpu_affinity == []
- assert policy.affinity == Policy.Affinity.SPECIFIC
@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
@@ -240,7 +232,6 @@ def test_create_run_policy_run_request_gpu_affinity() -> None:
assert policy.cpu_affinity == []
assert set(policy.gpu_affinity) == set(affinity)
- assert policy.affinity == Policy.Affinity.SPECIFIC
@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
diff --git a/tests/_legacy/test_dragon_run_request.py b/tests/_legacy/test_dragon_run_request.py
index f5fdc73a06..a1c1e495f3 100644
--- a/tests/_legacy/test_dragon_run_request.py
+++ b/tests/_legacy/test_dragon_run_request.py
@@ -30,63 +30,23 @@
import time
from unittest.mock import MagicMock
+import pydantic.error_wrappers
import pytest
-from pydantic import ValidationError
# The tests in this file belong to the group_b group
pytestmark = pytest.mark.group_b
-
-try:
- import dragon
-
- dragon_loaded = True
-except:
- dragon_loaded = False
+dragon = pytest.importorskip("dragon")
from smartsim._core.config import CONFIG
+from smartsim._core.launcher.dragon.dragon_backend import (
+ DragonBackend,
+ ProcessGroupInfo,
+)
+from smartsim._core.launcher.dragon.pqueue import NodePrioritizer
from smartsim._core.schemas.dragon_requests import *
from smartsim._core.schemas.dragon_responses import *
-from smartsim._core.utils.helpers import create_short_id_str
from smartsim.status import TERMINAL_STATUSES, InvalidJobStatus, JobStatus
-if t.TYPE_CHECKING:
- from smartsim._core.launcher.dragon.dragon_backend import (
- DragonBackend,
- ProcessGroupInfo,
- )
-
-
-class NodeMock(MagicMock):
- def __init__(
- self, name: t.Optional[str] = None, num_gpus: int = 2, num_cpus: int = 8
- ) -> None:
- super().__init__()
- self._mock_id = name
- NodeMock._num_gpus = num_gpus
- NodeMock._num_cpus = num_cpus
-
- @property
- def hostname(self) -> str:
- if self._mock_id:
- return self._mock_id
- return create_short_id_str()
-
- @property
- def num_cpus(self) -> str:
- return NodeMock._num_cpus
-
- @property
- def num_gpus(self) -> str:
- return NodeMock._num_gpus
-
- def _set_id(self, value: str) -> None:
- self._mock_id = value
-
- def gpus(self, parent: t.Any = None) -> t.List[str]:
- if self._num_gpus:
- return [f"{self.hostname}-gpu{i}" for i in range(NodeMock._num_gpus)]
- return []
-
class GroupStateMock(MagicMock):
def Running(self) -> MagicMock:
@@ -102,59 +62,57 @@ class ProcessGroupMock(MagicMock):
puids = [121, 122]
-def node_mock() -> NodeMock:
- return NodeMock()
-
-
def get_mock_backend(
- monkeypatch: pytest.MonkeyPatch, num_gpus: int = 2
+ monkeypatch: pytest.MonkeyPatch, num_cpus: int, num_gpus: int
) -> "DragonBackend":
-
+ # create all the necessary namespaces as raw magic mocks
+ monkeypatch.setitem(sys.modules, "dragon.data.ddict.ddict", MagicMock())
+ monkeypatch.setitem(sys.modules, "dragon.native.machine", MagicMock())
+ monkeypatch.setitem(sys.modules, "dragon.native.group_state", MagicMock())
+ monkeypatch.setitem(sys.modules, "dragon.native.process_group", MagicMock())
+ monkeypatch.setitem(sys.modules, "dragon.native.process", MagicMock())
+ monkeypatch.setitem(sys.modules, "dragon.infrastructure.connection", MagicMock())
+ monkeypatch.setitem(sys.modules, "dragon.infrastructure.policy", MagicMock())
+ monkeypatch.setitem(sys.modules, "dragon.infrastructure.process_desc", MagicMock())
+ monkeypatch.setitem(sys.modules, "dragon.data.ddict.ddict", MagicMock())
+
+ node_list = ["node1", "node2", "node3"]
+ system_mock = MagicMock(return_value=MagicMock(nodes=node_list))
+ node_mock = lambda x: MagicMock(hostname=x, num_cpus=num_cpus, num_gpus=num_gpus)
+ process_group_mock = MagicMock(return_value=ProcessGroupMock())
process_mock = MagicMock(returncode=0)
- process_group_mock = MagicMock(**{"Process.return_value": ProcessGroupMock()})
- process_module_mock = MagicMock()
- process_module_mock.Process = process_mock
- node_mock = NodeMock(num_gpus=num_gpus)
- system_mock = MagicMock(nodes=["node1", "node2", "node3"])
+ policy_mock = MagicMock(return_value=MagicMock())
+ group_state_mock = GroupStateMock()
+
+ # customize members that must perform specific actions within the namespaces
monkeypatch.setitem(
sys.modules,
"dragon",
MagicMock(
**{
- "native.machine.Node.return_value": node_mock,
- "native.machine.System.return_value": system_mock,
- "native.group_state": GroupStateMock(),
- "native.process_group.ProcessGroup.return_value": ProcessGroupMock(),
+ "native.machine.Node": node_mock,
+ "native.machine.System": system_mock,
+ "native.group_state": group_state_mock,
+ "native.process_group.ProcessGroup": process_group_mock,
+ "native.process_group.Process": process_mock,
+ "native.process.Process": process_mock,
+ "infrastructure.policy.Policy": policy_mock,
}
),
)
- monkeypatch.setitem(
- sys.modules,
- "dragon.infrastructure.connection",
- MagicMock(),
- )
- monkeypatch.setitem(
- sys.modules,
- "dragon.infrastructure.policy",
- MagicMock(**{"Policy.return_value": MagicMock()}),
- )
- monkeypatch.setitem(sys.modules, "dragon.native.process", process_module_mock)
- monkeypatch.setitem(sys.modules, "dragon.native.process_group", process_group_mock)
-
- monkeypatch.setitem(sys.modules, "dragon.native.group_state", GroupStateMock())
- monkeypatch.setitem(
- sys.modules,
- "dragon.native.machine",
- MagicMock(
- **{"System.return_value": system_mock, "Node.return_value": node_mock}
- ),
- )
- from smartsim._core.launcher.dragon.dragon_backend import DragonBackend
dragon_backend = DragonBackend(pid=99999)
- monkeypatch.setattr(
- dragon_backend, "_free_hosts", collections.deque(dragon_backend._hosts)
+
+ # NOTE: we're manually updating these values due to issue w/mocking namespaces
+ dragon_backend._prioritizer = NodePrioritizer(
+ [
+ MagicMock(num_cpus=num_cpus, num_gpus=num_gpus, hostname=node)
+ for node in node_list
+ ],
+ dragon_backend._queue_lock,
)
+ dragon_backend._cpus = [num_cpus] * len(node_list)
+ dragon_backend._gpus = [num_gpus] * len(node_list)
return dragon_backend
@@ -212,16 +170,14 @@ def set_mock_group_infos(
}
monkeypatch.setattr(dragon_backend, "_group_infos", group_infos)
- monkeypatch.setattr(dragon_backend, "_free_hosts", collections.deque(hosts[1:3]))
- monkeypatch.setattr(dragon_backend, "_allocated_hosts", {hosts[0]: "abc123-1"})
+ monkeypatch.setattr(dragon_backend, "_allocated_hosts", {hosts[0]: {"abc123-1"}})
monkeypatch.setattr(dragon_backend, "_running_steps", ["abc123-1"])
return group_infos
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_handshake_request(monkeypatch: pytest.MonkeyPatch) -> None:
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
handshake_req = DragonHandshakeRequest()
handshake_resp = dragon_backend.process_request(handshake_req)
@@ -230,9 +186,8 @@ def test_handshake_request(monkeypatch: pytest.MonkeyPatch) -> None:
assert handshake_resp.dragon_pid == 99999
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None:
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
run_req = DragonRunRequest(
exe="sleep",
exe_args=["5"],
@@ -259,9 +214,9 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None:
assert dragon_backend._running_steps == [step_id]
assert len(dragon_backend._queued_steps) == 0
- assert len(dragon_backend._free_hosts) == 1
- assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id
- assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id
+ assert len(dragon_backend.free_hosts) == 1
+ assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]]
+ assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]]
monkeypatch.setattr(
dragon_backend._group_infos[step_id].process_group, "status", "Running"
@@ -271,9 +226,9 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None:
assert dragon_backend._running_steps == [step_id]
assert len(dragon_backend._queued_steps) == 0
- assert len(dragon_backend._free_hosts) == 1
- assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id
- assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id
+ assert len(dragon_backend.free_hosts) == 1
+ assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]]
+ assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]]
dragon_backend._group_infos[step_id].status = JobStatus.CANCELLED
@@ -281,9 +236,8 @@ def test_run_request(monkeypatch: pytest.MonkeyPatch) -> None:
assert not dragon_backend._running_steps
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_deny_run_request(monkeypatch: pytest.MonkeyPatch) -> None:
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
dragon_backend._shutdown_requested = True
@@ -309,7 +263,7 @@ def test_deny_run_request(monkeypatch: pytest.MonkeyPatch) -> None:
def test_run_request_with_empty_policy(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that a policy is applied to a run request"""
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
run_req = DragonRunRequest(
exe="sleep",
exe_args=["5"],
@@ -325,10 +279,9 @@ def test_run_request_with_empty_policy(monkeypatch: pytest.MonkeyPatch) -> None:
assert run_req.policy is None
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that a policy is applied to a run request"""
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
run_req = DragonRunRequest(
exe="sleep",
exe_args=["5"],
@@ -356,9 +309,9 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None:
assert dragon_backend._running_steps == [step_id]
assert len(dragon_backend._queued_steps) == 0
- assert len(dragon_backend._free_hosts) == 1
- assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id
- assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id
+ assert len(dragon_backend._prioritizer.unassigned()) == 1
+ assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]]
+ assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]]
monkeypatch.setattr(
dragon_backend._group_infos[step_id].process_group, "status", "Running"
@@ -368,9 +321,9 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None:
assert dragon_backend._running_steps == [step_id]
assert len(dragon_backend._queued_steps) == 0
- assert len(dragon_backend._free_hosts) == 1
- assert dragon_backend._allocated_hosts[dragon_backend.hosts[0]] == step_id
- assert dragon_backend._allocated_hosts[dragon_backend.hosts[1]] == step_id
+ assert len(dragon_backend._prioritizer.unassigned()) == 1
+ assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[0]]
+ assert step_id in dragon_backend._allocated_hosts[dragon_backend.hosts[1]]
dragon_backend._group_infos[step_id].status = JobStatus.CANCELLED
@@ -378,9 +331,8 @@ def test_run_request_with_policy(monkeypatch: pytest.MonkeyPatch) -> None:
assert not dragon_backend._running_steps
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_udpate_status_request(monkeypatch: pytest.MonkeyPatch) -> None:
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
group_infos = set_mock_group_infos(monkeypatch, dragon_backend)
@@ -395,9 +347,8 @@ def test_udpate_status_request(monkeypatch: pytest.MonkeyPatch) -> None:
}
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_stop_request(monkeypatch: pytest.MonkeyPatch) -> None:
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
group_infos = set_mock_group_infos(monkeypatch, dragon_backend)
running_steps = [
@@ -421,10 +372,9 @@ def test_stop_request(monkeypatch: pytest.MonkeyPatch) -> None:
assert dragon_backend._group_infos[step_id_to_stop].status == JobStatus.CANCELLED
assert len(dragon_backend._allocated_hosts) == 0
- assert len(dragon_backend._free_hosts) == 3
+ assert len(dragon_backend._prioritizer.unassigned()) == 3
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
@pytest.mark.parametrize(
"immediate, kill_jobs, frontend_shutdown",
[
@@ -443,7 +393,7 @@ def test_shutdown_request(
frontend_shutdown: bool,
) -> None:
monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", "0")
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
monkeypatch.setattr(dragon_backend, "_cooldown_period", 1)
set_mock_group_infos(monkeypatch, dragon_backend)
@@ -483,11 +433,10 @@ def test_shutdown_request(
assert dragon_backend._has_cooled_down == kill_jobs
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
@pytest.mark.parametrize("telemetry_flag", ["0", "1"])
def test_cooldown_is_set(monkeypatch: pytest.MonkeyPatch, telemetry_flag: str) -> None:
monkeypatch.setenv("SMARTSIM_FLAG_TELEMETRY", telemetry_flag)
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
expected_cooldown = (
2 * CONFIG.telemetry_frequency + 5 if int(telemetry_flag) > 0 else 5
@@ -499,19 +448,17 @@ def test_cooldown_is_set(monkeypatch: pytest.MonkeyPatch, telemetry_flag: str) -
assert dragon_backend.cooldown_period == expected_cooldown
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_heartbeat_and_time(monkeypatch: pytest.MonkeyPatch) -> None:
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
first_heartbeat = dragon_backend.last_heartbeat
assert dragon_backend.current_time > first_heartbeat
dragon_backend._heartbeat()
assert dragon_backend.last_heartbeat > first_heartbeat
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
@pytest.mark.parametrize("num_nodes", [1, 3, 100])
def test_can_honor(monkeypatch: pytest.MonkeyPatch, num_nodes: int) -> None:
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
run_req = DragonRunRequest(
exe="sleep",
exe_args=["5"],
@@ -524,18 +471,42 @@ def test_can_honor(monkeypatch: pytest.MonkeyPatch, num_nodes: int) -> None:
pmi_enabled=False,
)
- assert dragon_backend._can_honor(run_req)[0] == (
- num_nodes <= len(dragon_backend._hosts)
- )
+ can_honor, error_msg = dragon_backend._can_honor(run_req)
+
+ nodes_in_range = num_nodes <= len(dragon_backend._hosts)
+ assert can_honor == nodes_in_range
+ assert error_msg is None if nodes_in_range else error_msg is not None
+
+
+@pytest.mark.parametrize("num_nodes", [-10, -1, 0])
+def test_can_honor_invalid_num_nodes(
+ monkeypatch: pytest.MonkeyPatch, num_nodes: int
+) -> None:
+ """Verify that requests for invalid numbers of nodes (negative, zero) are rejected"""
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
+
+ with pytest.raises(pydantic.error_wrappers.ValidationError) as ex:
+ DragonRunRequest(
+ exe="sleep",
+ exe_args=["5"],
+ path="/a/fake/path",
+ nodes=num_nodes,
+ tasks=1,
+ tasks_per_node=1,
+ env={},
+ current_env={},
+ pmi_enabled=False,
+ )
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
@pytest.mark.parametrize("affinity", [[0], [0, 1], list(range(8))])
def test_can_honor_cpu_affinity(
monkeypatch: pytest.MonkeyPatch, affinity: t.List[int]
) -> None:
"""Verify that valid CPU affinities are accepted"""
- dragon_backend = get_mock_backend(monkeypatch)
+ num_cpus, num_gpus = 8, 0
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=num_cpus, num_gpus=num_gpus)
+
run_req = DragonRunRequest(
exe="sleep",
exe_args=["5"],
@@ -552,11 +523,10 @@ def test_can_honor_cpu_affinity(
assert dragon_backend._can_honor(run_req)[0]
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_can_honor_cpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that invalid CPU affinities are NOT accepted
NOTE: negative values are captured by the Pydantic schema"""
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
run_req = DragonRunRequest(
exe="sleep",
exe_args=["5"],
@@ -573,13 +543,15 @@ def test_can_honor_cpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) ->
assert not dragon_backend._can_honor(run_req)[0]
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
@pytest.mark.parametrize("affinity", [[0], [0, 1]])
def test_can_honor_gpu_affinity(
monkeypatch: pytest.MonkeyPatch, affinity: t.List[int]
) -> None:
"""Verify that valid GPU affinities are accepted"""
- dragon_backend = get_mock_backend(monkeypatch)
+
+ num_cpus, num_gpus = 8, 2
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=num_cpus, num_gpus=num_gpus)
+
run_req = DragonRunRequest(
exe="sleep",
exe_args=["5"],
@@ -596,11 +568,10 @@ def test_can_honor_gpu_affinity(
assert dragon_backend._can_honor(run_req)[0]
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_can_honor_gpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that invalid GPU affinities are NOT accepted
NOTE: negative values are captured by the Pydantic schema"""
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
run_req = DragonRunRequest(
exe="sleep",
exe_args=["5"],
@@ -617,46 +588,46 @@ def test_can_honor_gpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) ->
assert not dragon_backend._can_honor(run_req)[0]
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_can_honor_gpu_device_not_available(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that a request for a GPU if none exists is not accepted"""
# create a mock node class that always reports no GPUs available
- dragon_backend = get_mock_backend(monkeypatch, num_gpus=0)
-
- run_req = DragonRunRequest(
- exe="sleep",
- exe_args=["5"],
- path="/a/fake/path",
- nodes=2,
- tasks=1,
- tasks_per_node=1,
- env={},
- current_env={},
- pmi_enabled=False,
- # specify GPU device w/no affinity
- policy=DragonRunPolicy(gpu_affinity=[0]),
- )
-
- assert not dragon_backend._can_honor(run_req)[0]
+ with monkeypatch.context() as ctx:
+ dragon_backend = get_mock_backend(ctx, num_cpus=8, num_gpus=0)
+
+ run_req = DragonRunRequest(
+ exe="sleep",
+ exe_args=["5"],
+ path="/a/fake/path",
+ nodes=2,
+ tasks=1,
+ tasks_per_node=1,
+ env={},
+ current_env={},
+ pmi_enabled=False,
+ # specify GPU device w/no affinity
+ policy=DragonRunPolicy(gpu_affinity=[0]),
+ )
+ can_honor, _ = dragon_backend._can_honor(run_req)
+ assert not can_honor
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_get_id(monkeypatch: pytest.MonkeyPatch) -> None:
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
step_id = next(dragon_backend._step_ids)
assert step_id.endswith("0")
assert step_id != next(dragon_backend._step_ids)
-@pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems")
def test_view(monkeypatch: pytest.MonkeyPatch) -> None:
- dragon_backend = get_mock_backend(monkeypatch)
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
set_mock_group_infos(monkeypatch, dragon_backend)
hosts = dragon_backend.hosts
+ dragon_backend._prioritizer.increment(hosts[0])
- expected_message = textwrap.dedent(f"""\
+ expected_msg = textwrap.dedent(
+ f"""\
Dragon server backend update
| Host | Status |
|--------|----------|
@@ -664,15 +635,120 @@ def test_view(monkeypatch: pytest.MonkeyPatch) -> None:
| {hosts[1]} | Free |
| {hosts[2]} | Free |
| Step | Status | Hosts | Return codes | Num procs |
- |----------|--------------|-------------|----------------|-------------|
+ |----------|--------------|-----------------|----------------|-------------|
| abc123-1 | Running | {hosts[0]} | | 1 |
| del999-2 | Cancelled | {hosts[1]} | -9 | 1 |
| c101vz-3 | Completed | {hosts[1]},{hosts[2]} | 0 | 2 |
| 0ghjk1-4 | Failed | {hosts[2]} | -1 | 1 |
- | ljace0-5 | NeverStarted | | | 0 |""")
+ | ljace0-5 | NeverStarted | | | 0 |"""
+ )
# get rid of white space to make the comparison easier
actual_msg = dragon_backend.status_message.replace(" ", "")
- expected_message = expected_message.replace(" ", "")
+ expected_msg = expected_msg.replace(" ", "")
+
+ # ignore dashes in separators (hostname changes may cause column expansion)
+ while actual_msg.find("--") > -1:
+ actual_msg = actual_msg.replace("--", "-")
+ while expected_msg.find("--") > -1:
+ expected_msg = expected_msg.replace("--", "-")
+
+ assert actual_msg == expected_msg
+
+
+def test_can_honor_hosts_unavailable_hosts(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Verify that requesting nodes with invalid names causes number of available
+ nodes check to fail due to valid # of named nodes being under num_nodes"""
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
+
+ # let's supply 2 invalid and 1 valid hostname
+ actual_hosts = list(dragon_backend._hosts)
+ actual_hosts[0] = f"x{actual_hosts[0]}"
+ actual_hosts[1] = f"x{actual_hosts[1]}"
+
+ host_list = ",".join(actual_hosts)
+
+ run_req = DragonRunRequest(
+ exe="sleep",
+ exe_args=["5"],
+ path="/a/fake/path",
+ nodes=2, # <----- requesting 2 of 3 available nodes
+ hostlist=host_list, # <--- only one valid name available
+ tasks=1,
+ tasks_per_node=1,
+ env={},
+ current_env={},
+ pmi_enabled=False,
+ policy=DragonRunPolicy(),
+ )
+
+ can_honor, error_msg = dragon_backend._can_honor(run_req)
+
+ # confirm the failure is indicated
+ assert not can_honor
+ # confirm failure message indicates number of nodes requested as cause
+ assert "named hosts" in error_msg
+
+
+def test_can_honor_hosts_unavailable_hosts_ok(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Verify that requesting nodes with invalid names causes number of available
+ nodes check to be reduced but still passes if enough valid named nodes are passed"""
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
+
+ # let's supply 2 valid and 1 invalid hostname
+ actual_hosts = list(dragon_backend._hosts)
+ actual_hosts[0] = f"x{actual_hosts[0]}"
+
+ host_list = ",".join(actual_hosts)
+
+ run_req = DragonRunRequest(
+ exe="sleep",
+ exe_args=["5"],
+ path="/a/fake/path",
+ nodes=2, # <----- requesting 2 of 3 available nodes
+ hostlist=host_list, # <--- two valid names are available
+ tasks=1,
+ tasks_per_node=1,
+ env={},
+ current_env={},
+ pmi_enabled=False,
+ policy=DragonRunPolicy(),
+ )
+
+ can_honor, error_msg = dragon_backend._can_honor(run_req)
+
+ # confirm the failure is indicated
+ assert can_honor, error_msg
+ # confirm failure message indicates number of nodes requested as cause
+ assert error_msg is None, error_msg
+
+
+def test_can_honor_hosts_1_hosts_requested(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Verify that requesting nodes with invalid names causes number of available
+ nodes check to be reduced but still passes if enough valid named nodes are passed"""
+ dragon_backend = get_mock_backend(monkeypatch, num_cpus=8, num_gpus=0)
+
+ # let's supply 2 valid and 1 invalid hostname
+ actual_hosts = list(dragon_backend._hosts)
+ actual_hosts[0] = f"x{actual_hosts[0]}"
+
+ host_list = ",".join(actual_hosts)
+
+ run_req = DragonRunRequest(
+ exe="sleep",
+ exe_args=["5"],
+ path="/a/fake/path",
+ nodes=1, # <----- requesting 0 nodes - should be ignored
+ hostlist=host_list, # <--- two valid names are available
+ tasks=1,
+ tasks_per_node=1,
+ env={},
+ current_env={},
+ pmi_enabled=False,
+ policy=DragonRunPolicy(),
+ )
+
+ can_honor, error_msg = dragon_backend._can_honor(run_req)
- assert actual_msg == expected_message
+ # confirm the failure is indicated
+ assert can_honor, error_msg
diff --git a/tests/_legacy/test_dragon_run_request_nowlm.py b/tests/_legacy/test_dragon_run_request_nowlm.py
index 2b5526c69e..98f5b706da 100644
--- a/tests/_legacy/test_dragon_run_request_nowlm.py
+++ b/tests/_legacy/test_dragon_run_request_nowlm.py
@@ -101,5 +101,5 @@ def test_run_request_with_negative_affinity(
),
)
- assert f"{device}_affinity" in str(ex.value.args[0])
- assert "NumberNotGeError" in str(ex.value.args[0])
+ assert f"{device}_affinity" in str(ex.value)
+ assert "greater than or equal to 0" in str(ex.value)
diff --git a/tests/_legacy/test_dragon_step.py b/tests/_legacy/test_dragon_step.py
index 17279a33c6..3dbdf114ea 100644
--- a/tests/_legacy/test_dragon_step.py
+++ b/tests/_legacy/test_dragon_step.py
@@ -73,12 +73,18 @@ def dragon_batch_step(test_dir: str) -> DragonBatchStep:
cpu_affinities = [[], [0, 1, 2], [], [3, 4, 5, 6]]
gpu_affinities = [[], [], [0, 1, 2], [3, 4, 5, 6]]
+ # specify 3 hostnames to select from but require only 2 nodes
+ num_nodes = 2
+ hostnames = ["host1", "host2", "host3"]
+
# assign some unique affinities to each run setting instance
for index, rs in enumerate(settings):
if gpu_affinities[index]:
rs.set_node_feature("gpu")
rs.set_cpu_affinity(cpu_affinities[index])
rs.set_gpu_affinity(gpu_affinities[index])
+ rs.set_hostlist(hostnames)
+ rs.set_nodes(num_nodes)
steps = list(
DragonStep(name_, test_dir, rs_) for name_, rs_ in zip(names, settings)
@@ -374,6 +380,11 @@ def test_dragon_batch_step_write_request_file(
cpu_affinities = [[], [0, 1, 2], [], [3, 4, 5, 6]]
gpu_affinities = [[], [], [0, 1, 2], [3, 4, 5, 6]]
+ hostnames = ["host1", "host2", "host3"]
+ num_nodes = 2
+
+ # parse requests file path from the launch command
+ # e.g. dragon python
launch_cmd = dragon_batch_step.get_launch_cmd()
requests_file = get_request_path_from_batch_script(launch_cmd)
@@ -392,3 +403,5 @@ def test_dragon_batch_step_write_request_file(
assert run_request
assert run_request.policy.cpu_affinity == cpu_affinities[index]
assert run_request.policy.gpu_affinity == gpu_affinities[index]
+ assert run_request.nodes == num_nodes
+ assert run_request.hostlist == ",".join(hostnames)
diff --git a/tests/_legacy/test_model.py b/tests/_legacy/test_model.py
index 5adf8070f1..f8a84deb8d 100644
--- a/tests/_legacy/test_model.py
+++ b/tests/_legacy/test_model.py
@@ -24,8 +24,10 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import typing as t
from uuid import uuid4
+import numpy as np
import pytest
from smartsim import Experiment
@@ -35,7 +37,10 @@
from smartsim.entity import Application
from smartsim.error import EntityExistsError, SSUnsupportedError
from smartsim.settings import RunSettings, SbatchSettings, SrunSettings
-from smartsim.settings.mpiSettings import _BaseMPISettings
+
+# from smartsim.settings.mpiSettings import
+
+_BaseMPISettings = t.Any
# The tests in this file belong to the slow_tests group
pytestmark = pytest.mark.slow_tests
diff --git a/tests/_legacy/test_preview.py b/tests/_legacy/test_preview.py
index 82d443fb3e..6f029aab8f 100644
--- a/tests/_legacy/test_preview.py
+++ b/tests/_legacy/test_preview.py
@@ -359,7 +359,7 @@ def test_model_preview_properties(test_dir, wlmutils):
assert hw_rs == hello_world_model.run_settings.exe_args[0]
assert None == hello_world_model.batch_settings
assert "port" in list(hello_world_model.params.items())[0]
- assert hw_port in list(hello_world_model.params.items())[0]
+ assert str(hw_port) in list(hello_world_model.params.items())[0]
assert "password" in list(hello_world_model.params.items())[1]
assert hw_password in list(hello_world_model.params.items())[1]
@@ -983,7 +983,7 @@ def test_preview_active_infrastructure_feature_store_error(
exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher)
monkeypatch.setattr(
- smartsim.database.orchestrator.FeatureStore, "is_active", lambda x: True
+ smartsim.database.feature_store.FeatureStore, "is_active", lambda x: True
)
orc = exp.create_feature_store(
diff --git a/tests/_legacy/test_smartredis.py b/tests/_legacy/test_smartredis.py
index f09cc8ca89..d4ac0ceebc 100644
--- a/tests/_legacy/test_smartredis.py
+++ b/tests/_legacy/test_smartredis.py
@@ -27,10 +27,7 @@
import pytest
-from smartsim import Experiment
-from smartsim._core.utils import installed_redisai_backends
from smartsim.builders import Ensemble
-from smartsim.database import FeatureStore
from smartsim.entity import Application
from smartsim.status import JobStatus
@@ -51,7 +48,9 @@
except ImportError:
shouldrun = False
-torch_available = "torch" in installed_redisai_backends()
+torch_available = (
+ "torch" in []
+) # todo: update test to replace installed_redisai_backends()
shouldrun &= torch_available
diff --git a/tests/backends/test_ml_init.py b/tests/backends/test_ml_init.py
index 445ee8c444..7f5c6f9864 100644
--- a/tests/backends/test_ml_init.py
+++ b/tests/backends/test_ml_init.py
@@ -28,7 +28,13 @@
import pytest
-pytestmark = [pytest.mark.group_a, pytest.mark.group_b, pytest.mark.slow_tests]
+try:
+ import tensorflow
+ import torch
+except:
+ pytestmark = pytest.mark.skip("tensorflow or torch were not availalble")
+else:
+ pytestmark = [pytest.mark.group_a, pytest.mark.group_b, pytest.mark.slow_tests]
def test_import_ss_ml(monkeypatch):
diff --git a/tests/dragon_wlm/__init__.py b/tests/dragon_wlm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/dragon_wlm/channel.py b/tests/dragon_wlm/channel.py
new file mode 100644
index 0000000000..4c46359c2d
--- /dev/null
+++ b/tests/dragon_wlm/channel.py
@@ -0,0 +1,125 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import base64
+import pathlib
+import threading
+import typing as t
+
+from smartsim._core.mli.comm.channel.channel import CommChannelBase
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class FileSystemCommChannel(CommChannelBase):
+ """Passes messages by writing to a file"""
+
+ def __init__(self, key: pathlib.Path) -> None:
+ """Initialize the FileSystemCommChannel instance.
+
+ :param key: a path to the root directory of the feature store
+ """
+ self._lock = threading.RLock()
+
+ super().__init__(key.as_posix())
+ self._file_path = key
+
+ if not self._file_path.parent.exists():
+ self._file_path.parent.mkdir(parents=True)
+
+ self._file_path.touch()
+
+ def send(self, value: bytes, timeout: float = 0) -> None:
+ """Send a message throuh the underlying communication channel.
+
+ :param value: The value to send
+ :param timeout: maximum time to wait (in seconds) for messages to send
+ """
+ with self._lock:
+ # write as text so we can add newlines as delimiters
+ with open(self._file_path, "a") as fp:
+ encoded_value = base64.b64encode(value).decode("utf-8")
+ fp.write(f"{encoded_value}\n")
+ logger.debug(f"FileSystemCommChannel {self._file_path} sent message")
+
+ def recv(self, timeout: float = 0) -> t.List[bytes]:
+ """Receives message(s) through the underlying communication channel.
+
+ :param timeout: maximum time to wait (in seconds) for messages to arrive
+ :returns: the received message
+ :raises SmartSimError: if the descriptor points to a missing file
+ """
+ with self._lock:
+ messages: t.List[bytes] = []
+ if not self._file_path.exists():
+ raise SmartSimError("Empty channel")
+
+ # read as text so we can split on newlines
+ with open(self._file_path, "r") as fp:
+ lines = fp.readlines()
+
+ if lines:
+ line = lines.pop(0)
+ event_bytes = base64.b64decode(line.encode("utf-8"))
+ messages.append(event_bytes)
+
+ self.clear()
+
+ # remove the first message only, write remainder back...
+ if len(lines) > 0:
+ with open(self._file_path, "w") as fp:
+ fp.writelines(lines)
+
+ logger.debug(
+ f"FileSystemCommChannel {self._file_path} received message"
+ )
+
+ return messages
+
+ def clear(self) -> None:
+ """Create an empty file for events."""
+ if self._file_path.exists():
+ self._file_path.unlink()
+ self._file_path.touch()
+
+ @classmethod
+ def from_descriptor(
+ cls,
+ descriptor: str,
+ ) -> "FileSystemCommChannel":
+ """A factory method that creates an instance from a descriptor string.
+
+ :param descriptor: The descriptor that uniquely identifies the resource
+ :returns: An attached FileSystemCommChannel
+ """
+ try:
+ path = pathlib.Path(descriptor)
+ return FileSystemCommChannel(path)
+ except:
+ logger.warning(f"failed to create fs comm channel: {descriptor}")
+ raise
diff --git a/tests/dragon_wlm/conftest.py b/tests/dragon_wlm/conftest.py
new file mode 100644
index 0000000000..bdec40b7e5
--- /dev/null
+++ b/tests/dragon_wlm/conftest.py
@@ -0,0 +1,126 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from __future__ import annotations
+
+import os
+import socket
+import typing as t
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+# isort: off
+import dragon.data.ddict.ddict as dragon_ddict
+import dragon.infrastructure.policy as dragon_policy
+import dragon.infrastructure.process_desc as dragon_process_desc
+import dragon.native.process as dragon_process
+
+from dragon.fli import FLInterface
+
+# isort: on
+
+from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
+from smartsim._core.mli.comm.channel.dragon_util import create_local
+from smartsim._core.mli.infrastructure.storage import dragon_util
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+@pytest.fixture(scope="module")
+def the_storage() -> dragon_ddict.DDict:
+ """Fixture to instantiate a dragon distributed dictionary."""
+ return dragon_util.create_ddict(1, 2, 32 * 1024**2)
+
+
+@pytest.fixture(scope="module")
+def the_worker_channel() -> DragonFLIChannel:
+ """Fixture to create a valid descriptor for a worker channel
+ that can be attached to."""
+ channel_ = create_local()
+ fli_ = FLInterface(main_ch=channel_, manager_ch=None)
+ comm_channel = DragonFLIChannel(fli_)
+ return comm_channel
+
+
+@pytest.fixture(scope="module")
+def the_backbone(
+ the_storage: t.Any, the_worker_channel: DragonFLIChannel
+) -> BackboneFeatureStore:
+ """Fixture to create a distributed dragon dictionary and wrap it
+ in a BackboneFeatureStore.
+
+ :param the_storage: The dragon storage engine to use
+ :param the_worker_channel: Pre-configured worker channel
+ """
+
+ backbone = BackboneFeatureStore(the_storage, allow_reserved_writes=True)
+ backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_channel.descriptor
+
+ return backbone
+
+
+@pytest.fixture(scope="module")
+def backbone_descriptor(the_backbone: BackboneFeatureStore) -> str:
+ # create a shared backbone featurestore
+ return the_backbone.descriptor
+
+
+def function_as_dragon_proc(
+ entrypoint_fn: t.Callable[[t.Any], None],
+ args: t.List[t.Any],
+ cpu_affinity: t.List[int],
+ gpu_affinity: t.List[int],
+) -> dragon_process.Process:
+ """Execute a function as an independent dragon process.
+
+ :param entrypoint_fn: The function to execute
+ :param args: The arguments for the entrypoint function
+ :param cpu_affinity: The cpu affinity for the process
+ :param gpu_affinity: The gpu affinity for the process
+ :returns: The dragon process handle
+ """
+ options = dragon_process_desc.ProcessOptions(make_inf_channels=True)
+ local_policy = dragon_policy.Policy(
+ placement=dragon_policy.Policy.Placement.HOST_NAME,
+ host_name=socket.gethostname(),
+ cpu_affinity=cpu_affinity,
+ gpu_affinity=gpu_affinity,
+ )
+ return dragon_process.Process(
+ target=entrypoint_fn,
+ args=args,
+ cwd=os.getcwd(),
+ policy=local_policy,
+ options=options,
+ stderr=dragon_process.Popen.STDOUT,
+ stdout=dragon_process.Popen.STDOUT,
+ )
diff --git a/tests/dragon_wlm/feature_store.py b/tests/dragon_wlm/feature_store.py
new file mode 100644
index 0000000000..d06b0b334e
--- /dev/null
+++ b/tests/dragon_wlm/feature_store.py
@@ -0,0 +1,152 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pathlib
+import typing as t
+
+import smartsim.error as sse
+from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class MemoryFeatureStore(FeatureStore):
+ """A feature store with values persisted only in local memory"""
+
+ def __init__(
+ self, storage: t.Optional[t.Dict[str, t.Union[str, bytes]]] = None
+ ) -> None:
+ """Initialize the MemoryFeatureStore instance"""
+ super().__init__("in-memory-fs")
+ if storage is None:
+ storage = {"_": "abc"}
+ self._storage = storage
+
+ def _get(self, key: str) -> t.Union[str, bytes]:
+ """Retrieve a value from the underlying storage mechanism
+
+ :param key: The unique key that identifies the resource
+ :returns: the value identified by the key
+ :raises KeyError: if the key has not been used to store a value"""
+ return self._storage[key]
+
+ def _set(self, key: str, value: t.Union[str, bytes]) -> None:
+ """Store a value into the underlying storage mechanism
+
+ :param key: The unique key that identifies the resource
+ :param value: The value to store
+ :returns: the value identified by the key
+ :raises KeyError: if the key has not been used to store a value"""
+ self._storage[key] = value
+
+ def _contains(self, key: str) -> bool:
+ """Determine if the storage mechanism contains a given key
+
+ :param key: The unique key that identifies the resource
+ :returns: True if the key is defined, False otherwise"""
+ return key in self._storage
+
+
+class FileSystemFeatureStore(FeatureStore):
+ """Alternative feature store implementation for testing. Stores all
+ data on the file system"""
+
+ def __init__(self, storage_dir: t.Union[pathlib.Path, str]) -> None:
+ """Initialize the FileSystemFeatureStore instance
+
+ :param storage_dir: (optional) root directory to store all data relative to"""
+ if isinstance(storage_dir, str):
+ storage_dir = pathlib.Path(storage_dir)
+ self._storage_dir = storage_dir
+ super().__init__(storage_dir.as_posix())
+
+ def _get(self, key: str) -> t.Union[str, bytes]:
+ """Retrieve a value from the underlying storage mechanism
+
+ :param key: The unique key that identifies the resource
+ :returns: the value identified by the key
+ :raises KeyError: if the key has not been used to store a value"""
+ path = self._key_path(key)
+ if not path.exists():
+ raise sse.SmartSimError(f"{path} not found in feature store")
+ return path.read_bytes()
+
+ def _set(self, key: str, value: t.Union[str, bytes]) -> None:
+ """Store a value into the underlying storage mechanism
+
+ :param key: The unique key that identifies the resource
+ :param value: The value to store
+ :returns: the value identified by the key
+ :raises KeyError: if the key has not been used to store a value"""
+ path = self._key_path(key, create=True)
+ if isinstance(value, str):
+ value = value.encode("utf-8")
+ path.write_bytes(value)
+
+ def _contains(self, key: str) -> bool:
+ """Determine if the storage mechanism contains a given key
+
+ :param key: The unique key that identifies the resource
+ :returns: True if the key is defined, False otherwise"""
+ path = self._key_path(key)
+ return path.exists()
+
+ def _key_path(self, key: str, create: bool = False) -> pathlib.Path:
+ """Given a key, return a path that is optionally combined with a base
+ directory used by the FileSystemFeatureStore.
+
+ :param key: Unique key of an item to retrieve from the feature store"""
+ value = pathlib.Path(key)
+
+ if self._storage_dir is not None:
+ value = self._storage_dir / key
+
+ if create:
+ value.parent.mkdir(parents=True, exist_ok=True)
+
+ return value
+
+ @classmethod
+ def from_descriptor(
+ cls,
+ descriptor: str,
+ ) -> "FileSystemFeatureStore":
+ """A factory method that creates an instance from a descriptor string
+
+ :param descriptor: The descriptor that uniquely identifies the resource
+ :returns: An attached FileSystemFeatureStore"""
+ try:
+ path = pathlib.Path(descriptor)
+ path.mkdir(parents=True, exist_ok=True)
+ if not path.is_dir():
+ raise ValueError("FileSystemFeatureStore requires a directory path")
+ if not path.exists():
+ path.mkdir(parents=True, exist_ok=True)
+ return FileSystemFeatureStore(path)
+ except:
+ logger.error(f"Error while creating FileSystemFeatureStore: {descriptor}")
+ raise
diff --git a/tests/dragon_wlm/test_core_machine_learning_worker.py b/tests/dragon_wlm/test_core_machine_learning_worker.py
new file mode 100644
index 0000000000..f9295d9e86
--- /dev/null
+++ b/tests/dragon_wlm/test_core_machine_learning_worker.py
@@ -0,0 +1,377 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pathlib
+import time
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+import torch
+
+import smartsim.error as sse
+from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey, TensorKey
+from smartsim._core.mli.infrastructure.worker.worker import (
+ InferenceRequest,
+ MachineLearningWorkerCore,
+ RequestBatch,
+ TransformOutputResult,
+)
+
+from .feature_store import FileSystemFeatureStore, MemoryFeatureStore
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+# retrieved from pytest fixtures
+is_dragon = (
+ pytest.test_launcher == "dragon" if hasattr(pytest, "test_launcher") else False
+)
+torch_available = (
+ "torch" in []
+) # todo: update test to replace installed_redisai_backends()
+
+
+@pytest.fixture
+def persist_torch_model(test_dir: str) -> pathlib.Path:
+ ts_start = time.time_ns()
+ print("Starting model file creation...")
+ test_path = pathlib.Path(test_dir)
+ model_path = test_path / "basic.pt"
+
+ model = torch.nn.Linear(2, 1)
+ torch.save(model, model_path)
+ ts_end = time.time_ns()
+
+ ts_elapsed = (ts_end - ts_start) / 1000000000
+ print(f"Model file creation took {ts_elapsed} seconds")
+ return model_path
+
+
+@pytest.fixture
+def persist_torch_tensor(test_dir: str) -> pathlib.Path:
+ ts_start = time.time_ns()
+ print("Starting model file creation...")
+ test_path = pathlib.Path(test_dir)
+ file_path = test_path / "tensor.pt"
+
+ tensor = torch.randn((100, 100, 2))
+ torch.save(tensor, file_path)
+ ts_end = time.time_ns()
+
+ ts_elapsed = (ts_end - ts_start) / 1000000000
+ print(f"Tensor file creation took {ts_elapsed} seconds")
+ return file_path
+
+
+@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+def test_fetch_model_disk(persist_torch_model: pathlib.Path, test_dir: str) -> None:
+ """Verify that the ML worker successfully retrieves a model
+ when given a valid (file system) key"""
+ worker = MachineLearningWorkerCore
+ key = str(persist_torch_model)
+ feature_store = FileSystemFeatureStore(test_dir)
+ fsd = feature_store.descriptor
+ feature_store[str(persist_torch_model)] = persist_torch_model.read_bytes()
+
+ model_key = ModelKey(key=key, descriptor=fsd)
+ request = InferenceRequest(model_key=model_key)
+ batch = RequestBatch([request], None, model_key)
+
+ fetch_result = worker.fetch_model(batch, {fsd: feature_store})
+ assert fetch_result.model_bytes
+ assert fetch_result.model_bytes == persist_torch_model.read_bytes()
+
+
+def test_fetch_model_disk_missing() -> None:
+ """Verify that the ML worker fails to retrieves a model
+ when given an invalid (file system) key"""
+ worker = MachineLearningWorkerCore
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+
+ key = "/path/that/doesnt/exist"
+
+ model_key = ModelKey(key=key, descriptor=fsd)
+ request = InferenceRequest(model_key=model_key)
+ batch = RequestBatch([request], None, model_key)
+
+ with pytest.raises(sse.SmartSimError) as ex:
+ worker.fetch_model(batch, {fsd: feature_store})
+
+ # ensure the error message includes key-identifying information
+ assert key in ex.value.args[0]
+
+
+@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+def test_fetch_model_feature_store(persist_torch_model: pathlib.Path) -> None:
+ """Verify that the ML worker successfully retrieves a model
+ when given a valid (file system) key"""
+ worker = MachineLearningWorkerCore
+
+ # create a key to retrieve from the feature store
+ key = "test-model"
+
+ # put model bytes into the feature store
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+ feature_store[key] = persist_torch_model.read_bytes()
+
+ model_key = ModelKey(key=key, descriptor=feature_store.descriptor)
+ request = InferenceRequest(model_key=model_key)
+ batch = RequestBatch([request], None, model_key)
+
+ fetch_result = worker.fetch_model(batch, {fsd: feature_store})
+ assert fetch_result.model_bytes
+ assert fetch_result.model_bytes == persist_torch_model.read_bytes()
+
+
+def test_fetch_model_feature_store_missing() -> None:
+ """Verify that the ML worker fails to retrieves a model
+ when given an invalid (feature store) key"""
+ worker = MachineLearningWorkerCore
+
+ key = "some-key"
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+
+ model_key = ModelKey(key=key, descriptor=feature_store.descriptor)
+ request = InferenceRequest(model_key=model_key)
+ batch = RequestBatch([request], None, model_key)
+
+ # todo: consider that raising this exception shows impl. replace...
+ with pytest.raises(sse.SmartSimError) as ex:
+ worker.fetch_model(batch, {fsd: feature_store})
+
+ # ensure the error message includes key-identifying information
+ assert key in ex.value.args[0]
+
+
+@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+def test_fetch_model_memory(persist_torch_model: pathlib.Path) -> None:
+ """Verify that the ML worker successfully retrieves a model
+ when given a valid (file system) key"""
+ worker = MachineLearningWorkerCore
+
+ key = "test-model"
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+ feature_store[key] = persist_torch_model.read_bytes()
+
+ model_key = ModelKey(key=key, descriptor=feature_store.descriptor)
+ request = InferenceRequest(model_key=model_key)
+ batch = RequestBatch([request], None, model_key)
+
+ fetch_result = worker.fetch_model(batch, {fsd: feature_store})
+ assert fetch_result.model_bytes
+ assert fetch_result.model_bytes == persist_torch_model.read_bytes()
+
+
+@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+def test_fetch_input_disk(persist_torch_tensor: pathlib.Path) -> None:
+ """Verify that the ML worker successfully retrieves a tensor/input
+ when given a valid (file system) key"""
+ tensor_name = str(persist_torch_tensor)
+
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+ request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)])
+
+ model_key = ModelKey(key="test-model", descriptor=fsd)
+ batch = RequestBatch([request], None, model_key)
+
+ worker = MachineLearningWorkerCore
+
+ feature_store[tensor_name] = persist_torch_tensor.read_bytes()
+
+ fetch_result = worker.fetch_inputs(batch, {fsd: feature_store})
+ assert fetch_result[0].inputs is not None
+
+
+def test_fetch_input_disk_missing() -> None:
+ """Verify that the ML worker fails to retrieves a tensor/input
+ when given an invalid (file system) key"""
+ worker = MachineLearningWorkerCore
+
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+ key = "/path/that/doesnt/exist"
+
+ request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)])
+
+ model_key = ModelKey(key="test-model", descriptor=fsd)
+ batch = RequestBatch([request], None, model_key)
+
+ with pytest.raises(sse.SmartSimError) as ex:
+ worker.fetch_inputs(batch, {fsd: feature_store})
+
+ # ensure the error message includes key-identifying information
+ assert key[0] in ex.value.args[0]
+
+
+@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+def test_fetch_input_feature_store(persist_torch_tensor: pathlib.Path) -> None:
+ """Verify that the ML worker successfully retrieves a tensor/input
+ when given a valid (feature store) key"""
+ worker = MachineLearningWorkerCore
+
+ tensor_name = "test-tensor"
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+
+ request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)])
+
+ # put model bytes into the feature store
+ feature_store[tensor_name] = persist_torch_tensor.read_bytes()
+
+ model_key = ModelKey(key="test-model", descriptor=fsd)
+ batch = RequestBatch([request], None, model_key)
+
+ fetch_result = worker.fetch_inputs(batch, {fsd: feature_store})
+ assert fetch_result[0].inputs
+ assert (
+ list(fetch_result[0].inputs)[0][:10] == persist_torch_tensor.read_bytes()[:10]
+ )
+
+
+@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+def test_fetch_multi_input_feature_store(persist_torch_tensor: pathlib.Path) -> None:
+ """Verify that the ML worker successfully retrieves multiple tensor/input
+ when given a valid collection of (feature store) keys"""
+ worker = MachineLearningWorkerCore
+
+ tensor_name = "test-tensor"
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+
+ # put model bytes into the feature store
+ body1 = persist_torch_tensor.read_bytes()
+ feature_store[tensor_name + "1"] = body1
+
+ body2 = b"abcdefghijklmnopqrstuvwxyz"
+ feature_store[tensor_name + "2"] = body2
+
+ body3 = b"mnopqrstuvwxyzabcdefghijkl"
+ feature_store[tensor_name + "3"] = body3
+
+ request = InferenceRequest(
+ input_keys=[
+ TensorKey(key=tensor_name + "1", descriptor=fsd),
+ TensorKey(key=tensor_name + "2", descriptor=fsd),
+ TensorKey(key=tensor_name + "3", descriptor=fsd),
+ ]
+ )
+
+ model_key = ModelKey(key="test-model", descriptor=fsd)
+ batch = RequestBatch([request], None, model_key)
+
+ fetch_result = worker.fetch_inputs(batch, {fsd: feature_store})
+
+ raw_bytes = list(fetch_result[0].inputs)
+ assert raw_bytes
+ assert raw_bytes[0][:10] == persist_torch_tensor.read_bytes()[:10]
+ assert raw_bytes[1][:10] == body2[:10]
+ assert raw_bytes[2][:10] == body3[:10]
+
+
+def test_fetch_input_feature_store_missing() -> None:
+ """Verify that the ML worker fails to retrieves a tensor/input
+ when given an invalid (feature store) key"""
+ worker = MachineLearningWorkerCore
+
+ key = "bad-key"
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+ request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)])
+
+ model_key = ModelKey(key="test-model", descriptor=fsd)
+ batch = RequestBatch([request], None, model_key)
+
+ with pytest.raises(sse.SmartSimError) as ex:
+ worker.fetch_inputs(batch, {fsd: feature_store})
+
+ # ensure the error message includes key-identifying information
+ assert key in ex.value.args[0]
+
+
+@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+def test_fetch_input_memory(persist_torch_tensor: pathlib.Path) -> None:
+ """Verify that the ML worker successfully retrieves a tensor/input
+ when given a valid (file system) key"""
+ worker = MachineLearningWorkerCore
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+
+ key = "test-model"
+ feature_store[key] = persist_torch_tensor.read_bytes()
+ request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)])
+
+ model_key = ModelKey(key="test-model", descriptor=fsd)
+ batch = RequestBatch([request], None, model_key)
+
+ fetch_result = worker.fetch_inputs(batch, {fsd: feature_store})
+ assert fetch_result[0].inputs is not None
+
+
+def test_place_outputs() -> None:
+ """Verify outputs are shared using the feature store"""
+ worker = MachineLearningWorkerCore
+
+ key_name = "test-model"
+ feature_store = MemoryFeatureStore()
+ fsd = feature_store.descriptor
+
+ # create a key to retrieve from the feature store
+ keys = [
+ TensorKey(key=key_name + "1", descriptor=fsd),
+ TensorKey(key=key_name + "2", descriptor=fsd),
+ TensorKey(key=key_name + "3", descriptor=fsd),
+ ]
+ data = [b"abcdef", b"ghijkl", b"mnopqr"]
+
+ for fsk, v in zip(keys, data):
+ feature_store[fsk.key] = v
+
+ request = InferenceRequest(output_keys=keys)
+ transform_result = TransformOutputResult(data, [1], "c", "float32")
+
+ worker.place_output(request, transform_result, {fsd: feature_store})
+
+ for i in range(3):
+ assert feature_store[keys[i].key] == data[i]
+
+
+@pytest.mark.parametrize(
+ "key, descriptor",
+ [
+ pytest.param("", "desc", id="invalid key"),
+ pytest.param("key", "", id="invalid descriptor"),
+ ],
+)
+def test_invalid_tensorkey(key, descriptor) -> None:
+ with pytest.raises(ValueError):
+ fsk = TensorKey(key, descriptor)
diff --git a/tests/dragon_wlm/test_device_manager.py b/tests/dragon_wlm/test_device_manager.py
new file mode 100644
index 0000000000..d270e921cb
--- /dev/null
+++ b/tests/dragon_wlm/test_device_manager.py
@@ -0,0 +1,186 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import typing as t
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+from smartsim._core.mli.infrastructure.control.device_manager import (
+ DeviceManager,
+ WorkerDevice,
+)
+from smartsim._core.mli.infrastructure.storage.feature_store import (
+ FeatureStore,
+ ModelKey,
+ TensorKey,
+)
+from smartsim._core.mli.infrastructure.worker.worker import (
+ ExecuteResult,
+ FetchInputResult,
+ FetchModelResult,
+ InferenceRequest,
+ LoadModelResult,
+ MachineLearningWorkerBase,
+ RequestBatch,
+ TransformInputResult,
+ TransformOutputResult,
+)
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+
+class MockWorker(MachineLearningWorkerBase):
+ @staticmethod
+ def fetch_model(
+ batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore]
+ ) -> FetchModelResult:
+ if batch.has_raw_model:
+ return FetchModelResult(batch.raw_model)
+ return FetchModelResult(b"fetched_model")
+
+ @staticmethod
+ def load_model(
+ batch: RequestBatch, fetch_result: FetchModelResult, device: str
+ ) -> LoadModelResult:
+ return LoadModelResult(fetch_result.model_bytes)
+
+ @staticmethod
+ def transform_input(
+ batch: RequestBatch,
+ fetch_results: list[FetchInputResult],
+ mem_pool: "MemoryPool",
+ ) -> TransformInputResult:
+ return TransformInputResult(b"result", [slice(0, 1)], [[1, 2]], ["float32"])
+
+ @staticmethod
+ def execute(
+ batch: RequestBatch,
+ load_result: LoadModelResult,
+ transform_result: TransformInputResult,
+ device: str,
+ ) -> ExecuteResult:
+ return ExecuteResult(b"result", [slice(0, 1)])
+
+ @staticmethod
+ def transform_output(
+ batch: RequestBatch, execute_result: ExecuteResult
+ ) -> t.List[TransformOutputResult]:
+ return [TransformOutputResult(b"result", None, "c", "float32")]
+
+
+def test_worker_device():
+ worker_device = WorkerDevice("gpu:0")
+ assert worker_device.name == "gpu:0"
+
+ model_key = "my_model_key"
+ model = b"the model"
+
+ worker_device.add_model(model_key, model)
+
+ assert model_key in worker_device
+ assert worker_device.get_model(model_key) == model
+ worker_device.remove_model(model_key)
+
+ assert model_key not in worker_device
+
+
+def test_device_manager_model_in_request():
+
+ worker_device = WorkerDevice("gpu:0")
+ device_manager = DeviceManager(worker_device)
+
+ worker = MockWorker()
+
+ tensor_key = TensorKey(key="key", descriptor="desc")
+ output_key = TensorKey(key="key", descriptor="desc")
+ model_key = ModelKey(key="model key", descriptor="desc")
+
+ request = InferenceRequest(
+ model_key=model_key,
+ callback=None,
+ raw_inputs=None,
+ input_keys=[tensor_key],
+ input_meta=None,
+ output_keys=[output_key],
+ raw_model=b"raw model",
+ batch_size=0,
+ )
+
+ request_batch = RequestBatch(
+ [request],
+ TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]),
+ model_id=model_key,
+ )
+
+ with device_manager.get_device(
+ worker=worker, batch=request_batch, feature_stores={}
+ ) as returned_device:
+
+ assert returned_device == worker_device
+ assert worker_device.get_model(model_key.key) == b"raw model"
+
+ assert model_key.key not in worker_device
+
+
+def test_device_manager_model_key():
+
+ worker_device = WorkerDevice("gpu:0")
+ device_manager = DeviceManager(worker_device)
+
+ worker = MockWorker()
+
+ tensor_key = TensorKey(key="key", descriptor="desc")
+ output_key = TensorKey(key="key", descriptor="desc")
+ model_key = ModelKey(key="model key", descriptor="desc")
+
+ request = InferenceRequest(
+ model_key=model_key,
+ callback=None,
+ raw_inputs=None,
+ input_keys=[tensor_key],
+ input_meta=None,
+ output_keys=[output_key],
+ raw_model=None,
+ batch_size=0,
+ )
+
+ request_batch = RequestBatch(
+ [request],
+ TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]),
+ model_id=model_key,
+ )
+
+ with device_manager.get_device(
+ worker=worker, batch=request_batch, feature_stores={}
+ ) as returned_device:
+
+ assert returned_device == worker_device
+ assert worker_device.get_model(model_key.key) == b"fetched_model"
+
+ assert model_key.key in worker_device
diff --git a/tests/dragon_wlm/test_dragon_backend.py b/tests/dragon_wlm/test_dragon_backend.py
new file mode 100644
index 0000000000..dc98f5de75
--- /dev/null
+++ b/tests/dragon_wlm/test_dragon_backend.py
@@ -0,0 +1,308 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+import time
+import uuid
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+
+from smartsim._core.launcher.dragon.dragon_backend import DragonBackend
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.infrastructure.comm.event import (
+ OnCreateConsumer,
+ OnShutdownRequested,
+)
+from smartsim._core.mli.infrastructure.control.listener import (
+ ConsumerRegistrationListener,
+)
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim.log import get_logger
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+logger = get_logger(__name__)
+
+
+@pytest.fixture(scope="module")
+def the_backend() -> DragonBackend:
+ return DragonBackend(pid=9999)
+
+
+@pytest.mark.skip("Test is unreliable on build agent and may hang. TODO: Fix")
+def test_dragonbackend_start_listener(the_backend: DragonBackend):
+ """Verify the background process listening to consumer registration events
+ is up and processing messages as expected."""
+
+ # We need to let the backend create the backbone to continue
+ backbone = the_backend._create_backbone()
+ backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS)
+ backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER)
+
+ os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor
+
+ with pytest.raises(KeyError) as ex:
+ # we expect the value of the consumer to be empty until
+ # the listener start-up completes.
+ backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER]
+
+ assert "not found" in ex.value.args[0]
+
+ drg_process = the_backend.start_event_listener(cpu_affinity=[], gpu_affinity=[])
+
+ # # confirm there is a process still running
+ logger.info(f"Dragon process started: {drg_process}")
+ assert drg_process is not None, "Backend was unable to start event listener"
+ assert drg_process.puid != 0, "Process unique ID is empty"
+ assert drg_process.returncode is None, "Listener terminated early"
+
+ # wait for the event listener to come up
+ try:
+ config = backbone.wait_for(
+ [BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], timeout=30
+ )
+ # verify result was in the returned configuration map
+ assert config[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER]
+ except Exception:
+ raise KeyError(
+ f"Unable to locate {BackboneFeatureStore.MLI_REGISTRAR_CONSUMER}"
+ "in the backbone"
+ )
+
+ # wait_for ensures the normal retrieval will now work, error-free
+ descriptor = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER]
+ assert descriptor is not None
+
+ # register a new listener channel
+ comm_channel = DragonCommChannel.from_descriptor(descriptor)
+ mock_descriptor = str(uuid.uuid4())
+ event = OnCreateConsumer("test_dragonbackend_start_listener", mock_descriptor, [])
+
+ event_bytes = bytes(event)
+ comm_channel.send(event_bytes)
+
+ subscriber_list = []
+
+ # Give the channel time to write the message and the listener time to handle it
+ for i in range(20):
+ time.sleep(1)
+ # Retrieve the subscriber list from the backbone and verify it is updated
+ if subscriber_list := backbone.notification_channels:
+ logger.debug(f"The subscriber list was populated after {i} iterations")
+ break
+
+ assert mock_descriptor in subscriber_list
+
+ # now send a shutdown message to terminate the listener
+ return_code = drg_process.returncode
+
+ # clean up if the OnShutdownRequested wasn't properly handled
+ if return_code is None and drg_process.is_alive:
+ drg_process.kill()
+ drg_process.join()
+
+
+def test_dragonbackend_backend_consumer(the_backend: DragonBackend):
+ """Verify the listener background process updates the appropriate
+ value in the backbone."""
+
+ # We need to let the backend create the backbone to continue
+ backbone = the_backend._create_backbone()
+ backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS)
+ backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER)
+
+ assert backbone._allow_reserved_writes
+
+ # create listener with `as_service=False` to perform a single loop iteration
+ listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False)
+
+ logger.debug(f"backbone loaded? {listener._backbone}")
+ logger.debug(f"listener created? {listener}")
+
+ try:
+ # call the service execute method directly to trigger
+ # the entire service lifecycle
+ listener.execute()
+
+ consumer_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER]
+ logger.debug(f"MLI_REGISTRAR_CONSUMER: {consumer_desc}")
+
+ assert consumer_desc
+ except Exception as ex:
+ logger.info("")
+ finally:
+ listener._on_shutdown()
+
+
+def test_dragonbackend_event_handled(the_backend: DragonBackend):
+ """Verify the event listener process updates the appropriate
+ value in the backbone when an event is received and again on shutdown.
+ """
+ # We need to let the backend create the backbone to continue
+ backbone = the_backend._create_backbone()
+ backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS)
+ backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER)
+
+ # create the listener to be tested
+ listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False)
+
+ assert listener._backbone, "The listener is not attached to a backbone"
+
+ try:
+ # set up the listener but don't let the service event loop start
+ listener._create_eventing() # listener.execute()
+
+ # grab the channel descriptor so we can simulate registrations
+ channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER]
+ comm_channel = DragonCommChannel.from_descriptor(channel_desc)
+
+ num_events = 5
+ events = []
+ for i in range(num_events):
+ # register some mock consumers using the backend channel
+ event = OnCreateConsumer(
+ "test_dragonbackend_event_handled",
+ f"mock-consumer-descriptor-{uuid.uuid4()}",
+ [],
+ )
+ event_bytes = bytes(event)
+ comm_channel.send(event_bytes)
+ events.append(event)
+
+ # run few iterations of the event loop in case it takes a few cycles to write
+ for _ in range(20):
+ listener._on_iteration()
+ # Grab the value that should be getting updated
+ notify_consumers = set(backbone.notification_channels)
+ if len(notify_consumers) == len(events):
+ logger.info(f"Retrieved all consumers after {i} listen cycles")
+ break
+
+ # ... and confirm that all the mock consumer descriptors are registered
+ assert set([e.descriptor for e in events]) == set(notify_consumers)
+ logger.info(f"Number of registered consumers: {len(notify_consumers)}")
+
+ except Exception as ex:
+ logger.exception(f"test_dragonbackend_event_handled - exception occurred: {ex}")
+ assert False
+ finally:
+ # shutdown should unregister a registration listener
+ listener._on_shutdown()
+
+ for i in range(10):
+ if BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in backbone:
+ logger.debug(f"The listener was removed after {i} iterations")
+ channel_desc = None
+ break
+
+ # we should see that there is no listener registered
+ assert not channel_desc, "Listener shutdown failed to clean up the backbone"
+
+
+def test_dragonbackend_shutdown_event(the_backend: DragonBackend):
+ """Verify the background process shuts down when it receives a
+ shutdown request."""
+
+ # We need to let the backend create the backbone to continue
+ backbone = the_backend._create_backbone()
+ backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS)
+ backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER)
+
+ listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=True)
+
+ # set up the listener but don't let the listener loop start
+ listener._create_eventing() # listener.execute()
+
+ # grab the channel descriptor so we can publish to it
+ channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER]
+ comm_channel = DragonCommChannel.from_descriptor(channel_desc)
+
+ assert listener._consumer.listening, "Listener isn't ready to listen"
+
+ # send a shutdown request...
+ event = OnShutdownRequested("test_dragonbackend_shutdown_event")
+ event_bytes = bytes(event)
+ comm_channel.send(event_bytes, 0.1)
+
+ # execute should encounter the shutdown and exit
+ listener.execute()
+
+ # ...and confirm the listener is now cancelled
+ assert not listener._consumer.listening
+
+
+@pytest.mark.parametrize("health_check_frequency", [10, 20])
+def test_dragonbackend_shutdown_on_health_check(
+ the_backend: DragonBackend,
+ health_check_frequency: float,
+):
+ """Verify that the event listener automatically shuts down when
+ a new listener is registered in its place.
+
+ :param health_check_frequency: The expected frequency of service health check
+ invocations"""
+
+ # We need to let the backend create the backbone to continue
+ backbone = the_backend._create_backbone()
+ backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS)
+ backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER)
+
+ listener = ConsumerRegistrationListener(
+ backbone,
+ 1.0,
+ 1.0,
+ as_service=True, # allow service to run long enough to health check
+ health_check_frequency=health_check_frequency,
+ )
+
+ # set up the listener but don't let the listener loop start
+ listener._create_eventing() # listener.execute()
+ assert listener._consumer.listening, "Listener wasn't ready to listen"
+
+ # Replace the consumer descriptor in the backbone to trigger
+ # an automatic shutdown
+ backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = str(uuid.uuid4())
+
+ # set the last health check manually to verify the duration
+ start_at = time.time()
+ listener._last_health_check = time.time()
+
+ # run execute to let the service trigger health checks
+ listener.execute()
+ elapsed = time.time() - start_at
+
+ # confirm the frequency of the health check was honored
+ assert elapsed >= health_check_frequency
+
+ # ...and confirm the listener is now cancelled
+ assert (
+ not listener._consumer.listening
+ ), "Listener was not automatically shutdown by the health check"
diff --git a/tests/dragon_wlm/test_dragon_comm_utils.py b/tests/dragon_wlm/test_dragon_comm_utils.py
new file mode 100644
index 0000000000..a6f9c206a4
--- /dev/null
+++ b/tests/dragon_wlm/test_dragon_comm_utils.py
@@ -0,0 +1,257 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import base64
+import pathlib
+import uuid
+
+import pytest
+
+from smartsim.error.errors import SmartSimError
+
+dragon = pytest.importorskip("dragon")
+
+# isort: off
+import dragon.channels as dch
+import dragon.infrastructure.parameters as dp
+import dragon.managed_memory as dm
+import dragon.fli as fli
+
+# isort: on
+
+from smartsim._core.mli.comm.channel import dragon_util
+from smartsim.log import get_logger
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+logger = get_logger(__name__)
+
+
+@pytest.fixture(scope="function")
+def the_pool() -> dm.MemoryPool:
+ """Creates a memory pool."""
+ raw_pool_descriptor = dp.this_process.default_pd
+ descriptor_ = base64.b64decode(raw_pool_descriptor)
+
+ pool = dm.MemoryPool.attach(descriptor_)
+ return pool
+
+
+@pytest.fixture(scope="function")
+def the_channel() -> dch.Channel:
+ """Creates a Channel attached to the local memory pool."""
+ channel = dch.Channel.make_process_local()
+ return channel
+
+
+@pytest.fixture(scope="function")
+def the_fli(the_channel) -> fli.FLInterface:
+ """Creates an FLI attached to the local memory pool."""
+ fli_ = fli.FLInterface(main_ch=the_channel, manager_ch=None)
+ return fli_
+
+
+def test_descriptor_to_channel_empty() -> None:
+ """Verify that `descriptor_to_channel` raises an exception when
+ provided with an empty descriptor."""
+ descriptor = ""
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.descriptor_to_channel(descriptor)
+
+ assert "empty" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "descriptor",
+ ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()],
+)
+def test_descriptor_to_channel_b64fail(descriptor: str) -> None:
+ """Verify that `descriptor_to_channel` raises an exception when
+ provided with an incorrectly encoded descriptor.
+
+ :param descriptor: A descriptor that is not properly base64 encoded
+ """
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.descriptor_to_channel(descriptor)
+
+ assert "base64" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "descriptor",
+ [str(uuid.uuid4())],
+)
+def test_descriptor_to_channel_channel_fail(descriptor: str) -> None:
+ """Verify that `descriptor_to_channel` raises an exception when a correctly
+ formatted descriptor that does not describe a real channel is passed.
+
+ :param descriptor: A descriptor that is not properly base64 encoded
+ """
+
+ with pytest.raises(SmartSimError) as ex:
+ dragon_util.descriptor_to_channel(descriptor)
+
+ # ensure we're receiving the right exception
+ assert "address" in ex.value.args[0]
+ assert "channel" in ex.value.args[0]
+
+
+def test_descriptor_to_channel_channel_not_available(the_channel: dch.Channel) -> None:
+ """Verify that `descriptor_to_channel` raises an exception when a channel
+ is no longer available.
+
+ :param the_channel: A dragon channel
+ """
+
+ # get a good descriptor & wipe out the channel so it can't be attached
+ descriptor = dragon_util.channel_to_descriptor(the_channel)
+ the_channel.destroy()
+
+ with pytest.raises(SmartSimError) as ex:
+ dragon_util.descriptor_to_channel(descriptor)
+
+ assert "address" in ex.value.args[0]
+
+
+def test_descriptor_to_channel_happy_path(the_channel: dch.Channel) -> None:
+ """Verify that `descriptor_to_channel` works as expected when provided
+ a valid descriptor
+
+ :param the_channel: A dragon channel
+ """
+
+ # get a good descriptor
+ descriptor = dragon_util.channel_to_descriptor(the_channel)
+
+ reattached = dragon_util.descriptor_to_channel(descriptor)
+ assert reattached
+
+ # and just make sure creation of the descriptor is transitive
+ assert dragon_util.channel_to_descriptor(reattached) == descriptor
+
+
+def test_descriptor_to_fli_empty() -> None:
+ """Verify that `descriptor_to_fli` raises an exception when
+ provided with an empty descriptor."""
+ descriptor = ""
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.descriptor_to_fli(descriptor)
+
+ assert "empty" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "descriptor",
+ ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()],
+)
+def test_descriptor_to_fli_b64fail(descriptor: str) -> None:
+ """Verify that `descriptor_to_fli` raises an exception when
+ provided with an incorrectly encoded descriptor.
+
+ :param descriptor: A descriptor that is not properly base64 encoded
+ """
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.descriptor_to_fli(descriptor)
+
+ assert "base64" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "descriptor",
+ [str(uuid.uuid4())],
+)
+def test_descriptor_to_fli_fli_fail(descriptor: str) -> None:
+ """Verify that `descriptor_to_fli` raises an exception when a correctly
+ formatted descriptor that does not describe a real FLI is passed.
+
+ :param descriptor: A descriptor that is not properly base64 encoded
+ """
+
+ with pytest.raises(SmartSimError) as ex:
+ dragon_util.descriptor_to_fli(descriptor)
+
+ # ensure we're receiving the right exception
+ assert "address" in ex.value.args[0]
+ assert "fli" in ex.value.args[0].lower()
+
+
+def test_descriptor_to_fli_fli_not_available(
+ the_fli: fli.FLInterface, the_channel: dch.Channel
+) -> None:
+ """Verify that `descriptor_to_fli` raises an exception when a channel
+ is no longer available.
+
+ :param the_fli: A dragon FLInterface
+ :param the_channel: A dragon channel
+ """
+
+ # get a good descriptor & wipe out the FLI so it can't be attached
+ descriptor = dragon_util.channel_to_descriptor(the_fli)
+ the_fli.destroy()
+ the_channel.destroy()
+
+ with pytest.raises(SmartSimError) as ex:
+ dragon_util.descriptor_to_fli(descriptor)
+
+ # ensure we're receiving the right exception
+ assert "address" in ex.value.args[0]
+
+
+def test_descriptor_to_fli_happy_path(the_fli: dch.Channel) -> None:
+ """Verify that `descriptor_to_fli` works as expected when provided
+ a valid descriptor
+
+ :param the_fli: A dragon FLInterface
+ """
+
+ # get a good descriptor
+ descriptor = dragon_util.channel_to_descriptor(the_fli)
+
+ reattached = dragon_util.descriptor_to_fli(descriptor)
+ assert reattached
+
+ # and just make sure creation of the descriptor is transitive
+ assert dragon_util.channel_to_descriptor(reattached) == descriptor
+
+
+def test_pool_to_descriptor_empty() -> None:
+ """Verify that `pool_to_descriptor` raises an exception when
+ provided with a null pool."""
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.pool_to_descriptor(None)
+
+
+def test_pool_to_happy_path(the_pool) -> None:
+ """Verify that `pool_to_descriptor` creates a descriptor
+ when supplied with a valid memory pool."""
+
+ descriptor = dragon_util.pool_to_descriptor(the_pool)
+ assert descriptor
diff --git a/tests/dragon_wlm/test_dragon_ddict_utils.py b/tests/dragon_wlm/test_dragon_ddict_utils.py
new file mode 100644
index 0000000000..c8bf687ef1
--- /dev/null
+++ b/tests/dragon_wlm/test_dragon_ddict_utils.py
@@ -0,0 +1,117 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+# isort: off
+import dragon.data.ddict.ddict as dragon_ddict
+
+# isort: on
+
+from smartsim._core.mli.infrastructure.storage import dragon_util
+from smartsim.log import get_logger
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+logger = get_logger(__name__)
+
+
+@pytest.mark.parametrize(
+ "num_nodes, num_managers, mem_per_node",
+ [
+ pytest.param(1, 1, 3 * 1024**2, id="3MB, Bare minimum allocation"),
+ pytest.param(2, 2, 128 * 1024**2, id="128 MB allocation, 2 nodes, 2 mgr"),
+ pytest.param(2, 1, 512 * 1024**2, id="512 MB allocation, 2 nodes, 1 mgr"),
+ ],
+)
+def test_dragon_storage_util_create_ddict(
+ num_nodes: int,
+ num_managers: int,
+ mem_per_node: int,
+):
+ """Verify that a dragon dictionary is successfully created.
+
+ :param num_nodes: Number of ddict nodes to attempt to create
+ :param num_managers: Number of managers per node to request
+ :param num_managers: Memory to allocate per node
+ """
+ ddict = dragon_util.create_ddict(num_nodes, num_managers, mem_per_node)
+
+ assert ddict is not None
+
+
+@pytest.mark.parametrize(
+ "num_nodes, num_managers, mem_per_node",
+ [
+ pytest.param(-1, 1, 3 * 1024**2, id="Negative Node Count"),
+ pytest.param(0, 1, 3 * 1024**2, id="Invalid Node Count"),
+ pytest.param(1, -1, 3 * 1024**2, id="Negative Mgr Count"),
+ pytest.param(1, 0, 3 * 1024**2, id="Invalid Mgr Count"),
+ pytest.param(1, 1, -3 * 1024**2, id="Negative Mem Per Node"),
+ pytest.param(1, 1, (3 * 1024**2) - 1, id="Invalid Mem Per Node"),
+ pytest.param(1, 1, 0 * 1024**2, id="No Mem Per Node"),
+ ],
+)
+def test_dragon_storage_util_create_ddict_validators(
+ num_nodes: int,
+ num_managers: int,
+ mem_per_node: int,
+):
+ """Verify that a dragon dictionary is successfully created.
+
+ :param num_nodes: Number of ddict nodes to attempt to create
+ :param num_managers: Number of managers per node to request
+ :param num_managers: Memory to allocate per node
+ """
+ with pytest.raises(ValueError):
+ dragon_util.create_ddict(num_nodes, num_managers, mem_per_node)
+
+
+def test_dragon_storage_util_get_ddict_descriptor(the_storage: dragon_ddict.DDict):
+ """Verify that a descriptor is created.
+
+ :param the_storage: A pre-allocated ddict
+ """
+ value = dragon_util.ddict_to_descriptor(the_storage)
+
+ assert isinstance(value, str)
+ assert len(value) > 0
+
+
+def test_dragon_storage_util_get_ddict_from_descriptor(the_storage: dragon_ddict.DDict):
+ """Verify that a ddict is created from a descriptor.
+
+ :param the_storage: A pre-allocated ddict
+ """
+ descriptor = dragon_util.ddict_to_descriptor(the_storage)
+
+ value = dragon_util.descriptor_to_ddict(descriptor)
+
+ assert value is not None
+ assert isinstance(value, dragon_ddict.DDict)
+ assert dragon_util.ddict_to_descriptor(value) == descriptor
diff --git a/tests/dragon_wlm/test_environment_loader.py b/tests/dragon_wlm/test_environment_loader.py
new file mode 100644
index 0000000000..07b2a45c1c
--- /dev/null
+++ b/tests/dragon_wlm/test_environment_loader.py
@@ -0,0 +1,147 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+import dragon.data.ddict.ddict as dragon_ddict
+import dragon.utils as du
+from dragon.fli import FLInterface
+
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
+from smartsim._core.mli.comm.channel.dragon_util import create_local
+from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ DragonFeatureStore,
+)
+from smartsim.error.errors import SmartSimError
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+
+@pytest.mark.parametrize(
+ "content",
+ [
+ pytest.param(b"a"),
+ pytest.param(b"new byte string"),
+ ],
+)
+def test_environment_loader_attach_fli(content: bytes, monkeypatch: pytest.MonkeyPatch):
+ """A descriptor can be stored, loaded, and reattached."""
+ chan = create_local()
+ queue = FLInterface(main_ch=chan)
+ monkeypatch.setenv(
+ EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR,
+ du.B64.bytes_to_str(queue.serialize()),
+ )
+
+ config = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=DragonCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+ config_queue = config.get_queue()
+
+ _ = config_queue.send(content)
+
+ old_recv = queue.recvh()
+ result, _ = old_recv.recv_bytes()
+ assert result == content
+
+
+def test_environment_loader_serialize_fli(monkeypatch: pytest.MonkeyPatch):
+ """The serialized descriptors of a loaded and unloaded
+ queue are the same."""
+ chan = create_local()
+ queue = FLInterface(main_ch=chan)
+ monkeypatch.setenv(
+ EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR,
+ du.B64.bytes_to_str(queue.serialize()),
+ )
+
+ config = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=DragonCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+ config_queue = config.get_queue()
+ assert config_queue._fli.serialize() == queue.serialize()
+
+
+def test_environment_loader_flifails(monkeypatch: pytest.MonkeyPatch):
+ """An incorrect serialized descriptor will fails to attach."""
+
+ monkeypatch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "randomstring")
+
+ config = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=None,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+
+ with pytest.raises(SmartSimError):
+ config.get_queue()
+
+
+def test_environment_loader_backbone_load_dfs(
+ monkeypatch: pytest.MonkeyPatch, the_storage: dragon_ddict.DDict
+):
+ """Verify the dragon feature store is loaded correctly by the
+ EnvironmentConfigLoader to demonstrate featurestore_factory correctness."""
+ feature_store = DragonFeatureStore(the_storage)
+ monkeypatch.setenv(
+ EnvironmentConfigLoader.BACKBONE_ENV_VAR, feature_store.descriptor
+ )
+
+ config = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=None,
+ queue_factory=None,
+ )
+
+ print(f"calling config.get_backbone: `{feature_store.descriptor}`")
+
+ backbone = config.get_backbone()
+ assert backbone is not None
+
+
+def test_environment_variables_not_set(monkeypatch: pytest.MonkeyPatch):
+ """EnvironmentConfigLoader getters return None when environment
+ variables are not set."""
+ with monkeypatch.context() as patch:
+ patch.setenv(EnvironmentConfigLoader.BACKBONE_ENV_VAR, "")
+ patch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "")
+
+ config = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=DragonCommChannel.from_descriptor,
+ queue_factory=DragonCommChannel.from_descriptor,
+ )
+ assert config.get_backbone() is None
+ assert config.get_queue() is None
diff --git a/tests/dragon_wlm/test_error_handling.py b/tests/dragon_wlm/test_error_handling.py
new file mode 100644
index 0000000000..aacd47b556
--- /dev/null
+++ b/tests/dragon_wlm/test_error_handling.py
@@ -0,0 +1,511 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import typing as t
+from unittest.mock import MagicMock
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+import multiprocessing as mp
+
+from dragon.channels import Channel
+from dragon.data.ddict.ddict import DDict
+from dragon.fli import FLInterface
+from dragon.mpbridge.queues import DragonQueue
+
+from smartsim._core.mli.comm.channel.channel import CommChannelBase
+from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
+from smartsim._core.mli.infrastructure.control.request_dispatcher import (
+ RequestDispatcher,
+)
+from smartsim._core.mli.infrastructure.control.worker_manager import (
+ WorkerManager,
+ exception_handler,
+)
+from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
+ DragonFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.feature_store import (
+ FeatureStore,
+ ModelKey,
+ TensorKey,
+)
+from smartsim._core.mli.infrastructure.worker.worker import (
+ ExecuteResult,
+ FetchInputResult,
+ FetchModelResult,
+ InferenceRequest,
+ LoadModelResult,
+ MachineLearningWorkerBase,
+ RequestBatch,
+ TransformInputResult,
+ TransformOutputResult,
+)
+from smartsim._core.mli.message_handler import MessageHandler
+from smartsim._core.mli.mli_schemas.response.response_capnp import ResponseBuilder
+
+from .utils.channel import FileSystemCommChannel
+from .utils.worker import IntegratedTorchWorker
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+
+@pytest.fixture(scope="module")
+def app_feature_store(the_storage) -> FeatureStore:
+ # create a standalone feature store to mimic a user application putting
+ # data into an application-owned resource (app should not access backbone)
+ app_fs = DragonFeatureStore(the_storage)
+ return app_fs
+
+
+@pytest.fixture
+def setup_worker_manager_model_bytes(
+ test_dir: str,
+ monkeypatch: pytest.MonkeyPatch,
+ backbone_descriptor: str,
+ app_feature_store: FeatureStore,
+ the_worker_channel: DragonFLIChannel,
+):
+ integrated_worker_type = IntegratedTorchWorker
+
+ monkeypatch.setenv(
+ BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor
+ )
+ # Put backbone descriptor into env var for the `EnvironmentConfigLoader`
+ monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor)
+
+ config_loader = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=FileSystemCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+
+ dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0)
+
+ worker_manager = WorkerManager(
+ config_loader=config_loader,
+ worker_type=integrated_worker_type,
+ dispatcher_queue=dispatcher_task_queue,
+ as_service=False,
+ cooldown=3,
+ )
+
+ tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
+ output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
+
+ inf_request = InferenceRequest(
+ model_key=None,
+ callback=None,
+ raw_inputs=None,
+ input_keys=[tensor_key],
+ input_meta=None,
+ output_keys=[output_key],
+ raw_model=b"model",
+ batch_size=0,
+ )
+
+ model_id = ModelKey(key="key", descriptor=app_feature_store.descriptor)
+
+ request_batch = RequestBatch(
+ [inf_request],
+ TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]),
+ model_id=model_id,
+ )
+
+ dispatcher_task_queue.put(request_batch)
+ return worker_manager, integrated_worker_type
+
+
+@pytest.fixture
+def setup_worker_manager_model_key(
+ test_dir: str,
+ monkeypatch: pytest.MonkeyPatch,
+ backbone_descriptor: str,
+ app_feature_store: FeatureStore,
+ the_worker_channel: DragonFLIChannel,
+):
+ integrated_worker_type = IntegratedTorchWorker
+
+ monkeypatch.setenv(
+ BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor
+ )
+ # Put backbone descriptor into env var for the `EnvironmentConfigLoader`
+ monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor)
+
+ config_loader = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=FileSystemCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+
+ dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0)
+
+ worker_manager = WorkerManager(
+ config_loader=config_loader,
+ worker_type=integrated_worker_type,
+ dispatcher_queue=dispatcher_task_queue,
+ as_service=False,
+ cooldown=3,
+ )
+
+ tensor_key = TensorKey(key="key", descriptor=app_feature_store.descriptor)
+ output_key = TensorKey(key="key", descriptor=app_feature_store.descriptor)
+ model_id = ModelKey(key="model key", descriptor=app_feature_store.descriptor)
+
+ request = InferenceRequest(
+ model_key=model_id,
+ callback=None,
+ raw_inputs=None,
+ input_keys=[tensor_key],
+ input_meta=None,
+ output_keys=[output_key],
+ raw_model=b"model",
+ batch_size=0,
+ )
+ request_batch = RequestBatch(
+ [request],
+ TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]),
+ model_id=model_id,
+ )
+
+ dispatcher_task_queue.put(request_batch)
+ return worker_manager, integrated_worker_type
+
+
+@pytest.fixture
+def setup_request_dispatcher_model_bytes(
+ test_dir: str,
+ monkeypatch: pytest.MonkeyPatch,
+ backbone_descriptor: str,
+ app_feature_store: FeatureStore,
+ the_worker_channel: DragonFLIChannel,
+):
+ integrated_worker_type = IntegratedTorchWorker
+
+ monkeypatch.setenv(
+ BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor
+ )
+ # Put backbone descriptor into env var for the `EnvironmentConfigLoader`
+ monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor)
+
+ config_loader = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=FileSystemCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+
+ request_dispatcher = RequestDispatcher(
+ batch_timeout=0,
+ batch_size=0,
+ config_loader=config_loader,
+ worker_type=integrated_worker_type,
+ )
+ request_dispatcher._on_start()
+
+ tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
+ output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
+ model = MessageHandler.build_model(b"model", "model name", "v 0.0.1")
+ request = MessageHandler.build_request(
+ test_dir, model, [tensor_key], [output_key], [], None
+ )
+ ser_request = MessageHandler.serialize_request(request)
+
+ request_dispatcher._incoming_channel.send(ser_request)
+
+ return request_dispatcher, integrated_worker_type
+
+
+@pytest.fixture
+def setup_request_dispatcher_model_key(
+ test_dir: str,
+ monkeypatch: pytest.MonkeyPatch,
+ backbone_descriptor: str,
+ app_feature_store: FeatureStore,
+ the_worker_channel: DragonFLIChannel,
+):
+ integrated_worker_type = IntegratedTorchWorker
+
+ monkeypatch.setenv(
+ BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor
+ )
+ # Put backbone descriptor into env var for the `EnvironmentConfigLoader`
+ monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor)
+
+ config_loader = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=FileSystemCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+
+ request_dispatcher = RequestDispatcher(
+ batch_timeout=0,
+ batch_size=0,
+ config_loader=config_loader,
+ worker_type=integrated_worker_type,
+ )
+ request_dispatcher._on_start()
+
+ tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
+ output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
+ model_key = MessageHandler.build_model_key(
+ key="model key", descriptor=app_feature_store.descriptor
+ )
+ request = MessageHandler.build_request(
+ test_dir, model_key, [tensor_key], [output_key], [], None
+ )
+ ser_request = MessageHandler.serialize_request(request)
+
+ request_dispatcher._incoming_channel.send(ser_request)
+
+ return request_dispatcher, integrated_worker_type
+
+
+def mock_pipeline_stage(
+ monkeypatch: pytest.MonkeyPatch,
+ integrated_worker: MachineLearningWorkerBase,
+ stage: str,
+) -> t.Callable[[t.Any], ResponseBuilder]:
+ def mock_stage(*args: t.Any, **kwargs: t.Any) -> None:
+ raise ValueError(f"Simulated error in {stage}")
+
+ monkeypatch.setattr(integrated_worker, stage, mock_stage)
+ mock_reply_fn = MagicMock()
+ mock_response = MagicMock()
+ mock_response.schema.node.displayName = "Response"
+ mock_reply_fn.return_value = mock_response
+
+ monkeypatch.setattr(
+ "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply",
+ mock_reply_fn,
+ )
+
+ mock_reply_channel = MagicMock()
+ mock_reply_channel.send = MagicMock()
+
+ def mock_exception_handler(
+ exc: Exception, reply_channel: CommChannelBase, failure_message: str
+ ) -> None:
+ exception_handler(exc, mock_reply_channel, failure_message)
+
+ monkeypatch.setattr(
+ "smartsim._core.mli.infrastructure.control.worker_manager.exception_handler",
+ mock_exception_handler,
+ )
+
+ monkeypatch.setattr(
+ "smartsim._core.mli.infrastructure.control.request_dispatcher.exception_handler",
+ mock_exception_handler,
+ )
+
+ return mock_reply_fn
+
+
+@pytest.mark.parametrize(
+ "setup_worker_manager",
+ [
+ pytest.param("setup_worker_manager_model_bytes"),
+ pytest.param("setup_worker_manager_model_key"),
+ ],
+)
+@pytest.mark.parametrize(
+ "stage, error_message",
+ [
+ pytest.param(
+ "fetch_model",
+ "Error loading model on device or getting device.",
+ id="fetch model",
+ ),
+ pytest.param(
+ "load_model",
+ "Error loading model on device or getting device.",
+ id="load model",
+ ),
+ pytest.param("execute", "Error while executing.", id="execute"),
+ pytest.param(
+ "transform_output",
+ "Error while transforming the output.",
+ id="transform output",
+ ),
+ pytest.param(
+ "place_output", "Error while placing the output.", id="place output"
+ ),
+ ],
+)
+def test_wm_pipeline_stage_errors_handled(
+ request: pytest.FixtureRequest,
+ setup_worker_manager: str,
+ monkeypatch: pytest.MonkeyPatch,
+ stage: str,
+ error_message: str,
+) -> None:
+ """Ensures that the worker manager does not crash after a failure in various pipeline stages"""
+ worker_manager, integrated_worker_type = request.getfixturevalue(
+ setup_worker_manager
+ )
+ integrated_worker = worker_manager._worker
+
+ worker_manager._on_start()
+ device = worker_manager._device_manager._device
+ mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage)
+
+ if stage not in ["fetch_model"]:
+ monkeypatch.setattr(
+ integrated_worker,
+ "fetch_model",
+ MagicMock(return_value=FetchModelResult(b"result_bytes")),
+ )
+ if stage not in ["fetch_model", "load_model"]:
+ monkeypatch.setattr(
+ integrated_worker,
+ "load_model",
+ MagicMock(return_value=LoadModelResult(b"result_bytes")),
+ )
+ monkeypatch.setattr(
+ device,
+ "get_model",
+ MagicMock(return_value=b"result_bytes"),
+ )
+ if stage not in [
+ "fetch_model",
+ "execute",
+ ]:
+ monkeypatch.setattr(
+ integrated_worker,
+ "execute",
+ MagicMock(return_value=ExecuteResult(b"result_bytes", [slice(0, 1)])),
+ )
+ if stage not in [
+ "fetch_model",
+ "execute",
+ "transform_output",
+ ]:
+ monkeypatch.setattr(
+ integrated_worker,
+ "transform_output",
+ MagicMock(
+ return_value=[TransformOutputResult(b"result", [], "c", "float32")]
+ ),
+ )
+
+ worker_manager._on_iteration()
+
+ mock_reply_fn.assert_called_once()
+ mock_reply_fn.assert_called_with("fail", error_message)
+
+
+@pytest.mark.parametrize(
+ "setup_request_dispatcher",
+ [
+ pytest.param("setup_request_dispatcher_model_bytes"),
+ pytest.param("setup_request_dispatcher_model_key"),
+ ],
+)
+@pytest.mark.parametrize(
+ "stage, error_message",
+ [
+ pytest.param(
+ "fetch_inputs",
+ "Error fetching input.",
+ id="fetch input",
+ ),
+ pytest.param(
+ "transform_input",
+ "Error transforming input.",
+ id="transform input",
+ ),
+ ],
+)
+def test_dispatcher_pipeline_stage_errors_handled(
+ request: pytest.FixtureRequest,
+ setup_request_dispatcher: str,
+ monkeypatch: pytest.MonkeyPatch,
+ stage: str,
+ error_message: str,
+) -> None:
+ """Ensures that the request dispatcher does not crash after a failure in various pipeline stages"""
+ request_dispatcher, integrated_worker_type = request.getfixturevalue(
+ setup_request_dispatcher
+ )
+ integrated_worker = request_dispatcher._worker
+
+ mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage)
+
+ if stage not in ["fetch_inputs"]:
+ monkeypatch.setattr(
+ integrated_worker,
+ "fetch_inputs",
+ MagicMock(return_value=[FetchInputResult(result=[b"result"], meta=None)]),
+ )
+
+ request_dispatcher._on_iteration()
+
+ mock_reply_fn.assert_called_once()
+ mock_reply_fn.assert_called_with("fail", error_message)
+
+
+def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Ensures that the worker manager does not crash after a failure in the
+ execute pipeline stage"""
+
+ mock_reply_channel = MagicMock()
+ mock_reply_channel.send = MagicMock()
+
+ mock_reply_fn = MagicMock()
+
+ mock_response = MagicMock()
+ mock_response.schema.node.displayName = "Response"
+ mock_reply_fn.return_value = mock_response
+
+ monkeypatch.setattr(
+ "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply",
+ mock_reply_fn,
+ )
+
+ test_exception = ValueError("Test ValueError")
+ exception_handler(
+ test_exception, mock_reply_channel, "Failure while fetching the model."
+ )
+
+ mock_reply_fn.assert_called_once()
+ mock_reply_fn.assert_called_with("fail", "Failure while fetching the model.")
+
+
+def test_dragon_feature_store_invalid_storage():
+ """Verify that attempting to create a DragonFeatureStore without storage fails."""
+ storage = None
+
+ with pytest.raises(ValueError) as ex:
+ DragonFeatureStore(storage)
+
+ assert "storage" in ex.value.args[0].lower()
+ assert "required" in ex.value.args[0].lower()
diff --git a/tests/dragon_wlm/test_event_consumer.py b/tests/dragon_wlm/test_event_consumer.py
new file mode 100644
index 0000000000..8a241bab19
--- /dev/null
+++ b/tests/dragon_wlm/test_event_consumer.py
@@ -0,0 +1,386 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import time
+import typing as t
+from unittest import mock
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.comm.channel.dragon_util import create_local
+from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster
+from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer
+from smartsim._core.mli.infrastructure.comm.event import (
+ OnCreateConsumer,
+ OnShutdownRequested,
+ OnWriteFeatureStore,
+)
+from smartsim._core.mli.infrastructure.control.listener import (
+ ConsumerRegistrationListener,
+)
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+# isort: off
+from dragon import fli
+from dragon.channels import Channel
+
+# isort: on
+
+if t.TYPE_CHECKING:
+ import conftest
+
+
+# The tests in this file must run in a dragon environment
+pytestmark = pytest.mark.dragon
+
+
+def test_eventconsumer_eventpublisher_integration(
+ the_backbone: t.Any, test_dir: str
+) -> None:
+ """Verify that the publisher and consumer integrate as expected when
+ multiple publishers and consumers are sending simultaneously. This
+ test closely tracks the test in tests/test_featurestore_base.py also named
+ test_eventconsumer_eventpublisher_integration but requires dragon entities.
+
+ :param the_backbone: The BackboneFeatureStore to use
+ :param test_dir: Automatically generated unique working
+ directories for individual test outputs
+ """
+
+ wmgr_channel = DragonCommChannel(create_local())
+ capp_channel = DragonCommChannel(create_local())
+ back_channel = DragonCommChannel(create_local())
+
+ wmgr_consumer_descriptor = wmgr_channel.descriptor
+ capp_consumer_descriptor = capp_channel.descriptor
+ back_consumer_descriptor = back_channel.descriptor
+
+ # create some consumers to receive messages
+ wmgr_consumer = EventConsumer(
+ wmgr_channel,
+ the_backbone,
+ filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN],
+ )
+ capp_consumer = EventConsumer(
+ capp_channel,
+ the_backbone,
+ )
+ back_consumer = EventConsumer(
+ back_channel,
+ the_backbone,
+ filters=[OnCreateConsumer.CONSUMER_CREATED],
+ )
+
+ # create some broadcasters to publish messages
+ mock_worker_mgr = EventBroadcaster(
+ the_backbone,
+ channel_factory=DragonCommChannel.from_descriptor,
+ )
+ mock_client_app = EventBroadcaster(
+ the_backbone,
+ channel_factory=DragonCommChannel.from_descriptor,
+ )
+
+ # register all of the consumers even though the OnCreateConsumer really should
+ # trigger its registration. event processing is tested elsewhere.
+ the_backbone.notification_channels = [
+ wmgr_consumer_descriptor,
+ capp_consumer_descriptor,
+ back_consumer_descriptor,
+ ]
+
+ # simulate worker manager sending a notification to backend that it's alive
+ event_1 = OnCreateConsumer(
+ "test_eventconsumer_eventpublisher_integration",
+ wmgr_consumer_descriptor,
+ filters=[],
+ )
+ mock_worker_mgr.send(event_1)
+
+ # simulate the app updating a model a few times
+ for key in ["key-1", "key-2", "key-1"]:
+ event = OnWriteFeatureStore(
+ "test_eventconsumer_eventpublisher_integration",
+ the_backbone.descriptor,
+ key,
+ )
+ mock_client_app.send(event, timeout=0.1)
+
+ # worker manager should only get updates about feature update
+ wmgr_messages = wmgr_consumer.recv()
+ assert len(wmgr_messages) == 3
+
+ # the backend should only receive messages about consumer creation
+ back_messages = back_consumer.recv()
+ assert len(back_messages) == 1
+
+ # hypothetical app has no filters and will get all events
+ app_messages = capp_consumer.recv()
+ assert len(app_messages) == 4
+
+
+@pytest.mark.parametrize(
+ " timeout, batch_timeout, exp_err_msg",
+ [(-1, 1, " timeout"), (1, -1, "batch_timeout")],
+)
+def test_eventconsumer_invalid_timeout(
+ timeout: float,
+ batch_timeout: float,
+ exp_err_msg: str,
+ test_dir: str,
+ the_backbone: BackboneFeatureStore,
+) -> None:
+ """Verify that the event consumer raises an exception
+ when provided an invalid request timeout.
+
+ :param timeout: The request timeout for the event consumer recv call
+ :param batch_timeout: The batch timeout for the event consumer recv call
+ :param exp_err_msg: A unique value from the error message that should be raised
+ :param the_storage: The dragon storage engine to use
+ :param test_dir: Automatically generated unique working
+ directories for individual test outputs
+ """
+
+ wmgr_channel = DragonCommChannel(create_local())
+
+ # create some consumers to receive messages
+ wmgr_consumer = EventConsumer(
+ wmgr_channel,
+ the_backbone,
+ filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN],
+ )
+
+ # the consumer should report an error for the invalid timeout value
+ with pytest.raises(ValueError) as ex:
+ wmgr_consumer.recv(timeout=timeout, batch_timeout=batch_timeout)
+
+ assert exp_err_msg in ex.value.args[0]
+
+
+def test_eventconsumer_no_event_handler_registered(
+ the_backbone: t.Any, test_dir: str
+) -> None:
+ """Verify that a consumer discards messages when
+ on a channel if no handler is registered.
+
+ :param the_backbone: The BackboneFeatureStore to use
+ :param test_dir: Automatically generated unique working
+ directories for individual test outputs
+ """
+
+ wmgr_channel = DragonCommChannel(create_local())
+
+ # create a consumer to receive messages
+ wmgr_consumer = EventConsumer(wmgr_channel, the_backbone, event_handler=None)
+
+ # create a broadcasters to publish messages
+ mock_worker_mgr = EventBroadcaster(
+ the_backbone,
+ channel_factory=DragonCommChannel.from_descriptor,
+ )
+
+ # manually register the consumers since we don't have a backend running
+ the_backbone.notification_channels = [wmgr_channel.descriptor]
+
+ # simulate the app updating a model a few times
+ for key in ["key-1", "key-2", "key-1"]:
+ event = OnWriteFeatureStore(
+ "test_eventconsumer_no_event_handler_registered",
+ the_backbone.descriptor,
+ key,
+ )
+ mock_worker_mgr.send(event, timeout=0.1)
+
+ # run the handler and let it discard messages
+ for _ in range(15):
+ wmgr_consumer.listen_once(0.2, 2.0)
+
+ assert wmgr_consumer.listening
+
+
+def test_eventconsumer_no_event_handler_registered_shutdown(
+ the_backbone: t.Any, test_dir: str
+) -> None:
+ """Verify that a consumer without an event handler
+ registered still honors shutdown requests.
+
+ :param the_backbone: The BackboneFeatureStore to use
+ :param test_dir: Automatically generated unique working
+ directories for individual test outputs
+ """
+
+ wmgr_channel = DragonCommChannel(create_local())
+ capp_channel = DragonCommChannel(create_local())
+
+ # create a consumers to receive messages
+ wmgr_consumer = EventConsumer(wmgr_channel, the_backbone)
+
+ # create a broadcaster to publish messages
+ mock_worker_mgr = EventBroadcaster(
+ the_backbone,
+ channel_factory=DragonCommChannel.from_descriptor,
+ )
+
+ # manually register the consumers since we don't have a backend running
+ the_backbone.notification_channels = [
+ wmgr_channel.descriptor,
+ capp_channel.descriptor,
+ ]
+
+ # simulate the app updating a model a few times
+ for key in ["key-1", "key-2", "key-1"]:
+ event = OnWriteFeatureStore(
+ "test_eventconsumer_no_event_handler_registered_shutdown",
+ the_backbone.descriptor,
+ key,
+ )
+ mock_worker_mgr.send(event, timeout=0.1)
+
+ event = OnShutdownRequested(
+ "test_eventconsumer_no_event_handler_registered_shutdown"
+ )
+ mock_worker_mgr.send(event, timeout=0.1)
+
+ # wmgr will stop listening to messages when it is told to stop listening
+ wmgr_consumer.listen(timeout=0.1, batch_timeout=2.0)
+
+ for _ in range(15):
+ wmgr_consumer.listen_once(timeout=0.1, batch_timeout=2.0)
+
+ # confirm the messages were processed, discarded, and the shutdown was received
+ assert wmgr_consumer.listening == False
+
+
+def test_eventconsumer_registration(
+ the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """Verify that a consumer is correctly registered in
+ the backbone after sending a registration request. Then,
+ Confirm the consumer is unregistered after sending the
+ un-register request.
+
+ :param the_backbone: The BackboneFeatureStore to use
+ :param test_dir: Automatically generated unique working
+ directories for individual test outputs
+ """
+
+ with monkeypatch.context() as patch:
+ registrar = ConsumerRegistrationListener(
+ the_backbone, 1.0, 2.0, as_service=False
+ )
+
+ # NOTE: service.execute(as_service=False) will complete the service life-
+ # cycle and remove the registrar from the backbone, so mock _on_shutdown
+ disabled_shutdown = mock.MagicMock()
+ patch.setattr(registrar, "_on_shutdown", disabled_shutdown)
+
+ # initialze registrar resources
+ registrar.execute()
+
+ # create a consumer that will be registered
+ wmgr_channel = DragonCommChannel(create_local())
+ wmgr_consumer = EventConsumer(wmgr_channel, the_backbone)
+
+ registered_channels = the_backbone.notification_channels
+
+ # trigger the consumer-to-registrar handshake
+ wmgr_consumer.register()
+
+ current_registrations: t.List[str] = []
+
+ # have the registrar run a few times to pick up the msg
+ for i in range(15):
+ registrar.execute()
+ current_registrations = the_backbone.notification_channels
+ if len(current_registrations) != len(registered_channels):
+ logger.debug(f"The event was processed on iteration {i}")
+ break
+
+ # confirm the consumer is registered
+ assert wmgr_channel.descriptor in current_registrations
+
+ # copy old list so we can compare against it.
+ registered_channels = list(current_registrations)
+
+ # trigger the consumer removal
+ wmgr_consumer.unregister()
+
+ # have the registrar run a few times to pick up the msg
+ for i in range(15):
+ registrar.execute()
+ current_registrations = the_backbone.notification_channels
+ if len(current_registrations) != len(registered_channels):
+ logger.debug(f"The event was processed on iteration {i}")
+ break
+
+ # confirm the consumer is no longer registered
+ assert wmgr_channel.descriptor not in current_registrations
+
+
+def test_registrar_teardown(
+ the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """Verify that the consumer registrar removes itself from
+ the backbone when it shuts down.
+
+ :param the_backbone: The BackboneFeatureStore to use
+ :param test_dir: Automatically generated unique working
+ directories for individual test outputs
+ """
+
+ with monkeypatch.context() as patch:
+ registrar = ConsumerRegistrationListener(
+ the_backbone, 1.0, 2.0, as_service=False
+ )
+
+ # directly initialze registrar resources to avoid service life-cycle
+ registrar._create_eventing()
+
+ # confirm the registrar is published to the backbone
+ cfg = the_backbone.wait_for([BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], 10)
+ assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in cfg
+
+ # execute the entire service lifecycle 1x
+ registrar.execute()
+
+ consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone
+
+ for i in range(15):
+ time.sleep(0.1)
+ consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone
+ if not consumer_found:
+ logger.debug(f"Registrar removed from the backbone on iteration {i}")
+ break
+
+ assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in the_backbone
diff --git a/tests/dragon_wlm/test_featurestore.py b/tests/dragon_wlm/test_featurestore.py
new file mode 100644
index 0000000000..019dcde7a0
--- /dev/null
+++ b/tests/dragon_wlm/test_featurestore.py
@@ -0,0 +1,327 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+import multiprocessing as mp
+import random
+import time
+import typing as t
+import unittest.mock as mock
+import uuid
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ time as bbtime,
+)
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+# isort: off
+from dragon import fli
+from dragon.channels import Channel
+
+# isort: on
+
+if t.TYPE_CHECKING:
+ import conftest
+
+
+# The tests in this file must run in a dragon environment
+pytestmark = pytest.mark.dragon
+
+
+def test_backbone_wait_for_no_keys(
+ the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """Verify that asking the backbone to wait for a value succeeds
+ immediately and does not cause a wait to occur if the supplied key
+ list is empty.
+
+ :param the_backbone: the storage engine to use, prepopulated with
+ """
+ # set a very low timeout to confirm that it does not wait
+
+ with monkeypatch.context() as ctx:
+ # all keys should be found and the timeout should never be checked.
+ ctx.setattr(bbtime, "sleep", mock.MagicMock())
+
+ values = the_backbone.wait_for([])
+ assert len(values) == 0
+
+ # confirm that no wait occurred
+ bbtime.sleep.assert_not_called()
+
+
+def test_backbone_wait_for_prepopulated(
+ the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """Verify that asking the backbone to wait for a value succeed
+ immediately and do not cause a wait to occur if the data exists.
+
+ :param the_backbone: the storage engine to use, prepopulated with
+ """
+ # set a very low timeout to confirm that it does not wait
+
+ with monkeypatch.context() as ctx:
+ # all keys should be found and the timeout should never be checked.
+ ctx.setattr(bbtime, "sleep", mock.MagicMock())
+
+ values = the_backbone.wait_for([BackboneFeatureStore.MLI_WORKER_QUEUE], 0.1)
+
+ # confirm that wait_for with one key returns one value
+ assert len(values) == 1
+
+ # confirm that the descriptor is non-null w/some non-trivial value
+ assert len(values[BackboneFeatureStore.MLI_WORKER_QUEUE]) > 5
+
+ # confirm that no wait occurred
+ bbtime.sleep.assert_not_called()
+
+
+def test_backbone_wait_for_prepopulated_dupe(
+ the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """Verify that asking the backbone to wait for keys that are duplicated
+ results in a single value being returned for each key.
+
+ :param the_backbone: the storage engine to use, prepopulated with
+ """
+ # set a very low timeout to confirm that it does not wait
+
+ key1, key2 = "key-1", "key-2"
+ value1, value2 = "i-am-value-1", "i-am-value-2"
+ the_backbone[key1] = value1
+ the_backbone[key2] = value2
+
+ with monkeypatch.context() as ctx:
+ # all keys should be found and the timeout should never be checked.
+ ctx.setattr(bbtime, "sleep", mock.MagicMock())
+
+ values = the_backbone.wait_for([key1, key2, key1]) # key1 is duplicated
+
+ # confirm that wait_for with one key returns one value
+ assert len(values) == 2
+ assert key1 in values
+ assert key2 in values
+
+ assert values[key1] == value1
+ assert values[key2] == value2
+
+
+def set_value_after_delay(
+ descriptor: str, key: str, value: str, delay: float = 5
+) -> None:
+ """Helper method to persist a random value into the backbone
+
+ :param descriptor: the backbone feature store descriptor to attach to
+ :param key: the key to write to
+ :param value: a value to write to the key
+ :param delay: amount of delay to apply before writing the key
+ """
+ time.sleep(delay)
+
+ backbone = BackboneFeatureStore.from_descriptor(descriptor)
+ backbone[key] = value
+ logger.debug(f"set_value_after_delay wrote `{value} to backbone[`{key}`]")
+
+
+@pytest.mark.parametrize(
+ "delay",
+ [
+ pytest.param(
+ 0,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ pytest.param(
+ 1,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ pytest.param(
+ 2,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ pytest.param(
+ 4,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ pytest.param(
+ 8,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ ],
+)
+def test_backbone_wait_for_partial_prepopulated(
+ the_backbone: BackboneFeatureStore, delay: float
+) -> None:
+ """Verify that when data is not all in the backbone, the `wait_for` operation
+ continues to poll until it finds everything it needs.
+
+ :param the_backbone: the storage engine to use, prepopulated with
+ :param delay: the number of seconds the second process will wait before
+ setting the target value in the backbone featurestore
+ """
+ # set a very low timeout to confirm that it does not wait
+ wait_timeout = 10
+
+ key, value = str(uuid.uuid4()), str(random.random() * 10)
+
+ logger.debug(f"Starting process to write {key} after {delay}s")
+ p = mp.Process(
+ target=set_value_after_delay, args=(the_backbone.descriptor, key, value, delay)
+ )
+ p.start()
+
+ p2 = mp.Process(
+ target=the_backbone.wait_for,
+ args=([BackboneFeatureStore.MLI_WORKER_QUEUE, key],),
+ kwargs={"timeout": wait_timeout},
+ )
+ p2.start()
+
+ p.join()
+ p2.join()
+
+ # both values should be written at this time
+ ret_vals = the_backbone.wait_for(
+ [key, BackboneFeatureStore.MLI_WORKER_QUEUE, key], 0.1
+ )
+ # confirm that wait_for with two keys returns two values
+ assert len(ret_vals) == 2, "values should contain values for both awaited keys"
+
+ # confirm the pre-populated value has the correct output
+ assert (
+ ret_vals[BackboneFeatureStore.MLI_WORKER_QUEUE] == "12345"
+ ) # mock descriptor value from fixture
+
+ # confirm the population process completed and the awaited value is correct
+ assert ret_vals[key] == value, "verify order of values "
+
+
+@pytest.mark.parametrize(
+ "num_keys",
+ [
+ pytest.param(
+ 0,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ pytest.param(
+ 1,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ pytest.param(
+ 3,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ pytest.param(
+ 7,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ pytest.param(
+ 11,
+ marks=pytest.mark.skip(
+ "Must use entrypoint instead of mp.Process to run on build agent"
+ ),
+ ),
+ ],
+)
+def test_backbone_wait_for_multikey(
+ the_backbone: BackboneFeatureStore,
+ num_keys: int,
+ test_dir: str,
+) -> None:
+ """Verify that asking the backbone to wait for multiple keys results
+ in that number of values being returned.
+
+ :param the_backbone: the storage engine to use, prepopulated with
+ :param num_keys: the number of extra keys to set & request in the backbone
+ """
+ # maximum delay allowed for setter processes
+ max_delay = 5
+
+ extra_keys = [str(uuid.uuid4()) for _ in range(num_keys)]
+ extra_values = [str(uuid.uuid4()) for _ in range(num_keys)]
+ extras = dict(zip(extra_keys, extra_values))
+ delays = [random.random() * max_delay for _ in range(num_keys)]
+ processes = []
+
+ for key, value, delay in zip(extra_keys, extra_values, delays):
+ assert delay < max_delay, "write delay exceeds test timeout"
+ logger.debug(f"Delaying {key} write by {delay} seconds")
+ p = mp.Process(
+ target=set_value_after_delay,
+ args=(the_backbone.descriptor, key, value, delay),
+ )
+ p.start()
+ processes.append(p)
+
+ p2 = mp.Process(
+ target=the_backbone.wait_for,
+ args=(extra_keys,),
+ kwargs={"timeout": max_delay * 2},
+ )
+ p2.start()
+ for p in processes:
+ p.join(timeout=max_delay * 2)
+ p2.join(
+ timeout=max_delay * 2
+ ) # give it 10 seconds longer than p2 timeout for backoff
+
+ # use without a wait to verify all values are written
+ num_keys = len(extra_keys)
+ actual_values = the_backbone.wait_for(extra_keys, timeout=0.01)
+ assert len(extra_keys) == num_keys
+
+ # confirm that wait_for returns all the expected values
+ assert len(actual_values) == num_keys
+
+ # confirm that the returned values match (e.g. are returned in the right order)
+ for k in extras:
+ assert extras[k] == actual_values[k]
diff --git a/tests/dragon_wlm/test_featurestore_base.py b/tests/dragon_wlm/test_featurestore_base.py
new file mode 100644
index 0000000000..6daceb9061
--- /dev/null
+++ b/tests/dragon_wlm/test_featurestore_base.py
@@ -0,0 +1,844 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import pathlib
+import time
+import typing as t
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster
+from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer
+from smartsim._core.mli.infrastructure.comm.event import (
+ OnCreateConsumer,
+ OnWriteFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
+ DragonFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.feature_store import ReservedKeys
+from smartsim.error import SmartSimError
+
+from .channel import FileSystemCommChannel
+from .feature_store import MemoryFeatureStore
+
+if t.TYPE_CHECKING:
+ import conftest
+
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+
+def boom(*args, **kwargs) -> None:
+ """Helper function that blows up when used to mock up
+ some other function."""
+ raise Exception(f"you shall not pass! {args}, {kwargs}")
+
+
+def test_event_uid() -> None:
+ """Verify that all events include a unique identifier."""
+ uids: t.Set[str] = set()
+ num_iters = 1000
+
+ # generate a bunch of events and keep track all the IDs
+ for i in range(num_iters):
+ event_a = OnCreateConsumer("test_event_uid", str(i), filters=[])
+ event_b = OnWriteFeatureStore("test_event_uid", "test_event_uid", str(i))
+
+ uids.add(event_a.uid)
+ uids.add(event_b.uid)
+
+ # verify each event created a unique ID
+ assert len(uids) == 2 * num_iters
+
+
+def test_mli_reserved_keys_conversion() -> None:
+ """Verify that conversion from a string to an enum member
+ works as expected."""
+
+ for reserved_key in ReservedKeys:
+ # iterate through all keys and verify `from_string` works
+ assert ReservedKeys.contains(reserved_key.value)
+
+ # show that the value (actual key) not the enum member name
+ # will not be incorrectly identified as reserved
+ assert not ReservedKeys.contains(str(reserved_key).split(".")[1])
+
+
+def test_mli_reserved_keys_writes() -> None:
+ """Verify that attempts to write to reserved keys are blocked from a
+ standard DragonFeatureStore but enabled with the BackboneFeatureStore."""
+
+ mock_storage = {}
+ dfs = DragonFeatureStore(mock_storage)
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ other = MemoryFeatureStore(mock_storage)
+
+ expected_value = "value"
+
+ for reserved_key in ReservedKeys:
+ # we expect every reserved key to fail using DragonFeatureStore...
+ with pytest.raises(SmartSimError) as ex:
+ dfs[reserved_key] = expected_value
+
+ assert "reserved key" in ex.value.args[0]
+
+ # ... and expect other feature stores to respect reserved keys
+ with pytest.raises(SmartSimError) as ex:
+ other[reserved_key] = expected_value
+
+ assert "reserved key" in ex.value.args[0]
+
+ # ...and those same keys to succeed on the backbone
+ backbone[reserved_key] = expected_value
+ actual_value = backbone[reserved_key]
+ assert actual_value == expected_value
+
+
+def test_mli_consumers_read_by_key() -> None:
+ """Verify that the value returned from the mli consumers method is written
+ to the correct key and reads are allowed via standard dragon feature store."""
+
+ mock_storage = {}
+ dfs = DragonFeatureStore(mock_storage)
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ other = MemoryFeatureStore(mock_storage)
+
+ expected_value = "value"
+
+ # write using backbone that has permission to write reserved keys
+ backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] = expected_value
+
+ # confirm read-only access to reserved keys from any FeatureStore
+ for fs in [dfs, backbone, other]:
+ assert fs[ReservedKeys.MLI_NOTIFY_CONSUMERS] == expected_value
+
+
+def test_mli_consumers_read_by_backbone() -> None:
+ """Verify that the backbone reads the correct location
+ when using the backbone feature store API instead of mapping API."""
+
+ mock_storage = {}
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ expected_value = "value"
+
+ backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] = expected_value
+
+ # confirm reading via convenience method returns expected value
+ assert backbone.notification_channels[0] == expected_value
+
+
+def test_mli_consumers_write_by_backbone() -> None:
+ """Verify that the backbone writes the correct location
+ when using the backbone feature store API instead of mapping API."""
+
+ mock_storage = {}
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ expected_value = ["value"]
+
+ backbone.notification_channels = expected_value
+
+ # confirm write using convenience method targets expected key
+ assert backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] == ",".join(expected_value)
+
+
+def test_eventpublisher_broadcast_no_factory(test_dir: str) -> None:
+ """Verify that a broadcast operation without any registered subscribers
+ succeeds without raising Exceptions.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ mock_storage = {}
+ consumer_descriptor = storage_path / "test-consumer"
+
+ # NOTE: we're not putting any consumers into the backbone here!
+ backbone = BackboneFeatureStore(mock_storage)
+
+ event = OnCreateConsumer(
+ "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[]
+ )
+
+ publisher = EventBroadcaster(backbone)
+ num_receivers = 0
+
+ # publishing this event without any known consumers registered should succeed
+ # but report that it didn't have anybody to send the event to
+ consumer_descriptor = storage_path / f"test-consumer"
+ event = OnCreateConsumer(
+ "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[]
+ )
+
+ num_receivers += publisher.send(event)
+
+ # confirm no changes to the backbone occur when fetching the empty consumer key
+ key_in_features_store = ReservedKeys.MLI_NOTIFY_CONSUMERS in backbone
+ assert not key_in_features_store
+
+ # confirm that the broadcast reports no events published
+ assert num_receivers == 0
+ # confirm that the broadcast buffered the event for a later send
+ assert publisher.num_buffered == 1
+
+
+def test_eventpublisher_broadcast_to_empty_consumer_list(test_dir: str) -> None:
+ """Verify that a broadcast operation without any registered subscribers
+ succeeds without raising Exceptions.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ consumer_descriptor = storage_path / "test-consumer"
+
+ # prep our backbone with a consumer list
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ backbone.notification_channels = []
+
+ event = OnCreateConsumer(
+ "test_eventpublisher_broadcast_to_empty_consumer_list",
+ consumer_descriptor,
+ filters=[],
+ )
+ publisher = EventBroadcaster(
+ backbone, channel_factory=FileSystemCommChannel.from_descriptor
+ )
+ num_receivers = publisher.send(event)
+
+ registered_consumers = backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS]
+
+ # confirm that no consumers exist in backbone to send to
+ assert not registered_consumers
+ # confirm that the broadcast reports no events published
+ assert num_receivers == 0
+ # confirm that the broadcast buffered the event for a later send
+ assert publisher.num_buffered == 1
+
+
+def test_eventpublisher_broadcast_without_channel_factory(test_dir: str) -> None:
+ """Verify that a broadcast operation reports an error if no channel
+ factory was supplied for constructing the consumer channels.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ consumer_descriptor = storage_path / "test-consumer"
+
+ # prep our backbone with a consumer list
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ backbone.notification_channels = [consumer_descriptor]
+
+ event = OnCreateConsumer(
+ "test_eventpublisher_broadcast_without_channel_factory",
+ consumer_descriptor,
+ filters=[],
+ )
+ publisher = EventBroadcaster(
+ backbone,
+ # channel_factory=FileSystemCommChannel.from_descriptor # <--- not supplied
+ )
+
+ with pytest.raises(SmartSimError) as ex:
+ publisher.send(event)
+
+ assert "factory" in ex.value.args[0]
+
+
+def test_eventpublisher_broadcast_empties_buffer(test_dir: str) -> None:
+ """Verify that a successful broadcast clears messages from the event
+ buffer when a new message is sent and consumers are registered.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ consumer_descriptor = storage_path / "test-consumer"
+
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ backbone.notification_channels = (consumer_descriptor,)
+
+ publisher = EventBroadcaster(
+ backbone, channel_factory=FileSystemCommChannel.from_descriptor
+ )
+
+ # mock building up some buffered events
+ num_buffered_events = 14
+ for i in range(num_buffered_events):
+ event = OnCreateConsumer(
+ "test_eventpublisher_broadcast_empties_buffer",
+ storage_path / f"test-consumer-{str(i)}",
+ [],
+ )
+ publisher._event_buffer.append(bytes(event))
+
+ event0 = OnCreateConsumer(
+ "test_eventpublisher_broadcast_empties_buffer",
+ storage_path / f"test-consumer-{str(num_buffered_events + 1)}",
+ [],
+ )
+
+ num_receivers = publisher.send(event0)
+ # 1 receiver x 15 total events == 15 events
+ assert num_receivers == num_buffered_events + 1
+
+
+@pytest.mark.parametrize(
+ "num_consumers, num_buffered, expected_num_sent",
+ [
+ pytest.param(0, 7, 0, id="0 x (7+1) - no consumers, multi-buffer"),
+ pytest.param(1, 7, 8, id="1 x (7+1) - single consumer, multi-buffer"),
+ pytest.param(2, 7, 16, id="2 x (7+1) - multi-consumer, multi-buffer"),
+ pytest.param(4, 4, 20, id="4 x (4+1) - multi-consumer, multi-buffer (odd #)"),
+ pytest.param(9, 0, 9, id="13 x (0+1) - multi-consumer, empty buffer"),
+ ],
+)
+def test_eventpublisher_broadcast_returns_total_sent(
+ test_dir: str, num_consumers: int, num_buffered: int, expected_num_sent: int
+) -> None:
+ """Verify that a successful broadcast returns the total number of events
+ sent, including buffered messages.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ :param num_consumers: the number of consumers to mock setting up prior to send
+ :param num_buffered: the number of pre-buffered events to mock up
+ :param expected_num_sent: the expected result from calling send
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ consumers = []
+ for i in range(num_consumers):
+ consumers.append(storage_path / f"test-consumer-{i}")
+
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ backbone.notification_channels = consumers
+
+ publisher = EventBroadcaster(
+ backbone, channel_factory=FileSystemCommChannel.from_descriptor
+ )
+
+ # mock building up some buffered events
+ for i in range(num_buffered):
+ event = OnCreateConsumer(
+ "test_eventpublisher_broadcast_returns_total_sent",
+ storage_path / f"test-consumer-{str(i)}",
+ [],
+ )
+ publisher._event_buffer.append(bytes(event))
+
+ assert publisher.num_buffered == num_buffered
+
+ # this event will trigger clearing anything already in buffer
+ event0 = OnCreateConsumer(
+ "test_eventpublisher_broadcast_returns_total_sent",
+ storage_path / f"test-consumer-{num_buffered}",
+ [],
+ )
+
+ # num_receivers should contain a number that computes w/all consumers and all events
+ num_receivers = publisher.send(event0)
+
+ assert num_receivers == expected_num_sent
+
+
+def test_eventpublisher_prune_unused_consumer(test_dir: str) -> None:
+ """Verify that any unused consumers are pruned each time a new event is sent.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ consumer_descriptor = storage_path / "test-consumer"
+
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+
+ publisher = EventBroadcaster(
+ backbone, channel_factory=FileSystemCommChannel.from_descriptor
+ )
+
+ event = OnCreateConsumer(
+ "test_eventpublisher_prune_unused_consumer",
+ consumer_descriptor,
+ filters=[],
+ )
+
+ # the only registered cnosumer is in the event, expect no pruning
+ backbone.notification_channels = (consumer_descriptor,)
+
+ publisher.send(event)
+ assert str(consumer_descriptor) in publisher._channel_cache
+ assert len(publisher._channel_cache) == 1
+
+ # add a new descriptor for another event...
+ consumer_descriptor2 = storage_path / "test-consumer-2"
+ # ... and remove the old descriptor from the backbone when it's looked up
+ backbone.notification_channels = (consumer_descriptor2,)
+
+ event = OnCreateConsumer(
+ "test_eventpublisher_prune_unused_consumer", consumer_descriptor2, filters=[]
+ )
+
+ publisher.send(event)
+
+ assert str(consumer_descriptor2) in publisher._channel_cache
+ assert str(consumer_descriptor) not in publisher._channel_cache
+ assert len(publisher._channel_cache) == 1
+
+ # test multi-consumer pruning by caching some extra channels
+ prune0, prune1, prune2 = "abc", "def", "ghi"
+ publisher._channel_cache[prune0] = "doesnt-matter-if-it-is-pruned"
+ publisher._channel_cache[prune1] = "doesnt-matter-if-it-is-pruned"
+ publisher._channel_cache[prune2] = "doesnt-matter-if-it-is-pruned"
+
+ # add in one of our old channels so we prune the above items, send to these
+ backbone.notification_channels = (consumer_descriptor, consumer_descriptor2)
+
+ publisher.send(event)
+
+ assert str(consumer_descriptor2) in publisher._channel_cache
+
+ # NOTE: we should NOT prune something that isn't used by this message but
+ # does appear in `backbone.notification_channels`
+ assert str(consumer_descriptor) in publisher._channel_cache
+
+ # confirm all of our items that were not in the notification channels are gone
+ for pruned in [prune0, prune1, prune2]:
+ assert pruned not in publisher._channel_cache
+
+ # confirm we have only the two expected items in the channel cache
+ assert len(publisher._channel_cache) == 2
+
+
+def test_eventpublisher_serialize_failure(
+ test_dir: str, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """Verify that errors during message serialization are raised to the caller.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ :param monkeypatch: pytest fixture for modifying behavior of existing code
+ with mock implementations
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ storage_path.mkdir(parents=True, exist_ok=True)
+
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ target_descriptor = str(storage_path / "test-consumer")
+
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ publisher = EventBroadcaster(
+ backbone, channel_factory=FileSystemCommChannel.from_descriptor
+ )
+
+ with monkeypatch.context() as patch:
+ event = OnCreateConsumer(
+ "test_eventpublisher_serialize_failure", target_descriptor, filters=[]
+ )
+
+ # patch the __bytes__ implementation to cause pickling to fail during send
+ def bad_bytes(self) -> bytes:
+ return b"abc"
+
+ # this patch causes an attribute error when event pickling is attempted
+ patch.setattr(event, "__bytes__", bad_bytes)
+
+ backbone.notification_channels = (target_descriptor,)
+
+ # send a message into the channel
+ with pytest.raises(AttributeError) as ex:
+ publisher.send(event)
+
+ assert "serialize" in ex.value.args[0]
+
+
+def test_eventpublisher_factory_failure(
+ test_dir: str, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """Verify that errors during channel construction are raised to the caller.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ :param monkeypatch: pytest fixture for modifying behavior of existing code
+ with mock implementations
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ storage_path.mkdir(parents=True, exist_ok=True)
+
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ target_descriptor = str(storage_path / "test-consumer")
+
+ def boom(descriptor: str) -> None:
+ raise Exception(f"you shall not pass! {descriptor}")
+
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ publisher = EventBroadcaster(backbone, channel_factory=boom)
+
+ with monkeypatch.context() as patch:
+ event = OnCreateConsumer(
+ "test_eventpublisher_factory_failure", target_descriptor, filters=[]
+ )
+
+ backbone.notification_channels = (target_descriptor,)
+
+ # send a message into the channel
+ with pytest.raises(SmartSimError) as ex:
+ publisher.send(event)
+
+ assert "construct" in ex.value.args[0]
+
+
+def test_eventpublisher_failure(test_dir: str, monkeypatch: pytest.MonkeyPatch) -> None:
+ """Verify that unexpected errors during message send are caught and wrapped in a
+ SmartSimError so they are not propagated directly to the caller.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ :param monkeypatch: pytest fixture for modifying behavior of existing code
+ with mock implementations
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ storage_path.mkdir(parents=True, exist_ok=True)
+
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ target_descriptor = str(storage_path / "test-consumer")
+
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ publisher = EventBroadcaster(
+ backbone, channel_factory=FileSystemCommChannel.from_descriptor
+ )
+
+ def boom(self) -> None:
+ raise Exception("That was unexpected...")
+
+ with monkeypatch.context() as patch:
+ event = OnCreateConsumer(
+ "test_eventpublisher_failure", target_descriptor, filters=[]
+ )
+
+ # patch the _broadcast implementation to cause send to fail after
+ # after the event has been pickled
+ patch.setattr(publisher, "_broadcast", boom)
+
+ backbone.notification_channels = (target_descriptor,)
+
+ # Here, we see the exception raised by broadcast that isn't expected
+ # is not allowed directly out, and instead is wrapped in SmartSimError
+ with pytest.raises(SmartSimError) as ex:
+ publisher.send(event)
+
+ assert "unexpected" in ex.value.args[0]
+
+
+def test_eventconsumer_receive(test_dir: str) -> None:
+ """Verify that a consumer retrieves a message from the given channel.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ storage_path.mkdir(parents=True, exist_ok=True)
+
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ target_descriptor = str(storage_path / "test-consumer")
+
+ backbone = BackboneFeatureStore(mock_storage)
+ comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor)
+ event = OnCreateConsumer(
+ "test_eventconsumer_receive", target_descriptor, filters=[]
+ )
+
+ # simulate a sent event by writing directly to the input comm channel
+ comm_channel.send(bytes(event))
+
+ consumer = EventConsumer(comm_channel, backbone)
+
+ all_received: t.List[OnCreateConsumer] = consumer.recv()
+ assert len(all_received) == 1
+
+ # verify we received the same event that was raised
+ assert all_received[0].category == event.category
+ assert all_received[0].descriptor == event.descriptor
+
+
+@pytest.mark.parametrize("num_sent", [0, 1, 2, 4, 8, 16])
+def test_eventconsumer_receive_multi(test_dir: str, num_sent: int) -> None:
+ """Verify that a consumer retrieves multiple message from the given channel.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ :param num_sent: parameterized value used to vary the number of events
+ that are enqueued and validations are checked at multiple queue sizes
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ storage_path.mkdir(parents=True, exist_ok=True)
+
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ target_descriptor = str(storage_path / "test-consumer")
+
+ backbone = BackboneFeatureStore(mock_storage)
+ comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor)
+
+ # simulate multiple sent events by writing directly to the input comm channel
+ for _ in range(num_sent):
+ event = OnCreateConsumer(
+ "test_eventconsumer_receive_multi", target_descriptor, filters=[]
+ )
+ comm_channel.send(bytes(event))
+
+ consumer = EventConsumer(comm_channel, backbone)
+
+ all_received: t.List[OnCreateConsumer] = consumer.recv()
+ assert len(all_received) == num_sent
+
+
+def test_eventconsumer_receive_empty(test_dir: str) -> None:
+ """Verify that a consumer receiving an empty message ignores the
+ message and continues processing.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ storage_path.mkdir(parents=True, exist_ok=True)
+
+ mock_storage = {}
+
+ # note: file-system descriptors are just paths
+ target_descriptor = str(storage_path / "test-consumer")
+
+ backbone = BackboneFeatureStore(mock_storage)
+ comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor)
+
+ # simulate a sent event by writing directly to the input comm channel
+ comm_channel.send(bytes(b""))
+
+ consumer = EventConsumer(comm_channel, backbone)
+
+ messages = consumer.recv()
+
+ # the messages array should be empty
+ assert not messages
+
+
+def test_eventconsumer_eventpublisher_integration(test_dir: str) -> None:
+ """Verify that the publisher and consumer integrate as expected when
+ multiple publishers and consumers are sending simultaneously.
+
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ storage_path.mkdir(parents=True, exist_ok=True)
+
+ mock_storage = {}
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+ mock_fs_descriptor = str(storage_path / f"mock-feature-store")
+
+ wmgr_channel = FileSystemCommChannel(storage_path / "test-wmgr")
+ capp_channel = FileSystemCommChannel(storage_path / "test-capp")
+ back_channel = FileSystemCommChannel(storage_path / "test-backend")
+
+ wmgr_consumer_descriptor = wmgr_channel.descriptor
+ capp_consumer_descriptor = capp_channel.descriptor
+ back_consumer_descriptor = back_channel.descriptor
+
+ # create some consumers to receive messages
+ wmgr_consumer = EventConsumer(
+ wmgr_channel,
+ backbone,
+ filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN],
+ )
+ capp_consumer = EventConsumer(
+ capp_channel,
+ backbone,
+ )
+ back_consumer = EventConsumer(
+ back_channel,
+ backbone,
+ filters=[OnCreateConsumer.CONSUMER_CREATED],
+ )
+
+ # create some broadcasters to publish messages
+ mock_worker_mgr = EventBroadcaster(
+ backbone,
+ channel_factory=FileSystemCommChannel.from_descriptor,
+ )
+ mock_client_app = EventBroadcaster(
+ backbone,
+ channel_factory=FileSystemCommChannel.from_descriptor,
+ )
+
+ # register all of the consumers even though the OnCreateConsumer really should
+ # trigger its registration. event processing is tested elsewhere.
+ backbone.notification_channels = [
+ wmgr_consumer_descriptor,
+ capp_consumer_descriptor,
+ back_consumer_descriptor,
+ ]
+
+ # simulate worker manager sending a notification to backend that it's alive
+ event_1 = OnCreateConsumer(
+ "test_eventconsumer_eventpublisher_integration",
+ wmgr_consumer_descriptor,
+ filters=[],
+ )
+ mock_worker_mgr.send(event_1)
+
+ # simulate the app updating a model a few times
+ event_2 = OnWriteFeatureStore(
+ "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1"
+ )
+ event_3 = OnWriteFeatureStore(
+ "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-2"
+ )
+ event_4 = OnWriteFeatureStore(
+ "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1"
+ )
+
+ mock_client_app.send(event_2)
+ mock_client_app.send(event_3)
+ mock_client_app.send(event_4)
+
+ # worker manager should only get updates about feature update
+ wmgr_messages = wmgr_consumer.recv()
+ assert len(wmgr_messages) == 3
+
+ # the backend should only receive messages about consumer creation
+ back_messages = back_consumer.recv()
+ assert len(back_messages) == 1
+
+ # hypothetical app has no filters and will get all events
+ app_messages = capp_consumer.recv()
+ assert len(app_messages) == 4
+
+
+@pytest.mark.parametrize("invalid_timeout", [-100.0, -1.0, 0.0])
+def test_eventconsumer_batch_timeout(
+ invalid_timeout: float,
+ test_dir: str,
+) -> None:
+ """Verify that a consumer allows only positive, non-zero values for timeout
+ if it is supplied.
+
+ :param invalid_timeout: any invalid timeout that should fail validation
+ :param test_dir: pytest fixture automatically generating unique working
+ directories for individual test outputs
+ """
+ storage_path = pathlib.Path(test_dir) / "features"
+ storage_path.mkdir(parents=True, exist_ok=True)
+
+ mock_storage = {}
+ backbone = BackboneFeatureStore(mock_storage)
+
+ channel = FileSystemCommChannel(storage_path / "test-wmgr")
+
+ with pytest.raises(ValueError) as ex:
+ # try to create a consumer w/a max recv size of 0
+ consumer = EventConsumer(
+ channel,
+ backbone,
+ filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN],
+ )
+ consumer.recv(batch_timeout=invalid_timeout)
+
+ assert "positive" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "wait_timeout, exp_wait_max",
+ [
+ # aggregate the 1+1+1 into 3 on remaining parameters
+ pytest.param(1, 1 + 1 + 1, id="1s wait, 3 cycle steps"),
+ pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"),
+ pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"),
+ pytest.param(9, 3 + 2 + 4 + 8, id="9s wait, 6 cycle steps"),
+ # aggregate an entire cycle into 16
+ pytest.param(19.5, 16 + 3 + 2 + 4, id="20s wait, repeat cycle"),
+ ],
+)
+def test_backbone_wait_timeout(wait_timeout: float, exp_wait_max: float) -> None:
+ """Verify that attempts to attach to the worker queue from the protoclient
+ timeout in an appropriate amount of time. Note: due to the backoff, we verify
+ the elapsed time is less than the 15s of a cycle of waits.
+
+ :param wait_timeout: Maximum amount of time (in seconds) to allow the backbone
+ to wait for the requested value to exist
+ :param exp_wait_max: Maximum amount of time (in seconds) to set as the upper
+ bound to allow the delays with backoff to occur
+ :param storage_for_dragon_fs: the dragon storage engine to use
+ """
+
+ # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8]
+ # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps)
+ start_time = time.time()
+
+ storage = {}
+ backbone = BackboneFeatureStore(storage)
+
+ with pytest.raises(SmartSimError) as ex:
+ backbone.wait_for(["does-not-exist"], wait_timeout)
+
+ assert "timeout" in str(ex.value.args[0]).lower()
+
+ end_time = time.time()
+ elapsed = end_time - start_time
+
+ # confirm that we met our timeout
+ assert elapsed > wait_timeout, f"below configured timeout {wait_timeout}"
+
+ # confirm that the total wait time is aligned with the sleep cycle
+ assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}"
diff --git a/tests/dragon_wlm/test_featurestore_integration.py b/tests/dragon_wlm/test_featurestore_integration.py
new file mode 100644
index 0000000000..23fdc55ab6
--- /dev/null
+++ b/tests/dragon_wlm/test_featurestore_integration.py
@@ -0,0 +1,213 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import typing as t
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.comm.channel.dragon_util import (
+ DEFAULT_CHANNEL_BUFFER_SIZE,
+ create_local,
+)
+from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster
+from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer
+from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+
+# isort: off
+from dragon.channels import Channel
+
+# isort: on
+
+if t.TYPE_CHECKING:
+ import conftest
+
+
+# The tests in this file must run in a dragon environment
+pytestmark = pytest.mark.dragon
+
+
+@pytest.fixture(scope="module")
+def the_worker_channel() -> DragonCommChannel:
+ """Fixture to create a valid descriptor for a worker channel
+ that can be attached to."""
+ wmgr_channel_ = create_local()
+ wmgr_channel = DragonCommChannel(wmgr_channel_)
+ return wmgr_channel
+
+
+@pytest.mark.parametrize(
+ "num_events, batch_timeout, max_batches_expected",
+ [
+ pytest.param(1, 1.0, 2, id="under 1s timeout"),
+ pytest.param(20, 1.0, 3, id="test 1s timeout 20x"),
+ pytest.param(30, 0.2, 5, id="test 0.2s timeout 30x"),
+ pytest.param(60, 0.4, 4, id="small batches"),
+ pytest.param(100, 0.1, 10, id="many small batches"),
+ ],
+)
+def test_eventconsumer_max_dequeue(
+ num_events: int,
+ batch_timeout: float,
+ max_batches_expected: int,
+ the_worker_channel: DragonCommChannel,
+ the_backbone: BackboneFeatureStore,
+) -> None:
+ """Verify that a consumer does not sit and collect messages indefinitely
+ by checking that a consumer returns after a maximum timeout is exceeded.
+
+ :param num_events: Total number of events to raise in the test
+ :param batch_timeout: Maximum wait time (in seconds) for a message to be sent
+ :param max_batches_expected: Maximum number of receives that should occur
+ :param the_storage: Dragon storage engine to use
+ """
+
+ # create some consumers to receive messages
+ wmgr_consumer = EventConsumer(
+ the_worker_channel,
+ the_backbone,
+ filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN],
+ )
+
+ # create a broadcaster to publish messages
+ mock_client_app = EventBroadcaster(
+ the_backbone,
+ channel_factory=DragonCommChannel.from_descriptor,
+ )
+
+ # register all of the consumers even though the OnCreateConsumer really should
+ # trigger its registration. event processing is tested elsewhere.
+ the_backbone.notification_channels = [the_worker_channel.descriptor]
+
+ # simulate the app updating a model a lot of times
+ for key in (f"key-{i}" for i in range(num_events)):
+ event = OnWriteFeatureStore(
+ "test_eventconsumer_max_dequeue", the_backbone.descriptor, key
+ )
+ mock_client_app.send(event, timeout=0.01)
+
+ num_dequeued = 0
+ num_batches = 0
+
+ while wmgr_messages := wmgr_consumer.recv(
+ timeout=0.1,
+ batch_timeout=batch_timeout,
+ ):
+ # worker manager should not get more than `max_num_msgs` events
+ num_dequeued += len(wmgr_messages)
+ num_batches += 1
+
+ # make sure we made all the expected dequeue calls and got everything
+ assert num_dequeued == num_events
+ assert num_batches > 0
+ assert num_batches < max_batches_expected, "too many recv calls were made"
+
+
+@pytest.mark.parametrize(
+ "buffer_size",
+ [
+ pytest.param(
+ -1,
+ id="replace negative, default to 500",
+ marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"),
+ ),
+ pytest.param(
+ 0,
+ id="replace zero, default to 500",
+ marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"),
+ ),
+ pytest.param(
+ 1,
+ id="non-zero buffer size: 1",
+ marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"),
+ ),
+ # pytest.param(500, id="maximum size edge case: 500"),
+ pytest.param(
+ 550,
+ id="larger than default: 550",
+ marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"),
+ ),
+ pytest.param(
+ 800,
+ id="much larger then default: 800",
+ marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"),
+ ),
+ pytest.param(
+ 1000,
+ id="very large buffer: 1000, unreliable in dragon-v0.10",
+ marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"),
+ ),
+ ],
+)
+def test_channel_buffer_size(
+ buffer_size: int,
+ the_storage: t.Any,
+) -> None:
+ """Verify that a channel used by an EventBroadcaster can buffer messages
+ until a configured maximum value is exceeded.
+
+ :param buffer_size: Maximum number of messages allowed in a channel buffer
+ :param the_storage: The dragon storage engine to use
+ """
+
+ mock_storage = the_storage
+ backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True)
+
+ wmgr_channel_ = create_local(buffer_size) # <--- vary buffer size
+ wmgr_channel = DragonCommChannel(wmgr_channel_)
+ wmgr_consumer_descriptor = wmgr_channel.descriptor
+
+ # create a broadcaster to publish messages. create no consumers to
+ # push the number of sent messages past the allotted buffer size
+ mock_client_app = EventBroadcaster(
+ backbone,
+ channel_factory=DragonCommChannel.from_descriptor,
+ )
+
+ # register all of the consumers even though the OnCreateConsumer really should
+ # trigger its registration. event processing is tested elsewhere.
+ backbone.notification_channels = [wmgr_consumer_descriptor]
+
+ if buffer_size < 1:
+ # NOTE: we set this after creating the channel above to ensure
+ # the default parameter value was used during instantiation
+ buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE
+
+ # simulate the app updating a model a lot of times
+ for key in (f"key-{i}" for i in range(buffer_size)):
+ event = OnWriteFeatureStore(
+ "test_channel_buffer_size", backbone.descriptor, key
+ )
+ mock_client_app.send(event, timeout=0.01)
+
+ # adding 1 more over the configured buffer size should report the error
+ with pytest.raises(Exception) as ex:
+ mock_client_app.send(event, timeout=0.01)
diff --git a/tests/dragon_wlm/test_inference_reply.py b/tests/dragon_wlm/test_inference_reply.py
new file mode 100644
index 0000000000..bdc7be14bc
--- /dev/null
+++ b/tests/dragon_wlm/test_inference_reply.py
@@ -0,0 +1,76 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey
+from smartsim._core.mli.infrastructure.worker.worker import InferenceReply
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+handler = MessageHandler()
+
+
+@pytest.fixture
+def inference_reply() -> InferenceReply:
+ return InferenceReply()
+
+
+@pytest.fixture
+def fs_key() -> TensorKey:
+ return TensorKey("key", "descriptor")
+
+
+@pytest.mark.parametrize(
+ "outputs, expected",
+ [
+ ([b"output bytes"], True),
+ (None, False),
+ ([], False),
+ ],
+)
+def test_has_outputs(monkeypatch, inference_reply, outputs, expected):
+ """Test the has_outputs property with different values for outputs."""
+ monkeypatch.setattr(inference_reply, "outputs", outputs)
+ assert inference_reply.has_outputs == expected
+
+
+@pytest.mark.parametrize(
+ "output_keys, expected",
+ [
+ ([fs_key], True),
+ (None, False),
+ ([], False),
+ ],
+)
+def test_has_output_keys(monkeypatch, inference_reply, output_keys, expected):
+ """Test the has_output_keys property with different values for output_keys."""
+ monkeypatch.setattr(inference_reply, "output_keys", output_keys)
+ assert inference_reply.has_output_keys == expected
diff --git a/tests/dragon_wlm/test_inference_request.py b/tests/dragon_wlm/test_inference_request.py
new file mode 100644
index 0000000000..f5c8b9bdc7
--- /dev/null
+++ b/tests/dragon_wlm/test_inference_request.py
@@ -0,0 +1,118 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey
+from smartsim._core.mli.infrastructure.worker.worker import InferenceRequest
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+handler = MessageHandler()
+
+
+@pytest.fixture
+def inference_request() -> InferenceRequest:
+ return InferenceRequest()
+
+
+@pytest.fixture
+def fs_key() -> TensorKey:
+ return TensorKey("key", "descriptor")
+
+
+@pytest.mark.parametrize(
+ "raw_model, expected",
+ [
+ (handler.build_model(b"bytes", "Model Name", "V1"), True),
+ (None, False),
+ ],
+)
+def test_has_raw_model(monkeypatch, inference_request, raw_model, expected):
+ """Test the has_raw_model property with different values for raw_model."""
+ monkeypatch.setattr(inference_request, "raw_model", raw_model)
+ assert inference_request.has_raw_model == expected
+
+
+@pytest.mark.parametrize(
+ "model_key, expected",
+ [
+ (fs_key, True),
+ (None, False),
+ ],
+)
+def test_has_model_key(monkeypatch, inference_request, model_key, expected):
+ """Test the has_model_key property with different values for model_key."""
+ monkeypatch.setattr(inference_request, "model_key", model_key)
+ assert inference_request.has_model_key == expected
+
+
+@pytest.mark.parametrize(
+ "raw_inputs, expected",
+ [([b"raw input bytes"], True), (None, False), ([], False)],
+)
+def test_has_raw_inputs(monkeypatch, inference_request, raw_inputs, expected):
+ """Test the has_raw_inputs property with different values for raw_inputs."""
+ monkeypatch.setattr(inference_request, "raw_inputs", raw_inputs)
+ assert inference_request.has_raw_inputs == expected
+
+
+@pytest.mark.parametrize(
+ "input_keys, expected",
+ [([fs_key], True), (None, False), ([], False)],
+)
+def test_has_input_keys(monkeypatch, inference_request, input_keys, expected):
+ """Test the has_input_keys property with different values for input_keys."""
+ monkeypatch.setattr(inference_request, "input_keys", input_keys)
+ assert inference_request.has_input_keys == expected
+
+
+@pytest.mark.parametrize(
+ "output_keys, expected",
+ [([fs_key], True), (None, False), ([], False)],
+)
+def test_has_output_keys(monkeypatch, inference_request, output_keys, expected):
+ """Test the has_output_keys property with different values for output_keys."""
+ monkeypatch.setattr(inference_request, "output_keys", output_keys)
+ assert inference_request.has_output_keys == expected
+
+
+@pytest.mark.parametrize(
+ "input_meta, expected",
+ [
+ ([handler.build_tensor_descriptor("c", "float32", [1, 2, 3])], True),
+ (None, False),
+ ([], False),
+ ],
+)
+def test_has_input_meta(monkeypatch, inference_request, input_meta, expected):
+ """Test the has_input_meta property with different values for input_meta."""
+ monkeypatch.setattr(inference_request, "input_meta", input_meta)
+ assert inference_request.has_input_meta == expected
diff --git a/tests/dragon_wlm/test_protoclient.py b/tests/dragon_wlm/test_protoclient.py
new file mode 100644
index 0000000000..f84417107d
--- /dev/null
+++ b/tests/dragon_wlm/test_protoclient.py
@@ -0,0 +1,313 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+import pickle
+import time
+import typing as t
+from unittest.mock import MagicMock
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
+from smartsim._core.mli.comm.channel.dragon_util import create_local
+from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster
+from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+# isort: off
+from dragon import fli
+from dragon.data.ddict.ddict import DDict
+
+# from ..ex..high_throughput_inference.mock_app import ProtoClient
+from smartsim._core.mli.client.protoclient import ProtoClient
+
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+WORK_QUEUE_KEY = BackboneFeatureStore.MLI_WORKER_QUEUE
+logger = get_logger(__name__)
+
+
+@pytest.fixture(scope="module")
+def the_worker_queue(the_backbone: BackboneFeatureStore) -> DragonFLIChannel:
+ """Fixture that creates a dragon FLI channel as a stand-in for the
+ worker queue created by the worker.
+
+ :param the_backbone: The backbone feature store to update
+ with the worker queue descriptor.
+ :returns: The attached `DragonFLIChannel`
+ """
+
+ # create the FLI
+ to_worker_channel = create_local()
+ fli_ = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None)
+ comm_channel = DragonFLIChannel(fli_)
+
+ # store the descriptor in the backbone
+ the_backbone.worker_queue = comm_channel.descriptor
+
+ try:
+ comm_channel.send(b"foo")
+ except Exception as ex:
+ logger.exception(f"Test send from worker channel failed", exc_info=True)
+
+ return comm_channel
+
+
+@pytest.mark.parametrize(
+ "backbone_timeout, exp_wait_max",
+ [
+ # aggregate the 1+1+1 into 3 on remaining parameters
+ pytest.param(0.5, 1 + 1 + 1, id="0.5s wait, 3 cycle steps"),
+ pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"),
+ pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"),
+ ],
+)
+def test_protoclient_timeout(
+ backbone_timeout: float,
+ exp_wait_max: float,
+ the_backbone: BackboneFeatureStore,
+ monkeypatch: pytest.MonkeyPatch,
+):
+ """Verify that attempts to attach to the worker queue from the protoclient
+ timeout in an appropriate amount of time. Note: due to the backoff, we verify
+ the elapsed time is less than the 15s of a cycle of waits.
+
+ :param backbone_timeout: a timeout for use when configuring a proto client
+ :param exp_wait_max: a ceiling for the expected time spent waiting for
+ the timeout
+ :param the_backbone: a pre-initialized backbone featurestore for setting up
+ the environment variable required by the client
+ """
+
+ # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8]
+ # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps)
+
+ with monkeypatch.context() as ctx, pytest.raises(SmartSimError) as ex:
+ start_time = time.time()
+ # remove the worker queue value from the backbone if it exists
+ # to ensure the timeout occurs
+ the_backbone.pop(BackboneFeatureStore.MLI_WORKER_QUEUE)
+
+ ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor)
+
+ ProtoClient(timing_on=False, backbone_timeout=backbone_timeout)
+ elapsed = time.time() - start_time
+ logger.info(f"ProtoClient timeout occurred in {elapsed} seconds")
+
+ # confirm that we met our timeout
+ assert (
+ elapsed >= backbone_timeout
+ ), f"below configured timeout {backbone_timeout}"
+
+ # confirm that the total wait time is aligned with the sleep cycle
+ assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}"
+
+
+def test_protoclient_initialization_no_backbone(
+ monkeypatch: pytest.MonkeyPatch, the_worker_queue: DragonFLIChannel
+):
+ """Verify that attempting to start the client without required environment variables
+ results in an exception.
+
+ :param the_worker_queue: Passing the worker queue fixture to ensure
+ the worker queue environment is correctly configured.
+
+ NOTE: os.environ[BackboneFeatureStore.MLI_BACKBONE] is not set"""
+
+ with monkeypatch.context() as patch, pytest.raises(SmartSimError) as ex:
+ patch.setenv(BackboneFeatureStore.MLI_BACKBONE, "")
+
+ ProtoClient(timing_on=False)
+
+ # confirm the missing value error has been raised
+ assert {"backbone", "configuration"}.issubset(set(ex.value.args[0].split(" ")))
+
+
+def test_protoclient_initialization(
+ the_backbone: BackboneFeatureStore,
+ the_worker_queue: DragonFLIChannel,
+ monkeypatch: pytest.MonkeyPatch,
+):
+ """Verify that attempting to start the client with required env vars results
+ in a fully initialized client.
+
+ :param the_backbone: a pre-initialized backbone featurestore
+ :param the_worker_queue: an FLI channel the client will retrieve
+ from the backbone"""
+
+ with monkeypatch.context() as ctx:
+ ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor)
+ # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone
+
+ client = ProtoClient(timing_on=False)
+
+ fs_descriptor = the_backbone.descriptor
+ wq_descriptor = the_worker_queue.descriptor
+
+ # confirm the backbone was attached correctly
+ assert client._backbone is not None
+ assert client._backbone.descriptor == fs_descriptor
+
+ # we expect the backbone to add its descriptor to the local env
+ assert os.environ[BackboneFeatureStore.MLI_BACKBONE] == fs_descriptor
+
+ # confirm the worker queue is created and attached correctly
+ assert client._to_worker_fli is not None
+ assert client._to_worker_fli.descriptor == wq_descriptor
+
+ # we expect the worker queue descriptor to be placed into the backbone
+ # we do NOT expect _from_worker_ch to be placed anywhere. it's a specific callback
+ assert the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] == wq_descriptor
+
+ # confirm the worker channels are created
+ assert client._from_worker_ch is not None
+ assert client._to_worker_ch is not None
+
+ # wrap the channels just to easily verify they produces a descriptor
+ assert DragonCommChannel(client._from_worker_ch).descriptor
+ assert DragonCommChannel(client._to_worker_ch).descriptor
+
+ # confirm a publisher is created
+ assert client._publisher is not None
+
+
+def test_protoclient_write_model(
+ the_backbone: BackboneFeatureStore,
+ the_worker_queue: DragonFLIChannel,
+ monkeypatch: pytest.MonkeyPatch,
+):
+ """Verify that writing a model using the client causes the model data to be
+ written to a feature store.
+
+ :param the_backbone: a pre-initialized backbone featurestore
+ :param the_worker_queue: Passing the worker queue fixture to ensure
+ the worker queue environment is correctly configured.
+ from the backbone
+ """
+
+ with monkeypatch.context() as ctx:
+ # we won't actually send here
+ client = ProtoClient(timing_on=False)
+
+ ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor)
+ # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone
+
+ client = ProtoClient(timing_on=False)
+
+ model_key = "my-model"
+ model_bytes = b"12345"
+
+ client.set_model(model_key, model_bytes)
+
+ # confirm the client modified the underlying feature store
+ assert client._backbone[model_key] == model_bytes
+
+
+@pytest.mark.parametrize(
+ "num_listeners, num_model_updates",
+ [(1, 1), (1, 4), (2, 4), (16, 4), (64, 8)],
+)
+def test_protoclient_write_model_notification_sent(
+ the_backbone: BackboneFeatureStore,
+ the_worker_queue: DragonFLIChannel,
+ monkeypatch: pytest.MonkeyPatch,
+ num_listeners: int,
+ num_model_updates: int,
+):
+ """Verify that writing a model sends a key-written event.
+
+ :param the_backbone: a pre-initialized backbone featurestore
+ :param the_worker_queue: an FLI channel the client will retrieve
+ from the backbone
+ :param num_listeners: vary the number of registered listeners
+ to verify that the event is broadcast to everyone
+ :param num_listeners: vary the number of listeners to register
+ to verify the broadcast counts messages sent correctly
+ """
+
+ # we won't actually send here, but it won't try without registered listeners
+ listeners = [f"mock-ch-desc-{i}" for i in range(num_listeners)]
+
+ the_backbone[BackboneFeatureStore.MLI_BACKBONE] = the_backbone.descriptor
+ the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_queue.descriptor
+ the_backbone[BackboneFeatureStore.MLI_NOTIFY_CONSUMERS] = ",".join(listeners)
+ the_backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = None
+
+ with monkeypatch.context() as ctx:
+ ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor)
+ # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone
+
+ client = ProtoClient(timing_on=False)
+
+ publisher = t.cast(EventBroadcaster, client._publisher)
+
+ # mock attaching to a channel given the mock-ch-desc in backbone
+ mock_send = MagicMock(return_value=None)
+ mock_comm_channel = MagicMock(**{"send": mock_send}, spec=DragonCommChannel)
+ mock_get_comm_channel = MagicMock(return_value=mock_comm_channel)
+ ctx.setattr(publisher, "_get_comm_channel", mock_get_comm_channel)
+
+ model_key = "my-model"
+ model_bytes = b"12345"
+
+ for i in range(num_model_updates):
+ client.set_model(model_key, model_bytes)
+
+ # confirm that a listener channel was attached
+ # once for each registered listener in backbone
+ assert mock_get_comm_channel.call_count == num_listeners * num_model_updates
+
+ # confirm the client raised the key-written event
+ assert (
+ mock_send.call_count == num_listeners * num_model_updates
+ ), f"Expected {num_listeners} sends with {num_listeners} registrations"
+
+ # with at least 1 consumer registered, we can verify the message is sent
+ for call_args in mock_send.call_args_list:
+ send_args = call_args.args
+ event_bytes, timeout = send_args[0], send_args[1]
+
+ assert event_bytes, "Expected event bytes to be supplied to send"
+ assert (
+ timeout == 0.001
+ ), "Expected default timeout on call to `publisher.send`, "
+
+ # confirm the correct event was raised
+ event = t.cast(
+ OnWriteFeatureStore,
+ pickle.loads(event_bytes),
+ )
+ assert event.descriptor == the_backbone.descriptor
+ assert event.key == model_key
diff --git a/tests/dragon_wlm/test_reply_building.py b/tests/dragon_wlm/test_reply_building.py
new file mode 100644
index 0000000000..1b0074ca0e
--- /dev/null
+++ b/tests/dragon_wlm/test_reply_building.py
@@ -0,0 +1,88 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+<<<<<<< HEAD:smartsim/entity/_mock.py
+"""This module contains stubs of functionality that is not currently
+implemented.
+
+THIS WHOLE MODULE SHOULD BE REMOVED IN FUTURE!!
+"""
+
+from __future__ import annotations
+
+import typing as t
+=======
+import typing as t
+
+import pytest
+
+dragon = pytest.importorskip("dragon")
+
+from smartsim._core.mli.infrastructure.control.worker_manager import build_failure_reply
+
+if t.TYPE_CHECKING:
+ from smartsim._core.mli.mli_schemas.response.response_capnp import Status
+>>>>>>> 5bdafc5f93fd56bf94ca5a7979a28f185c7c7ebf:tests/dragon_wlm/test_reply_building.py
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+<<<<<<< HEAD:smartsim/entity/_mock.py
+class Mock:
+ """Base mock class"""
+
+ def __init__(self, *_: t.Any, **__: t.Any): ...
+ def __getattr__(self, _: str) -> Mock:
+ return type(self)()
+
+ def __deepcopy__(self, _: dict[t.Any, t.Any]) -> Mock:
+ return type(self)()
+=======
+
+@pytest.mark.parametrize(
+ "status, message",
+ [
+ pytest.param("timeout", "Worker timed out", id="timeout"),
+ pytest.param("fail", "Failed while executing", id="fail"),
+ ],
+)
+def test_build_failure_reply(status: "Status", message: str):
+ "Ensures failure replies can be built successfully"
+ response = build_failure_reply(status, message)
+ display_name = response.schema.node.displayName # type: ignore
+ class_name = display_name.split(":")[-1]
+ assert class_name == "Response"
+ assert response.status == status
+ assert response.message == message
+
+
+def test_build_failure_reply_fails():
+ "Ensures ValueError is raised if a Status Enum is not used"
+ with pytest.raises(ValueError) as ex:
+ response = build_failure_reply("not a status enum", "message")
+
+ assert "Error assigning status to response" in ex.value.args[0]
+>>>>>>> 5bdafc5f93fd56bf94ca5a7979a28f185c7c7ebf:tests/dragon_wlm/test_reply_building.py
diff --git a/tests/dragon_wlm/test_request_dispatcher.py b/tests/dragon_wlm/test_request_dispatcher.py
new file mode 100644
index 0000000000..8dc0f67a31
--- /dev/null
+++ b/tests/dragon_wlm/test_request_dispatcher.py
@@ -0,0 +1,237 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import gc
+import os
+import time
+import typing as t
+from queue import Empty
+
+import numpy as np
+import pytest
+
+pytest.importorskip("dragon")
+
+
+# isort: off
+import dragon
+
+from dragon.fli import FLInterface
+from dragon.data.ddict.ddict import DDict
+from dragon.managed_memory import MemoryAlloc
+
+import multiprocessing as mp
+
+import torch
+
+# isort: on
+
+
+from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
+from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
+from smartsim._core.mli.comm.channel.dragon_util import create_local
+from smartsim._core.mli.infrastructure.control.dragon_util import (
+ function_as_dragon_proc,
+)
+from smartsim._core.mli.infrastructure.control.request_dispatcher import (
+ RequestBatch,
+ RequestDispatcher,
+)
+from smartsim._core.mli.infrastructure.control.worker_manager import (
+ EnvironmentConfigLoader,
+)
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
+ DragonFeatureStore,
+)
+from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
+from smartsim.log import get_logger
+
+from .utils.msg_pump import mock_messages
+
+logger = get_logger(__name__)
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+
+try:
+ mp.set_start_method("dragon")
+except Exception:
+ pass
+
+
+@pytest.mark.skip("TODO: Fix issue unpickling messages")
+@pytest.mark.parametrize("num_iterations", [4])
+def test_request_dispatcher(
+ num_iterations: int,
+ the_storage: DDict,
+ test_dir: str,
+) -> None:
+ """Test the request dispatcher batching and queueing system
+
+ This also includes setting a queue to disposable, checking that it is no
+ longer referenced by the dispatcher.
+ """
+
+ to_worker_channel = create_local()
+ to_worker_fli = FLInterface(main_ch=to_worker_channel, manager_ch=None)
+ to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli)
+
+ backbone_fs = BackboneFeatureStore(the_storage, allow_reserved_writes=True)
+
+ # NOTE: env vars should be set prior to instantiating EnvironmentConfigLoader
+ # or test environment may be unable to send messages w/queue
+ os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor
+ os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone_fs.descriptor
+
+ config_loader = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=DragonCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+
+ request_dispatcher = RequestDispatcher(
+ batch_timeout=1000,
+ batch_size=2,
+ config_loader=config_loader,
+ worker_type=TorchWorker,
+ mem_pool_size=2 * 1024**2,
+ )
+
+ worker_queue = config_loader.get_queue()
+ if worker_queue is None:
+ logger.warning(
+ "FLI input queue not loaded correctly from config_loader: "
+ f"{config_loader._queue_descriptor}"
+ )
+
+ request_dispatcher._on_start()
+
+ # put some messages into the work queue for the dispatcher to pickup
+ channels = []
+ processes = []
+ for i in range(num_iterations):
+ batch: t.Optional[RequestBatch] = None
+ mem_allocs = []
+ tensors = []
+
+ # NOTE: creating callbacks in test to avoid a local channel being torn
+ # down when mock_messages terms but before the final response message is sent
+
+ callback_channel = DragonCommChannel.from_local()
+ channels.append(callback_channel)
+
+ process = function_as_dragon_proc(
+ mock_messages,
+ [
+ worker_queue.descriptor,
+ backbone_fs.descriptor,
+ i,
+ callback_channel.descriptor,
+ ],
+ [],
+ [],
+ )
+ processes.append(process)
+ process.start()
+ assert process.returncode is None, "The message pump failed to start"
+
+ # give dragon some time to populate the message queues
+ for i in range(15):
+ try:
+ request_dispatcher._on_iteration()
+ batch = request_dispatcher.task_queue.get(timeout=1.0)
+ break
+ except Empty:
+ time.sleep(2)
+ logger.warning(f"Task queue is empty on iteration {i}")
+ continue
+ except Exception as exc:
+ logger.error(f"Task queue exception on iteration {i}")
+ raise exc
+
+ assert batch is not None
+ assert batch.has_valid_requests
+
+ model_key = batch.model_id.key
+
+ try:
+ transform_result = batch.inputs
+ for transformed, dims, dtype in zip(
+ transform_result.transformed,
+ transform_result.dims,
+ transform_result.dtypes,
+ ):
+ mem_alloc = MemoryAlloc.attach(transformed)
+ mem_allocs.append(mem_alloc)
+ itemsize = np.empty((1), dtype=dtype).itemsize
+ tensors.append(
+ torch.from_numpy(
+ np.frombuffer(
+ mem_alloc.get_memview()[0 : np.prod(dims) * itemsize],
+ dtype=dtype,
+ ).reshape(dims)
+ )
+ )
+
+ assert len(batch.requests) == 2
+ assert batch.model_id.key == model_key
+ assert model_key in request_dispatcher._queues
+ assert model_key in request_dispatcher._active_queues
+ assert len(request_dispatcher._queues[model_key]) == 1
+ assert request_dispatcher._queues[model_key][0].empty()
+ assert request_dispatcher._queues[model_key][0].model_id.key == model_key
+ assert len(tensors) == 1
+ assert tensors[0].shape == torch.Size([2, 2])
+
+ for tensor in tensors:
+ for sample_idx in range(tensor.shape[0]):
+ tensor_in = tensor[sample_idx]
+ tensor_out = (sample_idx + 1) * torch.ones(
+ (2,), dtype=torch.float32
+ )
+ assert torch.equal(tensor_in, tensor_out)
+
+ except Exception as exc:
+ raise exc
+ finally:
+ for mem_alloc in mem_allocs:
+ mem_alloc.free()
+
+ request_dispatcher._active_queues[model_key].make_disposable()
+ assert request_dispatcher._active_queues[model_key].can_be_removed
+
+ request_dispatcher._on_iteration()
+
+ assert model_key not in request_dispatcher._active_queues
+ assert model_key not in request_dispatcher._queues
+
+ # Try to remove the dispatcher and free the memory
+ del request_dispatcher
+ gc.collect()
diff --git a/tests/dragon_wlm/test_torch_worker.py b/tests/dragon_wlm/test_torch_worker.py
new file mode 100644
index 0000000000..2a9e7d01bd
--- /dev/null
+++ b/tests/dragon_wlm/test_torch_worker.py
@@ -0,0 +1,221 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import io
+import typing as t
+
+import numpy as np
+import pytest
+import torch
+
+dragon = pytest.importorskip("dragon")
+import dragon.globalservices.pool as dragon_gs_pool
+from dragon.managed_memory import MemoryAlloc, MemoryPool
+from torch import nn
+from torch.nn import functional as F
+
+from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey
+from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
+from smartsim._core.mli.infrastructure.worker.worker import (
+ ExecuteResult,
+ FetchInputResult,
+ FetchModelResult,
+ InferenceRequest,
+ LoadModelResult,
+ RequestBatch,
+ TransformInputResult,
+)
+from smartsim._core.mli.message_handler import MessageHandler
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+
+# simple MNIST in PyTorch
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
+ self.dropout1 = nn.Dropout(0.25)
+ self.dropout2 = nn.Dropout(0.5)
+ self.fc1 = nn.Linear(9216, 128)
+ self.fc2 = nn.Linear(128, 10)
+
+ def forward(self, x, y):
+ x = self.conv1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = F.relu(x)
+ x = F.max_pool2d(x, 2)
+ x = self.dropout1(x)
+ x = torch.flatten(x, 1)
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ output = F.log_softmax(x, dim=1)
+ return output
+
+
+torch_device = {"cpu": "cpu", "gpu": "cuda"}
+
+
+def get_batch() -> torch.Tensor:
+ return torch.rand(20, 1, 28, 28)
+
+
+def create_torch_model():
+ n = Net()
+ example_forward_input = get_batch()
+ module = torch.jit.trace(n, [example_forward_input, example_forward_input])
+ model_buffer = io.BytesIO()
+ torch.jit.save(module, model_buffer)
+ return model_buffer.getvalue()
+
+
+def get_request() -> InferenceRequest:
+
+ tensors = [get_batch() for _ in range(2)]
+ tensor_numpy = [tensor.numpy() for tensor in tensors]
+ serialized_tensors_descriptors = [
+ MessageHandler.build_tensor_descriptor("c", "float32", list(tensor.shape))
+ for tensor in tensors
+ ]
+
+ return InferenceRequest(
+ model_key=ModelKey(key="model", descriptor="xyz"),
+ callback=None,
+ raw_inputs=tensor_numpy,
+ input_keys=None,
+ input_meta=serialized_tensors_descriptors,
+ output_keys=None,
+ raw_model=create_torch_model(),
+ batch_size=0,
+ )
+
+
+def get_request_batch_from_request(
+ request: InferenceRequest, inputs: t.Optional[TransformInputResult] = None
+) -> RequestBatch:
+
+ return RequestBatch([request], inputs, request.model_key)
+
+
+sample_request: InferenceRequest = get_request()
+sample_request_batch: RequestBatch = get_request_batch_from_request(sample_request)
+worker = TorchWorker()
+
+
+def test_load_model(mlutils) -> None:
+ fetch_model_result = FetchModelResult(sample_request.raw_model)
+ load_model_result = worker.load_model(
+ sample_request_batch, fetch_model_result, mlutils.get_test_device().lower()
+ )
+
+ assert load_model_result.model(
+ get_batch().to(torch_device[mlutils.get_test_device().lower()]),
+ get_batch().to(torch_device[mlutils.get_test_device().lower()]),
+ ).shape == torch.Size((20, 10))
+
+
+def test_transform_input(mlutils) -> None:
+ fetch_input_result = FetchInputResult(
+ sample_request.raw_inputs, sample_request.input_meta
+ )
+
+ mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc)
+
+ transform_input_result = worker.transform_input(
+ sample_request_batch, [fetch_input_result], mem_pool
+ )
+
+ batch = get_batch().numpy()
+ assert transform_input_result.slices[0] == slice(0, batch.shape[0])
+
+ for tensor_index in range(2):
+ assert torch.Size(transform_input_result.dims[tensor_index]) == batch.shape
+ assert transform_input_result.dtypes[tensor_index] == str(batch.dtype)
+ mem_alloc = MemoryAlloc.attach(transform_input_result.transformed[tensor_index])
+ itemsize = batch.itemsize
+ tensor = torch.from_numpy(
+ np.frombuffer(
+ mem_alloc.get_memview()[
+ 0 : np.prod(transform_input_result.dims[tensor_index]) * itemsize
+ ],
+ dtype=transform_input_result.dtypes[tensor_index],
+ ).reshape(transform_input_result.dims[tensor_index])
+ )
+
+ assert torch.equal(
+ tensor, torch.from_numpy(sample_request.raw_inputs[tensor_index])
+ )
+
+ mem_pool.destroy()
+
+
+def test_execute(mlutils) -> None:
+ load_model_result = LoadModelResult(
+ Net().to(torch_device[mlutils.get_test_device().lower()])
+ )
+ fetch_input_result = FetchInputResult(
+ sample_request.raw_inputs, sample_request.input_meta
+ )
+
+ request_batch = get_request_batch_from_request(sample_request, fetch_input_result)
+
+ mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc)
+
+ transform_result = worker.transform_input(
+ request_batch, [fetch_input_result], mem_pool
+ )
+
+ execute_result = worker.execute(
+ request_batch,
+ load_model_result,
+ transform_result,
+ mlutils.get_test_device().lower(),
+ )
+
+ assert all(
+ result.shape == torch.Size((20, 10)) for result in execute_result.predictions
+ )
+
+ mem_pool.destroy()
+
+
+def test_transform_output(mlutils):
+ tensors = [torch.rand((20, 10)) for _ in range(2)]
+ execute_result = ExecuteResult(tensors, [slice(0, 20)])
+
+ transformed_output = worker.transform_output(sample_request_batch, execute_result)
+
+ assert transformed_output[0].outputs == [item.numpy().tobytes() for item in tensors]
+ assert transformed_output[0].shape == None
+ assert transformed_output[0].order == "c"
+ assert transformed_output[0].dtype == "float32"
diff --git a/tests/dragon_wlm/test_worker_manager.py b/tests/dragon_wlm/test_worker_manager.py
new file mode 100644
index 0000000000..20370bea7e
--- /dev/null
+++ b/tests/dragon_wlm/test_worker_manager.py
@@ -0,0 +1,313 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import io
+import logging
+import pathlib
+import time
+
+import pytest
+
+torch = pytest.importorskip("torch")
+dragon = pytest.importorskip("dragon")
+
+import multiprocessing as mp
+
+try:
+ mp.set_start_method("dragon")
+except Exception:
+ pass
+
+import os
+
+import torch.nn as nn
+from dragon import fli
+
+from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
+from smartsim._core.mli.comm.channel.dragon_util import create_local
+from smartsim._core.mli.infrastructure.control.worker_manager import (
+ EnvironmentConfigLoader,
+ WorkerManager,
+)
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
+ DragonFeatureStore,
+)
+from smartsim._core.mli.infrastructure.storage.dragon_util import create_ddict
+from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
+from smartsim._core.mli.message_handler import MessageHandler
+from smartsim.log import get_logger
+
+from .utils.channel import FileSystemCommChannel
+
+logger = get_logger(__name__)
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+
+class MiniModel(nn.Module):
+ """A torch model that can be executed by the default torch worker"""
+
+ def __init__(self):
+ """Initialize the model."""
+ super().__init__()
+
+ self._name = "mini-model"
+ self._net = torch.nn.Linear(2, 1)
+
+ def forward(self, input):
+ """Execute a forward pass."""
+ return self._net(input)
+
+ @property
+ def bytes(self) -> bytes:
+ """Retrieve the serialized model
+
+ :returns: The byte stream of the model file
+ """
+ buffer = io.BytesIO()
+ scripted = torch.jit.trace(self._net, self.get_batch())
+ torch.jit.save(scripted, buffer)
+ return buffer.getvalue()
+
+ @classmethod
+ def get_batch(cls) -> "torch.Tensor":
+ """Generate a single batch of data with the correct
+ shape for inference.
+
+ :returns: The batch as a torch tensor
+ """
+ return torch.randn((100, 2), dtype=torch.float32)
+
+
+def create_model(model_path: pathlib.Path) -> pathlib.Path:
+ """Create a simple torch model and persist to disk for
+ testing purposes.
+
+ :param model_path: The path to the torch model file
+ """
+ if not model_path.parent.exists():
+ model_path.parent.mkdir(parents=True, exist_ok=True)
+
+ model_path.unlink(missing_ok=True)
+
+ mini_model = MiniModel()
+ torch.save(mini_model, model_path)
+
+ return model_path
+
+
+def load_model() -> bytes:
+ """Create a simple torch model in memory for testing."""
+ mini_model = MiniModel()
+ return mini_model.bytes
+
+
+def mock_messages(
+ feature_store_root_dir: pathlib.Path,
+ comm_channel_root_dir: pathlib.Path,
+ kill_queue: mp.Queue,
+) -> None:
+ """Mock event producer for triggering the inference pipeline.
+
+ :param feature_store_root_dir: Path to a directory where a
+ FileSystemFeatureStore can read & write results
+ :param comm_channel_root_dir: Path to a directory where a
+ FileSystemCommChannel can read & write messages
+ :param kill_queue: Queue used by unit test to stop mock_message process
+ """
+ feature_store_root_dir.mkdir(parents=True, exist_ok=True)
+ comm_channel_root_dir.mkdir(parents=True, exist_ok=True)
+
+ iteration_number = 0
+
+ config_loader = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=FileSystemCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+ backbone = config_loader.get_backbone()
+
+ worker_queue = config_loader.get_queue()
+ if worker_queue is None:
+ queue_desc = config_loader._queue_descriptor
+ logger.warn(
+ f"FLI input queue not loaded correctly from config_loader: {queue_desc}"
+ )
+
+ model_key = "mini-model"
+ model_bytes = load_model()
+ backbone[model_key] = model_bytes
+
+ while True:
+ if not kill_queue.empty():
+ return
+ iteration_number += 1
+ time.sleep(1)
+
+ channel_key = comm_channel_root_dir / f"{iteration_number}/channel.txt"
+ callback_channel = FileSystemCommChannel(pathlib.Path(channel_key))
+
+ batch = MiniModel.get_batch()
+ shape = batch.shape
+ batch_bytes = batch.numpy().tobytes()
+
+ logger.debug(f"Model content: {backbone[model_key][:20]}")
+
+ input_descriptor = MessageHandler.build_tensor_descriptor(
+ "f", "float32", list(shape)
+ )
+
+ # The first request is always the metadata...
+ request = MessageHandler.build_request(
+ reply_channel=callback_channel.descriptor,
+ model=MessageHandler.build_model(model_bytes, "mini-model", "1.0"),
+ inputs=[input_descriptor],
+ outputs=[],
+ output_descriptors=[],
+ custom_attributes=None,
+ )
+ request_bytes = MessageHandler.serialize_request(request)
+ fli: DragonFLIChannel = worker_queue
+
+ multipart_message = [request_bytes, batch_bytes]
+ fli.send_multiple(multipart_message)
+
+ logger.info("published message")
+
+ if iteration_number > 5:
+ return
+
+
+def mock_mli_infrastructure_mgr() -> None:
+ """Create resources normally instanatiated by the infrastructure
+ management portion of the DragonBackend.
+ """
+ config_loader = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=FileSystemCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+
+ integrated_worker = TorchWorker
+
+ worker_manager = WorkerManager(
+ config_loader,
+ integrated_worker,
+ as_service=True,
+ cooldown=10,
+ device="cpu",
+ dispatcher_queue=mp.Queue(maxsize=0),
+ )
+ worker_manager.execute()
+
+
+@pytest.fixture
+def prepare_environment(test_dir: str) -> pathlib.Path:
+ """Cleanup prior outputs to run demo repeatedly.
+
+ :param test_dir: the directory to prepare
+ :returns: The path to the log file
+ """
+ path = pathlib.Path(f"{test_dir}/workermanager.log")
+ logging.basicConfig(filename=path.absolute(), level=logging.DEBUG)
+ return path
+
+
+def test_worker_manager(prepare_environment: pathlib.Path) -> None:
+ """Test the worker manager.
+
+ :param prepare_environment: Pass this fixture to configure
+ global resources before the worker manager executes
+ """
+
+ test_path = prepare_environment
+ fs_path = test_path / "feature_store"
+ comm_path = test_path / "comm_store"
+
+ mgr_per_node = 1
+ num_nodes = 2
+ mem_per_node = 128 * 1024**2
+
+ storage = create_ddict(num_nodes, mgr_per_node, mem_per_node)
+ backbone = BackboneFeatureStore(storage, allow_reserved_writes=True)
+
+ to_worker_channel = create_local()
+ to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None)
+
+ to_worker_fli_comm_channel = DragonFLIChannel(to_worker_fli)
+
+ # NOTE: env vars must be set prior to instantiating EnvironmentConfigLoader
+ # or test environment may be unable to send messages w/queue
+ os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = (
+ to_worker_fli_comm_channel.descriptor
+ )
+ os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor
+
+ config_loader = EnvironmentConfigLoader(
+ featurestore_factory=DragonFeatureStore.from_descriptor,
+ callback_factory=FileSystemCommChannel.from_descriptor,
+ queue_factory=DragonFLIChannel.from_descriptor,
+ )
+ integrated_worker_type = TorchWorker
+
+ worker_manager = WorkerManager(
+ config_loader,
+ integrated_worker_type,
+ as_service=True,
+ cooldown=5,
+ device="cpu",
+ dispatcher_queue=mp.Queue(maxsize=0),
+ )
+
+ worker_queue = config_loader.get_queue()
+ if worker_queue is None:
+ logger.warn(
+ f"FLI input queue not loaded correctly from config_loader: {config_loader._queue_descriptor}"
+ )
+ backbone.worker_queue = to_worker_fli_comm_channel.descriptor
+
+ # create a mock client application to populate the request queue
+ kill_queue = mp.Queue()
+ msg_pump = mp.Process(
+ target=mock_messages,
+ args=(fs_path, comm_path, kill_queue),
+ )
+ msg_pump.start()
+
+ # create a process to execute commands
+ process = mp.Process(target=mock_mli_infrastructure_mgr)
+
+ # let it send some messages before starting the worker manager
+ msg_pump.join(timeout=5)
+ process.start()
+ msg_pump.join(timeout=5)
+ kill_queue.put_nowait("kill!")
+ process.join(timeout=5)
+ msg_pump.kill()
+ process.kill()
diff --git a/tests/dragon_wlm/utils/__init__.py b/tests/dragon_wlm/utils/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/dragon_wlm/utils/channel.py b/tests/dragon_wlm/utils/channel.py
new file mode 100644
index 0000000000..4c46359c2d
--- /dev/null
+++ b/tests/dragon_wlm/utils/channel.py
@@ -0,0 +1,125 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import base64
+import pathlib
+import threading
+import typing as t
+
+from smartsim._core.mli.comm.channel.channel import CommChannelBase
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class FileSystemCommChannel(CommChannelBase):
+ """Passes messages by writing to a file"""
+
+ def __init__(self, key: pathlib.Path) -> None:
+ """Initialize the FileSystemCommChannel instance.
+
+ :param key: a path to the root directory of the feature store
+ """
+ self._lock = threading.RLock()
+
+ super().__init__(key.as_posix())
+ self._file_path = key
+
+ if not self._file_path.parent.exists():
+ self._file_path.parent.mkdir(parents=True)
+
+ self._file_path.touch()
+
+ def send(self, value: bytes, timeout: float = 0) -> None:
+ """Send a message throuh the underlying communication channel.
+
+ :param value: The value to send
+ :param timeout: maximum time to wait (in seconds) for messages to send
+ """
+ with self._lock:
+ # write as text so we can add newlines as delimiters
+ with open(self._file_path, "a") as fp:
+ encoded_value = base64.b64encode(value).decode("utf-8")
+ fp.write(f"{encoded_value}\n")
+ logger.debug(f"FileSystemCommChannel {self._file_path} sent message")
+
+ def recv(self, timeout: float = 0) -> t.List[bytes]:
+ """Receives message(s) through the underlying communication channel.
+
+ :param timeout: maximum time to wait (in seconds) for messages to arrive
+ :returns: the received message
+ :raises SmartSimError: if the descriptor points to a missing file
+ """
+ with self._lock:
+ messages: t.List[bytes] = []
+ if not self._file_path.exists():
+ raise SmartSimError("Empty channel")
+
+ # read as text so we can split on newlines
+ with open(self._file_path, "r") as fp:
+ lines = fp.readlines()
+
+ if lines:
+ line = lines.pop(0)
+ event_bytes = base64.b64decode(line.encode("utf-8"))
+ messages.append(event_bytes)
+
+ self.clear()
+
+ # remove the first message only, write remainder back...
+ if len(lines) > 0:
+ with open(self._file_path, "w") as fp:
+ fp.writelines(lines)
+
+ logger.debug(
+ f"FileSystemCommChannel {self._file_path} received message"
+ )
+
+ return messages
+
+ def clear(self) -> None:
+ """Create an empty file for events."""
+ if self._file_path.exists():
+ self._file_path.unlink()
+ self._file_path.touch()
+
+ @classmethod
+ def from_descriptor(
+ cls,
+ descriptor: str,
+ ) -> "FileSystemCommChannel":
+ """A factory method that creates an instance from a descriptor string.
+
+ :param descriptor: The descriptor that uniquely identifies the resource
+ :returns: An attached FileSystemCommChannel
+ """
+ try:
+ path = pathlib.Path(descriptor)
+ return FileSystemCommChannel(path)
+ except:
+ logger.warning(f"failed to create fs comm channel: {descriptor}")
+ raise
diff --git a/tests/dragon_wlm/utils/msg_pump.py b/tests/dragon_wlm/utils/msg_pump.py
new file mode 100644
index 0000000000..8d69e57c63
--- /dev/null
+++ b/tests/dragon_wlm/utils/msg_pump.py
@@ -0,0 +1,225 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import io
+import logging
+import pathlib
+import sys
+import time
+import typing as t
+
+import pytest
+
+pytest.importorskip("torch")
+pytest.importorskip("dragon")
+
+
+# isort: off
+import dragon
+import multiprocessing as mp
+import torch
+import torch.nn as nn
+
+# isort: on
+
+from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
+from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
+ BackboneFeatureStore,
+)
+from smartsim._core.mli.message_handler import MessageHandler
+from smartsim.log import get_logger
+
+logger = get_logger(__name__, log_level=logging.DEBUG)
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+
+try:
+ mp.set_start_method("dragon")
+except Exception:
+ pass
+
+
+class MiniModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self._name = "mini-model"
+ self._net = torch.nn.Linear(2, 1)
+
+ def forward(self, input):
+ return self._net(input)
+
+ @property
+ def bytes(self) -> bytes:
+ """Returns the model serialized to a byte stream"""
+ buffer = io.BytesIO()
+ scripted = torch.jit.trace(self._net, self.get_batch())
+ torch.jit.save(scripted, buffer)
+ return buffer.getvalue()
+
+ @classmethod
+ def get_batch(cls) -> "torch.Tensor":
+ return torch.randn((100, 2), dtype=torch.float32)
+
+
+def load_model() -> bytes:
+ """Create a simple torch model in memory for testing"""
+ mini_model = MiniModel()
+ return mini_model.bytes
+
+
+def persist_model_file(model_path: pathlib.Path) -> pathlib.Path:
+ """Create a simple torch model and persist to disk for
+ testing purposes.
+
+ :returns: Path to the model file
+ """
+ # test_path = pathlib.Path(work_dir)
+ if not model_path.parent.exists():
+ model_path.parent.mkdir(parents=True, exist_ok=True)
+
+ model_path.unlink(missing_ok=True)
+
+ model = torch.nn.Linear(2, 1)
+ torch.save(model, model_path)
+
+ return model_path
+
+
+def _mock_messages(
+ dispatch_fli_descriptor: str,
+ fs_descriptor: str,
+ parent_iteration: int,
+ callback_descriptor: str,
+) -> None:
+ """Mock event producer for triggering the inference pipeline."""
+ model_key = "mini-model"
+ # mock_message sends 2 messages, so we offset by 2 * (# of iterations in caller)
+ offset = 2 * parent_iteration
+
+ feature_store = BackboneFeatureStore.from_descriptor(fs_descriptor)
+ request_dispatcher_queue = DragonFLIChannel.from_descriptor(dispatch_fli_descriptor)
+
+ feature_store[model_key] = load_model()
+
+ for iteration_number in range(2):
+ logged_iteration = offset + iteration_number
+ logger.debug(f"Sending mock message {logged_iteration}")
+
+ output_key = f"output-{iteration_number}"
+
+ tensor = (
+ (iteration_number + 1) * torch.ones((1, 2), dtype=torch.float32)
+ ).numpy()
+ fsd = feature_store.descriptor
+
+ tensor_desc = MessageHandler.build_tensor_descriptor(
+ "c", "float32", list(tensor.shape)
+ )
+
+ message_tensor_output_key = MessageHandler.build_tensor_key(output_key, fsd)
+ message_model_key = MessageHandler.build_model_key(model_key, fsd)
+
+ request = MessageHandler.build_request(
+ reply_channel=callback_descriptor,
+ model=message_model_key,
+ inputs=[tensor_desc],
+ outputs=[message_tensor_output_key],
+ output_descriptors=[],
+ custom_attributes=None,
+ )
+
+ logger.info(f"Sending request {iteration_number} to request_dispatcher_queue")
+ request_bytes = MessageHandler.serialize_request(request)
+
+ logger.info("Sending msg_envelope")
+
+ # cuid = request_dispatcher_queue._channel.cuid
+ # logger.info(f"\tInternal cuid: {cuid}")
+
+ # send the header & body together so they arrive together
+ try:
+ request_dispatcher_queue.send_multiple([request_bytes, tensor.tobytes()])
+ logger.info(f"\tenvelope 0: {request_bytes[:5]}...")
+ logger.info(f"\tenvelope 1: {tensor.tobytes()[:5]}...")
+ except Exception as ex:
+ logger.exception("Unable to send request envelope")
+
+ logger.info("All messages sent")
+
+ # keep the process alive for an extra 15 seconds to let the processor
+ # have access to the channels before they're destroyed
+ for _ in range(15):
+ time.sleep(1)
+
+
+def mock_messages(
+ dispatch_fli_descriptor: str,
+ fs_descriptor: str,
+ parent_iteration: int,
+ callback_descriptor: str,
+) -> int:
+ """Mock event producer for triggering the inference pipeline. Used
+ when starting using multiprocessing."""
+ logger.info(f"{dispatch_fli_descriptor=}")
+ logger.info(f"{fs_descriptor=}")
+ logger.info(f"{parent_iteration=}")
+ logger.info(f"{callback_descriptor=}")
+
+ try:
+ return _mock_messages(
+ dispatch_fli_descriptor,
+ fs_descriptor,
+ parent_iteration,
+ callback_descriptor,
+ )
+ except Exception as ex:
+ logger.exception()
+ return 1
+
+ return 0
+
+
+if __name__ == "__main__":
+ import argparse
+
+ args = argparse.ArgumentParser()
+
+ args.add_argument("--dispatch-fli-descriptor", type=str)
+ args.add_argument("--fs-descriptor", type=str)
+ args.add_argument("--parent-iteration", type=int)
+ args.add_argument("--callback-descriptor", type=str)
+
+ args = args.parse_args()
+
+ return_code = mock_messages(
+ args.dispatch_fli_descriptor,
+ args.fs_descriptor,
+ args.parent_iteration,
+ args.callback_descriptor,
+ )
+ sys.exit(return_code)
diff --git a/tests/dragon_wlm/utils/worker.py b/tests/dragon_wlm/utils/worker.py
new file mode 100644
index 0000000000..0582cae566
--- /dev/null
+++ b/tests/dragon_wlm/utils/worker.py
@@ -0,0 +1,104 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import io
+import typing as t
+
+import torch
+
+import smartsim._core.mli.infrastructure.worker.worker as mliw
+import smartsim.error as sse
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class IntegratedTorchWorker(mliw.MachineLearningWorkerBase):
+ """A minimum implementation of a worker that executes a PyTorch model"""
+
+ # @staticmethod
+ # def deserialize(request: InferenceRequest) -> t.List[t.Any]:
+ # # request.input_meta
+ # # request.raw_inputs
+ # return request
+
+ @staticmethod
+ def load_model(
+ request: mliw.InferenceRequest, fetch_result: mliw.FetchModelResult, device: str
+ ) -> mliw.LoadModelResult:
+ model_bytes = fetch_result.model_bytes or request.raw_model
+ if not model_bytes:
+ raise ValueError("Unable to load model without reference object")
+
+ model: torch.nn.Module = torch.load(io.BytesIO(model_bytes))
+ result = mliw.LoadModelResult(model)
+ return result
+
+ @staticmethod
+ def transform_input(
+ request: mliw.InferenceRequest,
+ fetch_result: mliw.FetchInputResult,
+ device: str,
+ ) -> mliw.TransformInputResult:
+ # extra metadata for assembly can be found in request.input_meta
+ raw_inputs = request.raw_inputs or fetch_result.inputs
+
+ result: t.List[torch.Tensor] = []
+ # should this happen here?
+ # consider - fortran to c data layout
+ # is there an intermediate representation before really doing torch.load?
+ if raw_inputs:
+ result = [torch.load(io.BytesIO(item)) for item in raw_inputs]
+
+ return mliw.TransformInputResult(result)
+
+ @staticmethod
+ def execute(
+ request: mliw.InferenceRequest,
+ load_result: mliw.LoadModelResult,
+ transform_result: mliw.TransformInputResult,
+ ) -> mliw.ExecuteResult:
+ if not load_result.model:
+ raise sse.SmartSimError("Model must be loaded to execute")
+
+ model = load_result.model
+ results = [model(tensor) for tensor in transform_result.transformed]
+
+ execute_result = mliw.ExecuteResult(results)
+ return execute_result
+
+ @staticmethod
+ def transform_output(
+ request: mliw.InferenceRequest,
+ execute_result: mliw.ExecuteResult,
+ result_device: str,
+ ) -> mliw.TransformOutputResult:
+ # send the original tensors...
+ execute_result.predictions = [t.detach() for t in execute_result.predictions]
+ # todo: solve sending all tensor metadata that coincisdes with each prediction
+ return mliw.TransformOutputResult(
+ execute_result.predictions, [1], "c", "float32"
+ )
diff --git a/tests/mli/__init__.py b/tests/mli/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/mli/channel.py b/tests/mli/channel.py
new file mode 100644
index 0000000000..4c46359c2d
--- /dev/null
+++ b/tests/mli/channel.py
@@ -0,0 +1,125 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import base64
+import pathlib
+import threading
+import typing as t
+
+from smartsim._core.mli.comm.channel.channel import CommChannelBase
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class FileSystemCommChannel(CommChannelBase):
+ """Passes messages by writing to a file"""
+
+ def __init__(self, key: pathlib.Path) -> None:
+ """Initialize the FileSystemCommChannel instance.
+
+ :param key: a path to the root directory of the feature store
+ """
+ self._lock = threading.RLock()
+
+ super().__init__(key.as_posix())
+ self._file_path = key
+
+ if not self._file_path.parent.exists():
+ self._file_path.parent.mkdir(parents=True)
+
+ self._file_path.touch()
+
+ def send(self, value: bytes, timeout: float = 0) -> None:
+ """Send a message throuh the underlying communication channel.
+
+ :param value: The value to send
+ :param timeout: maximum time to wait (in seconds) for messages to send
+ """
+ with self._lock:
+ # write as text so we can add newlines as delimiters
+ with open(self._file_path, "a") as fp:
+ encoded_value = base64.b64encode(value).decode("utf-8")
+ fp.write(f"{encoded_value}\n")
+ logger.debug(f"FileSystemCommChannel {self._file_path} sent message")
+
+ def recv(self, timeout: float = 0) -> t.List[bytes]:
+ """Receives message(s) through the underlying communication channel.
+
+ :param timeout: maximum time to wait (in seconds) for messages to arrive
+ :returns: the received message
+ :raises SmartSimError: if the descriptor points to a missing file
+ """
+ with self._lock:
+ messages: t.List[bytes] = []
+ if not self._file_path.exists():
+ raise SmartSimError("Empty channel")
+
+ # read as text so we can split on newlines
+ with open(self._file_path, "r") as fp:
+ lines = fp.readlines()
+
+ if lines:
+ line = lines.pop(0)
+ event_bytes = base64.b64decode(line.encode("utf-8"))
+ messages.append(event_bytes)
+
+ self.clear()
+
+ # remove the first message only, write remainder back...
+ if len(lines) > 0:
+ with open(self._file_path, "w") as fp:
+ fp.writelines(lines)
+
+ logger.debug(
+ f"FileSystemCommChannel {self._file_path} received message"
+ )
+
+ return messages
+
+ def clear(self) -> None:
+ """Create an empty file for events."""
+ if self._file_path.exists():
+ self._file_path.unlink()
+ self._file_path.touch()
+
+ @classmethod
+ def from_descriptor(
+ cls,
+ descriptor: str,
+ ) -> "FileSystemCommChannel":
+ """A factory method that creates an instance from a descriptor string.
+
+ :param descriptor: The descriptor that uniquely identifies the resource
+ :returns: An attached FileSystemCommChannel
+ """
+ try:
+ path = pathlib.Path(descriptor)
+ return FileSystemCommChannel(path)
+ except:
+ logger.warning(f"failed to create fs comm channel: {descriptor}")
+ raise
diff --git a/tests/mli/feature_store.py b/tests/mli/feature_store.py
new file mode 100644
index 0000000000..7bc18253c8
--- /dev/null
+++ b/tests/mli/feature_store.py
@@ -0,0 +1,144 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pathlib
+import typing as t
+
+import smartsim.error as sse
+from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class MemoryFeatureStore(FeatureStore):
+ """A feature store with values persisted only in local memory"""
+
+ def __init__(self, storage: t.Optional[t.Dict[str, bytes]] = None) -> None:
+ """Initialize the MemoryFeatureStore instance"""
+ super().__init__("in-memory-fs")
+ if storage is None:
+ storage = {"_": "abc"}
+ self._storage: t.Dict[str, bytes] = storage
+
+ def _get(self, key: str) -> bytes:
+ """Retrieve an item using key
+
+ :param key: Unique key of an item to retrieve from the feature store"""
+ if key not in self._storage:
+ raise sse.SmartSimError(f"{key} not found in feature store")
+ return self._storage[key]
+
+ def _set(self, key: str, value: bytes) -> None:
+ """Membership operator to test for a key existing within the feature store.
+
+ :param key: Unique key of an item to retrieve from the feature store
+ :returns: `True` if the key is found, `False` otherwise"""
+ self._check_reserved(key)
+ self._storage[key] = value
+
+ def _contains(self, key: str) -> bool:
+ """Membership operator to test for a key existing within the feature store.
+ Return `True` if the key is found, `False` otherwise
+ :param key: Unique key of an item to retrieve from the feature store"""
+ return key in self._storage
+
+
+class FileSystemFeatureStore(FeatureStore):
+ """Alternative feature store implementation for testing. Stores all
+ data on the file system"""
+
+ def __init__(self, storage_dir: t.Union[pathlib.Path, str] = None) -> None:
+ """Initialize the FileSystemFeatureStore instance
+
+ :param storage_dir: (optional) root directory to store all data relative to"""
+ if isinstance(storage_dir, str):
+ storage_dir = pathlib.Path(storage_dir)
+ self._storage_dir = storage_dir
+ super().__init__(storage_dir.as_posix())
+
+ def _get(self, key: str) -> bytes:
+ """Retrieve an item using key
+
+ :param key: Unique key of an item to retrieve from the feature store"""
+ path = self._key_path(key)
+ if not path.exists():
+ raise sse.SmartSimError(f"{path} not found in feature store")
+ return path.read_bytes()
+
+ def _set(self, key: str, value: bytes) -> None:
+ """Assign a value using key
+
+ :param key: Unique key of an item to set in the feature store
+ :param value: Value to persist in the feature store"""
+ path = self._key_path(key, create=True)
+ if isinstance(value, str):
+ value = value.encode("utf-8")
+ path.write_bytes(value)
+
+ def _contains(self, key: str) -> bool:
+ """Membership operator to test for a key existing within the feature store.
+
+ :param key: Unique key of an item to retrieve from the feature store
+ :returns: `True` if the key is found, `False` otherwise"""
+ path = self._key_path(key)
+ return path.exists()
+
+ def _key_path(self, key: str, create: bool = False) -> pathlib.Path:
+ """Given a key, return a path that is optionally combined with a base
+ directory used by the FileSystemFeatureStore.
+
+ :param key: Unique key of an item to retrieve from the feature store"""
+ value = pathlib.Path(key)
+
+ if self._storage_dir:
+ value = self._storage_dir / key
+
+ if create:
+ value.parent.mkdir(parents=True, exist_ok=True)
+
+ return value
+
+ @classmethod
+ def from_descriptor(
+ cls,
+ descriptor: str,
+ ) -> "FileSystemFeatureStore":
+ """A factory method that creates an instance from a descriptor string
+
+ :param descriptor: The descriptor that uniquely identifies the resource
+ :returns: An attached FileSystemFeatureStore"""
+ try:
+ path = pathlib.Path(descriptor)
+ path.mkdir(parents=True, exist_ok=True)
+ if not path.is_dir():
+ raise ValueError("FileSystemFeatureStore requires a directory path")
+ if not path.exists():
+ path.mkdir(parents=True, exist_ok=True)
+ return FileSystemFeatureStore(path)
+ except:
+ logger.error(f"Error while creating FileSystemFeatureStore: {descriptor}")
+ raise
diff --git a/tests/mli/test_integrated_torch_worker.py b/tests/mli/test_integrated_torch_worker.py
new file mode 100644
index 0000000000..4d93358bfb
--- /dev/null
+++ b/tests/mli/test_integrated_torch_worker.py
@@ -0,0 +1,271 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pathlib
+
+import pytest
+import torch
+
+# The tests in this file belong to the group_b group
+pytestmark = pytest.mark.group_b
+
+# retrieved from pytest fixtures
+is_dragon = pytest.test_launcher == "dragon"
+torch_available = (
+ "torch" in []
+) # todo: update test to replace installed_redisai_backends()
+
+
+@pytest.fixture
+def persist_torch_model(test_dir: str) -> pathlib.Path:
+ test_path = pathlib.Path(test_dir)
+ model_path = test_path / "basic.pt"
+
+ model = torch.nn.Linear(2, 1)
+ torch.save(model, model_path)
+
+ return model_path
+
+
+# todo: move deserialization tests into suite for worker manager where serialization occurs
+
+
+# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+# def test_deserialize_direct_request(persist_torch_model: pathlib.Path) -> None:
+# """Verify that a direct requestis deserialized properly"""
+# worker = mli.IntegratedTorchWorker
+# # feature_store = mli.MemoryFeatureStore()
+
+# model_bytes = persist_torch_model.read_bytes()
+# input_tensor = torch.randn(2)
+
+# expected_callback_channel = b"faux_channel_descriptor_bytes"
+# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
+
+# message_tensor_input = MessageHandler.build_tensor(
+# input_tensor, "c", "float32", [2]
+# )
+
+# request = MessageHandler.build_request(
+# reply_channel=callback_channel.descriptor,
+# model=model_bytes,
+# inputs=[message_tensor_input],
+# outputs=[],
+# custom_attributes=None,
+# )
+
+# msg_bytes = MessageHandler.serialize_request(request)
+
+# inference_request = worker.deserialize(msg_bytes)
+# assert inference_request.callback._descriptor == expected_callback_channel
+
+
+# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+# def test_deserialize_indirect_request(persist_torch_model: pathlib.Path) -> None:
+# """Verify that an indirect request is deserialized correctly"""
+# worker = mli.IntegratedTorchWorker
+# # feature_store = mli.MemoryFeatureStore()
+
+# model_key = "persisted-model"
+# # model_bytes = persist_torch_model.read_bytes()
+# # feature_store[model_key] = model_bytes
+
+# input_key = f"demo-input"
+# # input_tensor = torch.randn(2)
+# # feature_store[input_key] = input_tensor
+
+# expected_callback_channel = b"faux_channel_descriptor_bytes"
+# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
+
+# output_key = f"demo-output"
+
+# message_tensor_output_key = MessageHandler.build_tensor_key(output_key)
+# message_tensor_input_key = MessageHandler.build_tensor_key(input_key)
+# message_model_key = MessageHandler.build_model_key(model_key)
+
+# request = MessageHandler.build_request(
+# reply_channel=callback_channel.descriptor,
+# model=message_model_key,
+# inputs=[message_tensor_input_key],
+# outputs=[message_tensor_output_key],
+# custom_attributes=None,
+# )
+
+# msg_bytes = MessageHandler.serialize_request(request)
+
+# inference_request = worker.deserialize(msg_bytes)
+# assert inference_request.callback._descriptor == expected_callback_channel
+
+
+# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+# def test_deserialize_mixed_mode_indirect_inputs(
+# persist_torch_model: pathlib.Path,
+# ) -> None:
+# """Verify that a mixed mode (combining direct and indirect inputs, models, outputs)
+# with indirect inputs is deserialized correctly"""
+# worker = mli.IntegratedTorchWorker
+# # feature_store = mli.MemoryFeatureStore()
+
+# # model_key = "persisted-model"
+# model_bytes = persist_torch_model.read_bytes()
+# # feature_store[model_key] = model_bytes
+
+# input_key = f"demo-input"
+# # input_tensor = torch.randn(2)
+# # feature_store[input_key] = input_tensor
+
+# expected_callback_channel = b"faux_channel_descriptor_bytes"
+# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
+
+# output_key = f"demo-output"
+
+# message_tensor_output_key = MessageHandler.build_tensor_key(output_key)
+# message_tensor_input_key = MessageHandler.build_tensor_key(input_key)
+# # message_model_key = MessageHandler.build_model_key(model_key)
+
+# request = MessageHandler.build_request(
+# reply_channel=callback_channel.descriptor,
+# model=model_bytes,
+# inputs=[message_tensor_input_key],
+# # outputs=[message_tensor_output_key],
+# outputs=[],
+# custom_attributes=None,
+# )
+
+# msg_bytes = MessageHandler.serialize_request(request)
+
+# inference_request = worker.deserialize(msg_bytes)
+# assert inference_request.callback._descriptor == expected_callback_channel
+
+
+# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+# def test_deserialize_mixed_mode_indirect_outputs(
+# persist_torch_model: pathlib.Path,
+# ) -> None:
+# """Verify that a mixed mode (combining direct and indirect inputs, models, outputs)
+# with indirect outputs is deserialized correctly"""
+# worker = mli.IntegratedTorchWorker
+# # feature_store = mli.MemoryFeatureStore()
+
+# # model_key = "persisted-model"
+# model_bytes = persist_torch_model.read_bytes()
+# # feature_store[model_key] = model_bytes
+
+# input_key = f"demo-input"
+# input_tensor = torch.randn(2)
+# # feature_store[input_key] = input_tensor
+
+# expected_callback_channel = b"faux_channel_descriptor_bytes"
+# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
+
+# output_key = f"demo-output"
+
+# message_tensor_output_key = MessageHandler.build_tensor_key(output_key)
+# # message_tensor_input_key = MessageHandler.build_tensor_key(input_key)
+# # message_model_key = MessageHandler.build_model_key(model_key)
+# message_tensor_input = MessageHandler.build_tensor(
+# input_tensor, "c", "float32", [2]
+# )
+
+# request = MessageHandler.build_request(
+# reply_channel=callback_channel.descriptor,
+# model=model_bytes,
+# inputs=[message_tensor_input],
+# # outputs=[message_tensor_output_key],
+# outputs=[message_tensor_output_key],
+# custom_attributes=None,
+# )
+
+# msg_bytes = MessageHandler.serialize_request(request)
+
+# inference_request = worker.deserialize(msg_bytes)
+# assert inference_request.callback._descriptor == expected_callback_channel
+
+
+# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+# def test_deserialize_mixed_mode_indirect_model(
+# persist_torch_model: pathlib.Path,
+# ) -> None:
+# """Verify that a mixed mode (combining direct and indirect inputs, models, outputs)
+# with indirect outputs is deserialized correctly"""
+# worker = mli.IntegratedTorchWorker
+# # feature_store = mli.MemoryFeatureStore()
+
+# model_key = "persisted-model"
+# # model_bytes = persist_torch_model.read_bytes()
+# # feature_store[model_key] = model_bytes
+
+# # input_key = f"demo-input"
+# input_tensor = torch.randn(2)
+# # feature_store[input_key] = input_tensor
+
+# expected_callback_channel = b"faux_channel_descriptor_bytes"
+# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
+
+# output_key = f"demo-output"
+
+# # message_tensor_output_key = MessageHandler.build_tensor_key(output_key)
+# # message_tensor_input_key = MessageHandler.build_tensor_key(input_key)
+# message_model_key = MessageHandler.build_model_key(model_key)
+# message_tensor_input = MessageHandler.build_tensor(
+# input_tensor, "c", "float32", [2]
+# )
+
+# request = MessageHandler.build_request(
+# reply_channel=callback_channel.descriptor,
+# model=message_model_key,
+# inputs=[message_tensor_input],
+# # outputs=[message_tensor_output_key],
+# outputs=[],
+# custom_attributes=None,
+# )
+
+# msg_bytes = MessageHandler.serialize_request(request)
+
+# inference_request = worker.deserialize(msg_bytes)
+# assert inference_request.callback._descriptor == expected_callback_channel
+
+
+# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed")
+# def test_serialize(test_dir: str, persist_torch_model: pathlib.Path) -> None:
+# """Verify that the worker correctly executes reply serialization"""
+# worker = mli.IntegratedTorchWorker
+
+# reply = mli.InferenceReply()
+# reply.output_keys = ["foo", "bar"]
+
+# # use the worker implementation of reply serialization to get bytes for
+# # use on the callback channel
+# reply_bytes = worker.serialize_reply(reply)
+# assert reply_bytes is not None
+
+# # deserialize to verity the mapping in the worker.serialize_reply was correct
+# actual_reply = MessageHandler.deserialize_response(reply_bytes)
+
+# actual_tensor_keys = [tk.key for tk in actual_reply.result.keys]
+# assert set(actual_tensor_keys) == set(reply.output_keys)
+# assert actual_reply.status == 200
+# assert actual_reply.statusMessage == "success"
diff --git a/tests/mli/test_service.py b/tests/mli/test_service.py
new file mode 100644
index 0000000000..41595ca80b
--- /dev/null
+++ b/tests/mli/test_service.py
@@ -0,0 +1,290 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import datetime
+import multiprocessing as mp
+import pathlib
+import time
+import typing as t
+from asyncore import loop
+
+import pytest
+import torch
+
+import smartsim.error as sse
+from smartsim._core.entrypoints.service import Service
+
+# The tests in this file belong to the group_b group
+pytestmark = pytest.mark.group_a
+
+
+class SimpleService(Service):
+ """Mock implementation of a service that counts method invocations
+ using the base class event hooks."""
+
+ def __init__(
+ self,
+ log: t.List[str],
+ quit_after: int = -1,
+ as_service: bool = False,
+ cooldown: float = 0,
+ loop_delay: float = 0,
+ hc_freq: float = -1,
+ run_for: float = 0,
+ ) -> None:
+ super().__init__(as_service, cooldown, loop_delay, hc_freq)
+ self._log = log
+ self._quit_after = quit_after
+ self.num_starts = 0
+ self.num_shutdowns = 0
+ self.num_health_checks = 0
+ self.num_cooldowns = 0
+ self.num_delays = 0
+ self.num_iterations = 0
+ self.num_can_shutdown = 0
+ self.run_for = run_for
+ self.start_time = time.time()
+
+ @property
+ def runtime(self) -> float:
+ return time.time() - self.start_time
+
+ def _can_shutdown(self) -> bool:
+ self.num_can_shutdown += 1
+
+ if self._quit_after > -1 and self.num_iterations >= self._quit_after:
+ return True
+ if self.run_for > 0:
+ return self.runtime >= self.run_for
+
+ def _on_start(self) -> None:
+ self.num_starts += 1
+
+ def _on_shutdown(self) -> None:
+ self.num_shutdowns += 1
+
+ def _on_health_check(self) -> None:
+ self.num_health_checks += 1
+
+ def _on_cooldown_elapsed(self) -> None:
+ self.num_cooldowns += 1
+
+ def _on_delay(self) -> None:
+ self.num_delays += 1
+
+ def _on_iteration(self) -> None:
+ self.num_iterations += 1
+
+ return self.num_iterations >= self._quit_after
+
+
+def test_service_init() -> None:
+ """Verify expected default values after Service initialization"""
+ activity_log: t.List[str] = []
+ service = SimpleService(activity_log)
+
+ assert service._as_service is False
+ assert service._cooldown == 0
+ assert service._loop_delay == 0
+
+
+def test_service_run_once() -> None:
+ """Verify the service completes after a single call to _on_iteration"""
+ activity_log: t.List[str] = []
+ service = SimpleService(activity_log)
+
+ service.execute()
+
+ assert service.num_iterations == 1
+ assert service.num_starts == 1
+ assert service.num_cooldowns == 0 # it never exceeds a cooldown period
+ assert service.num_can_shutdown == 0 # it automatically exits in run once
+ assert service.num_shutdowns == 1
+
+
+@pytest.mark.parametrize(
+ "num_iterations",
+ [
+ pytest.param(0, id="Immediate Shutdown"),
+ pytest.param(1, id="1x"),
+ pytest.param(2, id="2x"),
+ pytest.param(4, id="4x"),
+ pytest.param(8, id="8x"),
+ pytest.param(16, id="16x"),
+ pytest.param(32, id="32x"),
+ ],
+)
+def test_service_run_until_can_shutdown(num_iterations: int) -> None:
+ """Verify the service completes after a dynamic number of iterations
+ based on the return value of `_can_shutdown`"""
+ activity_log: t.List[str] = []
+
+ service = SimpleService(activity_log, quit_after=num_iterations, as_service=True)
+
+ service.execute()
+
+ if num_iterations == 0:
+ # no matter what, it should always execute the _on_iteration method
+ assert service.num_iterations == 1
+ else:
+ # the shutdown check follows on_iteration. there will be one last call
+ assert service.num_iterations == num_iterations
+
+ assert service.num_starts == 1
+ assert service.num_shutdowns == 1
+
+
+@pytest.mark.parametrize(
+ "cooldown",
+ [
+ pytest.param(1, id="1s"),
+ pytest.param(3, id="3s"),
+ pytest.param(5, id="5s"),
+ ],
+)
+def test_service_cooldown(cooldown: int) -> None:
+ """Verify that the cooldown period is respected"""
+ activity_log: t.List[str] = []
+
+ service = SimpleService(
+ activity_log,
+ quit_after=1,
+ as_service=True,
+ cooldown=cooldown,
+ loop_delay=0,
+ )
+
+ ts0 = datetime.datetime.now()
+ service.execute()
+ ts1 = datetime.datetime.now()
+
+ fudge_factor = 1.1 # allow a little bit of wiggle room for the loop
+ duration_in_seconds = (ts1 - ts0).total_seconds()
+
+ assert duration_in_seconds <= cooldown * fudge_factor
+ assert service.num_cooldowns == 1
+ assert service.num_shutdowns == 1
+
+
+@pytest.mark.parametrize(
+ "delay, num_iterations",
+ [
+ pytest.param(1, 3, id="1s delay, 3x"),
+ pytest.param(3, 2, id="2s delay, 2x"),
+ pytest.param(5, 1, id="5s delay, 1x"),
+ ],
+)
+def test_service_delay(delay: int, num_iterations: int) -> None:
+ """Verify that a delay is correctly added between iterations"""
+ activity_log: t.List[str] = []
+
+ service = SimpleService(
+ activity_log,
+ quit_after=num_iterations,
+ as_service=True,
+ cooldown=0,
+ loop_delay=delay,
+ )
+
+ ts0 = datetime.datetime.now()
+ service.execute()
+ ts1 = datetime.datetime.now()
+
+ # the expected duration is the sum of the delay between each iteration
+ expected_duration = (num_iterations + 1) * delay
+ duration_in_seconds = (ts1 - ts0).total_seconds()
+
+ assert duration_in_seconds <= expected_duration
+ assert service.num_cooldowns == 0
+ assert service.num_shutdowns == 1
+
+
+@pytest.mark.parametrize(
+ "health_check_freq, run_for",
+ [
+ pytest.param(1, 5.5, id="1s freq, 10x"),
+ pytest.param(5, 10.5, id="5s freq, 2x"),
+ pytest.param(0.1, 5.1, id="0.1s freq, 50x"),
+ ],
+)
+def test_service_health_check_freq(health_check_freq: float, run_for: float) -> None:
+ """Verify that a the health check frequency is honored
+
+ :param health_check_freq: The desired frequency of the health check
+ :pram run_for: A fixed duration to allow the service to run
+ """
+ activity_log: t.List[str] = []
+
+ service = SimpleService(
+ activity_log,
+ quit_after=-1,
+ as_service=True,
+ cooldown=0,
+ hc_freq=health_check_freq,
+ run_for=run_for,
+ )
+
+ ts0 = datetime.datetime.now()
+ service.execute()
+ ts1 = datetime.datetime.now()
+
+ # the expected duration is the sum of the delay between each iteration
+ expected_hc_count = run_for // health_check_freq
+
+ # allow some wiggle room for frequency comparison
+ assert expected_hc_count - 2 <= service.num_health_checks <= expected_hc_count + 2
+
+ assert service.num_cooldowns == 0
+ assert service.num_shutdowns == 1
+
+
+def test_service_health_check_freq_unbound() -> None:
+ """Verify that a health check frequency of zero is treated as
+ "always on" and is called each loop iteration
+
+ :param health_check_freq: The desired frequency of the health check
+ :pram run_for: A fixed duration to allow the service to run
+ """
+ health_check_freq: float = 0.0
+ run_for: float = 5
+
+ activity_log: t.List[str] = []
+
+ service = SimpleService(
+ activity_log,
+ quit_after=-1,
+ as_service=True,
+ cooldown=0,
+ hc_freq=health_check_freq,
+ run_for=run_for,
+ )
+
+ service.execute()
+
+ # allow some wiggle room for frequency comparison
+ assert service.num_health_checks == service.num_iterations
+ assert service.num_cooldowns == 0
+ assert service.num_shutdowns == 1
diff --git a/tests/mli/worker.py b/tests/mli/worker.py
new file mode 100644
index 0000000000..0582cae566
--- /dev/null
+++ b/tests/mli/worker.py
@@ -0,0 +1,104 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import io
+import typing as t
+
+import torch
+
+import smartsim._core.mli.infrastructure.worker.worker as mliw
+import smartsim.error as sse
+from smartsim.log import get_logger
+
+logger = get_logger(__name__)
+
+
+class IntegratedTorchWorker(mliw.MachineLearningWorkerBase):
+ """A minimum implementation of a worker that executes a PyTorch model"""
+
+ # @staticmethod
+ # def deserialize(request: InferenceRequest) -> t.List[t.Any]:
+ # # request.input_meta
+ # # request.raw_inputs
+ # return request
+
+ @staticmethod
+ def load_model(
+ request: mliw.InferenceRequest, fetch_result: mliw.FetchModelResult, device: str
+ ) -> mliw.LoadModelResult:
+ model_bytes = fetch_result.model_bytes or request.raw_model
+ if not model_bytes:
+ raise ValueError("Unable to load model without reference object")
+
+ model: torch.nn.Module = torch.load(io.BytesIO(model_bytes))
+ result = mliw.LoadModelResult(model)
+ return result
+
+ @staticmethod
+ def transform_input(
+ request: mliw.InferenceRequest,
+ fetch_result: mliw.FetchInputResult,
+ device: str,
+ ) -> mliw.TransformInputResult:
+ # extra metadata for assembly can be found in request.input_meta
+ raw_inputs = request.raw_inputs or fetch_result.inputs
+
+ result: t.List[torch.Tensor] = []
+ # should this happen here?
+ # consider - fortran to c data layout
+ # is there an intermediate representation before really doing torch.load?
+ if raw_inputs:
+ result = [torch.load(io.BytesIO(item)) for item in raw_inputs]
+
+ return mliw.TransformInputResult(result)
+
+ @staticmethod
+ def execute(
+ request: mliw.InferenceRequest,
+ load_result: mliw.LoadModelResult,
+ transform_result: mliw.TransformInputResult,
+ ) -> mliw.ExecuteResult:
+ if not load_result.model:
+ raise sse.SmartSimError("Model must be loaded to execute")
+
+ model = load_result.model
+ results = [model(tensor) for tensor in transform_result.transformed]
+
+ execute_result = mliw.ExecuteResult(results)
+ return execute_result
+
+ @staticmethod
+ def transform_output(
+ request: mliw.InferenceRequest,
+ execute_result: mliw.ExecuteResult,
+ result_device: str,
+ ) -> mliw.TransformOutputResult:
+ # send the original tensors...
+ execute_result.predictions = [t.detach() for t in execute_result.predictions]
+ # todo: solve sending all tensor metadata that coincisdes with each prediction
+ return mliw.TransformOutputResult(
+ execute_result.predictions, [1], "c", "float32"
+ )
diff --git a/tests/test_dragon_comm_utils.py b/tests/test_dragon_comm_utils.py
new file mode 100644
index 0000000000..a6f9c206a4
--- /dev/null
+++ b/tests/test_dragon_comm_utils.py
@@ -0,0 +1,257 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import base64
+import pathlib
+import uuid
+
+import pytest
+
+from smartsim.error.errors import SmartSimError
+
+dragon = pytest.importorskip("dragon")
+
+# isort: off
+import dragon.channels as dch
+import dragon.infrastructure.parameters as dp
+import dragon.managed_memory as dm
+import dragon.fli as fli
+
+# isort: on
+
+from smartsim._core.mli.comm.channel import dragon_util
+from smartsim.log import get_logger
+
+# The tests in this file belong to the dragon group
+pytestmark = pytest.mark.dragon
+logger = get_logger(__name__)
+
+
+@pytest.fixture(scope="function")
+def the_pool() -> dm.MemoryPool:
+ """Creates a memory pool."""
+ raw_pool_descriptor = dp.this_process.default_pd
+ descriptor_ = base64.b64decode(raw_pool_descriptor)
+
+ pool = dm.MemoryPool.attach(descriptor_)
+ return pool
+
+
+@pytest.fixture(scope="function")
+def the_channel() -> dch.Channel:
+ """Creates a Channel attached to the local memory pool."""
+ channel = dch.Channel.make_process_local()
+ return channel
+
+
+@pytest.fixture(scope="function")
+def the_fli(the_channel) -> fli.FLInterface:
+ """Creates an FLI attached to the local memory pool."""
+ fli_ = fli.FLInterface(main_ch=the_channel, manager_ch=None)
+ return fli_
+
+
+def test_descriptor_to_channel_empty() -> None:
+ """Verify that `descriptor_to_channel` raises an exception when
+ provided with an empty descriptor."""
+ descriptor = ""
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.descriptor_to_channel(descriptor)
+
+ assert "empty" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "descriptor",
+ ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()],
+)
+def test_descriptor_to_channel_b64fail(descriptor: str) -> None:
+ """Verify that `descriptor_to_channel` raises an exception when
+ provided with an incorrectly encoded descriptor.
+
+ :param descriptor: A descriptor that is not properly base64 encoded
+ """
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.descriptor_to_channel(descriptor)
+
+ assert "base64" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "descriptor",
+ [str(uuid.uuid4())],
+)
+def test_descriptor_to_channel_channel_fail(descriptor: str) -> None:
+ """Verify that `descriptor_to_channel` raises an exception when a correctly
+ formatted descriptor that does not describe a real channel is passed.
+
+ :param descriptor: A descriptor that is not properly base64 encoded
+ """
+
+ with pytest.raises(SmartSimError) as ex:
+ dragon_util.descriptor_to_channel(descriptor)
+
+ # ensure we're receiving the right exception
+ assert "address" in ex.value.args[0]
+ assert "channel" in ex.value.args[0]
+
+
+def test_descriptor_to_channel_channel_not_available(the_channel: dch.Channel) -> None:
+ """Verify that `descriptor_to_channel` raises an exception when a channel
+ is no longer available.
+
+ :param the_channel: A dragon channel
+ """
+
+ # get a good descriptor & wipe out the channel so it can't be attached
+ descriptor = dragon_util.channel_to_descriptor(the_channel)
+ the_channel.destroy()
+
+ with pytest.raises(SmartSimError) as ex:
+ dragon_util.descriptor_to_channel(descriptor)
+
+ assert "address" in ex.value.args[0]
+
+
+def test_descriptor_to_channel_happy_path(the_channel: dch.Channel) -> None:
+ """Verify that `descriptor_to_channel` works as expected when provided
+ a valid descriptor
+
+ :param the_channel: A dragon channel
+ """
+
+ # get a good descriptor
+ descriptor = dragon_util.channel_to_descriptor(the_channel)
+
+ reattached = dragon_util.descriptor_to_channel(descriptor)
+ assert reattached
+
+ # and just make sure creation of the descriptor is transitive
+ assert dragon_util.channel_to_descriptor(reattached) == descriptor
+
+
+def test_descriptor_to_fli_empty() -> None:
+ """Verify that `descriptor_to_fli` raises an exception when
+ provided with an empty descriptor."""
+ descriptor = ""
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.descriptor_to_fli(descriptor)
+
+ assert "empty" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "descriptor",
+ ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()],
+)
+def test_descriptor_to_fli_b64fail(descriptor: str) -> None:
+ """Verify that `descriptor_to_fli` raises an exception when
+ provided with an incorrectly encoded descriptor.
+
+ :param descriptor: A descriptor that is not properly base64 encoded
+ """
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.descriptor_to_fli(descriptor)
+
+ assert "base64" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "descriptor",
+ [str(uuid.uuid4())],
+)
+def test_descriptor_to_fli_fli_fail(descriptor: str) -> None:
+ """Verify that `descriptor_to_fli` raises an exception when a correctly
+ formatted descriptor that does not describe a real FLI is passed.
+
+ :param descriptor: A descriptor that is not properly base64 encoded
+ """
+
+ with pytest.raises(SmartSimError) as ex:
+ dragon_util.descriptor_to_fli(descriptor)
+
+ # ensure we're receiving the right exception
+ assert "address" in ex.value.args[0]
+ assert "fli" in ex.value.args[0].lower()
+
+
+def test_descriptor_to_fli_fli_not_available(
+ the_fli: fli.FLInterface, the_channel: dch.Channel
+) -> None:
+ """Verify that `descriptor_to_fli` raises an exception when a channel
+ is no longer available.
+
+ :param the_fli: A dragon FLInterface
+ :param the_channel: A dragon channel
+ """
+
+ # get a good descriptor & wipe out the FLI so it can't be attached
+ descriptor = dragon_util.channel_to_descriptor(the_fli)
+ the_fli.destroy()
+ the_channel.destroy()
+
+ with pytest.raises(SmartSimError) as ex:
+ dragon_util.descriptor_to_fli(descriptor)
+
+ # ensure we're receiving the right exception
+ assert "address" in ex.value.args[0]
+
+
+def test_descriptor_to_fli_happy_path(the_fli: dch.Channel) -> None:
+ """Verify that `descriptor_to_fli` works as expected when provided
+ a valid descriptor
+
+ :param the_fli: A dragon FLInterface
+ """
+
+ # get a good descriptor
+ descriptor = dragon_util.channel_to_descriptor(the_fli)
+
+ reattached = dragon_util.descriptor_to_fli(descriptor)
+ assert reattached
+
+ # and just make sure creation of the descriptor is transitive
+ assert dragon_util.channel_to_descriptor(reattached) == descriptor
+
+
+def test_pool_to_descriptor_empty() -> None:
+ """Verify that `pool_to_descriptor` raises an exception when
+ provided with a null pool."""
+
+ with pytest.raises(ValueError) as ex:
+ dragon_util.pool_to_descriptor(None)
+
+
+def test_pool_to_happy_path(the_pool) -> None:
+ """Verify that `pool_to_descriptor` creates a descriptor
+ when supplied with a valid memory pool."""
+
+ descriptor = dragon_util.pool_to_descriptor(the_pool)
+ assert descriptor
diff --git a/tests/test_generator.py b/tests/test_generator.py
index 3915526a8b..f949d8f663 100644
--- a/tests/test_generator.py
+++ b/tests/test_generator.py
@@ -85,15 +85,13 @@ def as_executable_sequence(self):
def mock_job() -> unittest.mock.MagicMock:
"""Fixture to create a mock Job."""
job = unittest.mock.MagicMock(
- **{
- "entity": EchoHelloWorldEntity(),
- "name": "test_job",
- "get_launch_steps": unittest.mock.MagicMock(
- side_effect=lambda: NotImplementedError()
- ),
- },
+ entity=EchoHelloWorldEntity(),
+ get_launch_steps=unittest.mock.MagicMock(
+ side_effect=lambda: NotImplementedError()
+ ),
spec=Job,
)
+ job.name = "test_job"
yield job
diff --git a/tests/test_message_handler/__init__.py b/tests/test_message_handler/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/test_message_handler/test_build_model.py b/tests/test_message_handler/test_build_model.py
new file mode 100644
index 0000000000..56c1c8764c
--- /dev/null
+++ b/tests/test_message_handler/test_build_model.py
@@ -0,0 +1,72 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+handler = MessageHandler()
+
+
+def test_build_model_successful():
+ expected_data = b"model data"
+ expected_name = "model name"
+ expected_version = "v0.0.1"
+ model = handler.build_model(expected_data, expected_name, expected_version)
+ assert model.data == expected_data
+ assert model.name == expected_name
+ assert model.version == expected_version
+
+
+@pytest.mark.parametrize(
+ "data, name, version",
+ [
+ pytest.param(
+ 100,
+ "model name",
+ "v0.0.1",
+ id="bad data type",
+ ),
+ pytest.param(
+ b"model data",
+ 1,
+ "v0.0.1",
+ id="bad name type",
+ ),
+ pytest.param(
+ b"model data",
+ "model name",
+ 0.1,
+ id="bad version type",
+ ),
+ ],
+)
+def test_build_model_unsuccessful(data, name, version):
+ with pytest.raises(ValueError):
+ model = handler.build_model(data, name, version)
diff --git a/tests/test_message_handler/test_build_model_key.py b/tests/test_message_handler/test_build_model_key.py
new file mode 100644
index 0000000000..6c9b3dc951
--- /dev/null
+++ b/tests/test_message_handler/test_build_model_key.py
@@ -0,0 +1,47 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+handler = MessageHandler()
+
+
+def test_build_model_key_successful():
+ fsd = "mock-feature-store-descriptor"
+ model_key = handler.build_model_key("tensor_key", fsd)
+ assert model_key.key == "tensor_key"
+ assert model_key.descriptor == fsd
+
+
+def test_build_model_key_unsuccessful():
+ with pytest.raises(ValueError):
+ fsd = "mock-feature-store-descriptor"
+ model_key = handler.build_model_key(100, fsd)
diff --git a/tests/test_message_handler/test_build_request_attributes.py b/tests/test_message_handler/test_build_request_attributes.py
new file mode 100644
index 0000000000..5b1e09b0aa
--- /dev/null
+++ b/tests/test_message_handler/test_build_request_attributes.py
@@ -0,0 +1,55 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+handler = MessageHandler()
+
+
+def test_build_torch_request_attributes_successful():
+ attribute = handler.build_torch_request_attributes("sparse")
+ assert attribute.tensorType == "sparse"
+
+
+def test_build_torch_request_attributes_unsuccessful():
+ with pytest.raises(ValueError):
+ attribute = handler.build_torch_request_attributes("invalid!")
+
+
+def test_build_tf_request_attributes_successful():
+ attribute = handler.build_tf_request_attributes(name="tfcnn", tensor_type="sparse")
+ assert attribute.tensorType == "sparse"
+ assert attribute.name == "tfcnn"
+
+
+def test_build_tf_request_attributes_unsuccessful():
+ with pytest.raises(ValueError):
+ attribute = handler.build_tf_request_attributes("tf_fail", "invalid!")
diff --git a/tests/test_message_handler/test_build_tensor_desc.py b/tests/test_message_handler/test_build_tensor_desc.py
new file mode 100644
index 0000000000..45126fb16c
--- /dev/null
+++ b/tests/test_message_handler/test_build_tensor_desc.py
@@ -0,0 +1,90 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+handler = MessageHandler()
+
+
+@pytest.mark.parametrize(
+ "dtype, order, dimension",
+ [
+ pytest.param(
+ "int8",
+ "c",
+ [3, 2, 5],
+ id="small torch tensor",
+ ),
+ pytest.param(
+ "int64",
+ "c",
+ [1040, 1040, 3],
+ id="medium torch tensor",
+ ),
+ ],
+)
+def test_build_tensor_descriptor_successful(dtype, order, dimension):
+ built_tensor_descriptor = handler.build_tensor_descriptor(order, dtype, dimension)
+ assert built_tensor_descriptor is not None
+ assert built_tensor_descriptor.order == order
+ assert built_tensor_descriptor.dataType == dtype
+ for i, j in zip(built_tensor_descriptor.dimensions, dimension):
+ assert i == j
+
+
+@pytest.mark.parametrize(
+ "dtype, order, dimension",
+ [
+ pytest.param(
+ "bad_order",
+ "int8",
+ [3, 2, 5],
+ id="bad order type",
+ ),
+ pytest.param(
+ "f",
+ "bad_num_type",
+ [3, 2, 5],
+ id="bad numerical type",
+ ),
+ pytest.param(
+ "f",
+ "int8",
+ "bad shape type",
+ id="bad shape type",
+ ),
+ ],
+)
+def test_build_tensor_descriptor_unsuccessful(dtype, order, dimension):
+ with pytest.raises(ValueError):
+ built_tensor_descriptor = handler.build_tensor_descriptor(
+ order, dtype, dimension
+ )
diff --git a/tests/test_message_handler/test_build_tensor_key.py b/tests/test_message_handler/test_build_tensor_key.py
new file mode 100644
index 0000000000..6a28b80c4f
--- /dev/null
+++ b/tests/test_message_handler/test_build_tensor_key.py
@@ -0,0 +1,46 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+handler = MessageHandler()
+
+
+def test_build_tensor_key_successful():
+ fsd = "mock-feature-store-descriptor"
+ tensor_key = handler.build_tensor_key("tensor_key", fsd)
+ assert tensor_key.key == "tensor_key"
+
+
+def test_build_tensor_key_unsuccessful():
+ with pytest.raises(ValueError):
+ fsd = "mock-feature-store-descriptor"
+ tensor_key = handler.build_tensor_key(100, fsd)
diff --git a/tests/test_message_handler/test_output_descriptor.py b/tests/test_message_handler/test_output_descriptor.py
new file mode 100644
index 0000000000..beb9a47657
--- /dev/null
+++ b/tests/test_message_handler/test_output_descriptor.py
@@ -0,0 +1,78 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+handler = MessageHandler()
+
+fsd = "mock-feature-store-descriptor"
+tensor_key = handler.build_tensor_key("key", fsd)
+
+
+@pytest.mark.parametrize(
+ "order, keys, dtype, dimension",
+ [
+ pytest.param("c", [tensor_key], "int8", [1, 2, 3, 4], id="all specified"),
+ pytest.param(
+ "c", [tensor_key, tensor_key], "none", [1, 2, 3, 4], id="none dtype"
+ ),
+ pytest.param("c", [tensor_key], "int8", [], id="empty dimensions"),
+ pytest.param("c", [], "int8", [1, 2, 3, 4], id="empty keys"),
+ ],
+)
+def test_build_output_tensor_descriptor_successful(dtype, keys, order, dimension):
+ built_descriptor = handler.build_output_tensor_descriptor(
+ order, keys, dtype, dimension
+ )
+ assert built_descriptor is not None
+ assert built_descriptor.order == order
+ assert len(built_descriptor.optionalKeys) == len(keys)
+ assert built_descriptor.optionalDatatype == dtype
+ for i, j in zip(built_descriptor.optionalDimension, dimension):
+ assert i == j
+
+
+@pytest.mark.parametrize(
+ "order, keys, dtype, dimension",
+ [
+ pytest.param("bad_order", [], "int8", [3, 2, 5], id="bad order type"),
+ pytest.param(
+ "f", [tensor_key], "bad_num_type", [3, 2, 5], id="bad numerical type"
+ ),
+ pytest.param("f", [tensor_key], "int8", "bad shape type", id="bad shape type"),
+ pytest.param("f", ["tensor_key"], "int8", [3, 2, 5], id="bad key type"),
+ ],
+)
+def test_build_output_tensor_descriptor_unsuccessful(order, keys, dtype, dimension):
+ with pytest.raises(ValueError):
+ built_tensor = handler.build_output_tensor_descriptor(
+ order, keys, dtype, dimension
+ )
diff --git a/tests/test_message_handler/test_request.py b/tests/test_message_handler/test_request.py
new file mode 100644
index 0000000000..a60818f7dd
--- /dev/null
+++ b/tests/test_message_handler/test_request.py
@@ -0,0 +1,449 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+fsd = "mock-feature-store-descriptor"
+
+model_key = MessageHandler.build_model_key("model_key", fsd)
+model = MessageHandler.build_model(b"model data", "model_name", "v0.0.1")
+
+input_key1 = MessageHandler.build_tensor_key("input_key1", fsd)
+input_key2 = MessageHandler.build_tensor_key("input_key2", fsd)
+
+output_key1 = MessageHandler.build_tensor_key("output_key1", fsd)
+output_key2 = MessageHandler.build_tensor_key("output_key2", fsd)
+
+output_descriptor1 = MessageHandler.build_output_tensor_descriptor(
+ "c", [output_key1, output_key2], "int64", []
+)
+output_descriptor2 = MessageHandler.build_output_tensor_descriptor("f", [], "auto", [])
+output_descriptor3 = MessageHandler.build_output_tensor_descriptor(
+ "c", [output_key1], "none", [1, 2, 3]
+)
+torch_attributes = MessageHandler.build_torch_request_attributes("sparse")
+tf_attributes = MessageHandler.build_tf_request_attributes(
+ name="tf", tensor_type="sparse"
+)
+
+tensor_1 = MessageHandler.build_tensor_descriptor("c", "int8", [1])
+tensor_2 = MessageHandler.build_tensor_descriptor("c", "int64", [3, 2])
+tensor_3 = MessageHandler.build_tensor_descriptor("f", "int8", [1])
+tensor_4 = MessageHandler.build_tensor_descriptor("f", "int64", [3, 2])
+
+
+tf_indirect_request = MessageHandler.build_request(
+ b"reply",
+ model,
+ [input_key1, input_key2],
+ [output_key1, output_key2],
+ [output_descriptor1, output_descriptor2, output_descriptor3],
+ tf_attributes,
+)
+
+tf_direct_request = MessageHandler.build_request(
+ b"reply",
+ model,
+ [tensor_3, tensor_4],
+ [],
+ [output_descriptor1, output_descriptor2],
+ tf_attributes,
+)
+
+torch_indirect_request = MessageHandler.build_request(
+ b"reply",
+ model,
+ [input_key1, input_key2],
+ [output_key1, output_key2],
+ [output_descriptor1, output_descriptor2, output_descriptor3],
+ torch_attributes,
+)
+
+torch_direct_request = MessageHandler.build_request(
+ b"reply",
+ model,
+ [tensor_1, tensor_2],
+ [],
+ [output_descriptor1, output_descriptor2],
+ torch_attributes,
+)
+
+
+@pytest.mark.parametrize(
+ "reply_channel, model, input, output, output_descriptors, custom_attributes",
+ [
+ pytest.param(
+ "reply channel",
+ model_key,
+ [input_key1, input_key2],
+ [output_key1, output_key2],
+ [output_descriptor1],
+ torch_attributes,
+ ),
+ pytest.param(
+ "another reply channel",
+ model,
+ [input_key1],
+ [output_key2],
+ [output_descriptor1],
+ tf_attributes,
+ ),
+ pytest.param(
+ "another reply channel",
+ model,
+ [input_key1],
+ [output_key2],
+ [output_descriptor1],
+ torch_attributes,
+ ),
+ pytest.param(
+ "reply channel",
+ model_key,
+ [input_key1],
+ [output_key1],
+ [output_descriptor1],
+ None,
+ ),
+ ],
+)
+def test_build_request_indirect_successful(
+ reply_channel, model, input, output, output_descriptors, custom_attributes
+):
+ built_request = MessageHandler.build_request(
+ reply_channel,
+ model,
+ input,
+ output,
+ output_descriptors,
+ custom_attributes,
+ )
+ assert built_request is not None
+ assert built_request.replyChannel.descriptor == reply_channel
+ if built_request.model.which() == "key":
+ assert built_request.model.key.key == model.key
+ else:
+ assert built_request.model.data.data == model.data
+ assert built_request.model.data.name == model.name
+ assert built_request.model.data.version == model.version
+ assert built_request.input.which() == "keys"
+ assert built_request.input.keys[0].key == input[0].key
+ assert len(built_request.input.keys) == len(input)
+ assert len(built_request.output) == len(output)
+ for i, j in zip(built_request.outputDescriptors, output_descriptors):
+ assert i.order == j.order
+ if built_request.customAttributes.which() == "tf":
+ assert (
+ built_request.customAttributes.tf.tensorType == custom_attributes.tensorType
+ )
+ elif built_request.customAttributes.which() == "torch":
+ assert (
+ built_request.customAttributes.torch.tensorType
+ == custom_attributes.tensorType
+ )
+ else:
+ assert built_request.customAttributes.none == custom_attributes
+
+
+@pytest.mark.parametrize(
+ "reply_channel, model, input, output, output_descriptors, custom_attributes",
+ [
+ pytest.param(
+ [],
+ model_key,
+ [input_key1, input_key2],
+ [output_key1, output_key2],
+ [output_descriptor1],
+ tf_attributes,
+ id="bad channel",
+ ),
+ pytest.param(
+ "reply channel",
+ "bad model",
+ [input_key1],
+ [output_key2],
+ [output_descriptor1],
+ torch_attributes,
+ id="bad model",
+ ),
+ pytest.param(
+ "reply channel",
+ model_key,
+ ["input_key1", "input_key2"],
+ [output_key1, output_key2],
+ [output_descriptor1],
+ tf_attributes,
+ id="bad inputs",
+ ),
+ pytest.param(
+ "reply channel",
+ model_key,
+ [torch_attributes],
+ [output_key1, output_key2],
+ [output_descriptor1],
+ torch_attributes,
+ id="bad input schema type",
+ ),
+ pytest.param(
+ "reply channel",
+ model_key,
+ [input_key1],
+ ["output_key1", "output_key2"],
+ [output_descriptor1],
+ tf_attributes,
+ id="bad outputs",
+ ),
+ pytest.param(
+ "reply channel",
+ model_key,
+ [input_key1],
+ [torch_attributes],
+ [output_descriptor1],
+ tf_attributes,
+ id="bad output schema type",
+ ),
+ pytest.param(
+ "reply channel",
+ model_key,
+ [input_key1],
+ [output_key1, output_key2],
+ [output_descriptor1],
+ "bad attributes",
+ id="bad custom attributes",
+ ),
+ pytest.param(
+ "reply channel",
+ model_key,
+ [input_key1],
+ [output_key1, output_key2],
+ [output_descriptor1],
+ model_key,
+ id="bad custom attributes schema type",
+ ),
+ pytest.param(
+ "reply channel",
+ model_key,
+ [input_key1],
+ [output_key1, output_key2],
+ "bad descriptors",
+ torch_attributes,
+ id="bad output descriptors",
+ ),
+ ],
+)
+def test_build_request_indirect_unsuccessful(
+ reply_channel, model, input, output, output_descriptors, custom_attributes
+):
+ with pytest.raises(ValueError):
+ built_request = MessageHandler.build_request(
+ reply_channel,
+ model,
+ input,
+ output,
+ output_descriptors,
+ custom_attributes,
+ )
+
+
+@pytest.mark.parametrize(
+ "reply_channel, model, input, output, output_descriptors, custom_attributes",
+ [
+ pytest.param(
+ "reply channel",
+ model_key,
+ [tensor_1, tensor_2],
+ [],
+ [output_descriptor2],
+ torch_attributes,
+ ),
+ pytest.param(
+ "another reply channel",
+ model,
+ [tensor_1],
+ [],
+ [output_descriptor3],
+ tf_attributes,
+ ),
+ pytest.param(
+ "another reply channel",
+ model,
+ [tensor_2],
+ [],
+ [output_descriptor1],
+ tf_attributes,
+ ),
+ pytest.param(
+ "another reply channel",
+ model,
+ [tensor_1],
+ [],
+ [output_descriptor1],
+ None,
+ ),
+ ],
+)
+def test_build_request_direct_successful(
+ reply_channel, model, input, output, output_descriptors, custom_attributes
+):
+ built_request = MessageHandler.build_request(
+ reply_channel,
+ model,
+ input,
+ output,
+ output_descriptors,
+ custom_attributes,
+ )
+ assert built_request is not None
+ assert built_request.replyChannel.descriptor == reply_channel
+ if built_request.model.which() == "key":
+ assert built_request.model.key.key == model.key
+ else:
+ assert built_request.model.data.data == model.data
+ assert built_request.model.data.name == model.name
+ assert built_request.model.data.version == model.version
+ assert built_request.input.which() == "descriptors"
+ assert len(built_request.input.descriptors) == len(input)
+ assert len(built_request.output) == len(output)
+ for i, j in zip(built_request.outputDescriptors, output_descriptors):
+ assert i.order == j.order
+ if built_request.customAttributes.which() == "tf":
+ assert (
+ built_request.customAttributes.tf.tensorType == custom_attributes.tensorType
+ )
+ elif built_request.customAttributes.which() == "torch":
+ assert (
+ built_request.customAttributes.torch.tensorType
+ == custom_attributes.tensorType
+ )
+ else:
+ assert built_request.customAttributes.none == custom_attributes
+
+
+@pytest.mark.parametrize(
+ "reply_channel, model, input, output, output_descriptors, custom_attributes",
+ [
+ pytest.param(
+ [],
+ model_key,
+ [tensor_3, tensor_4],
+ [],
+ [output_descriptor2],
+ tf_attributes,
+ id="bad channel",
+ ),
+ pytest.param(
+ b"reply channel",
+ "bad model",
+ [tensor_4],
+ [],
+ [output_descriptor2],
+ tf_attributes,
+ id="bad model",
+ ),
+ pytest.param(
+ b"reply channel",
+ model_key,
+ ["input_key1", "input_key2"],
+ [],
+ [output_descriptor2],
+ torch_attributes,
+ id="bad inputs",
+ ),
+ pytest.param(
+ b"reply channel",
+ model_key,
+ [],
+ ["output_key1", "output_key2"],
+ [output_descriptor2],
+ tf_attributes,
+ id="bad outputs",
+ ),
+ pytest.param(
+ b"reply channel",
+ model_key,
+ [tensor_4],
+ [],
+ [output_descriptor2],
+ "bad attributes",
+ id="bad custom attributes",
+ ),
+ pytest.param(
+ b"reply_channel",
+ model_key,
+ [tensor_3, tensor_4],
+ [],
+ ["output_descriptor2"],
+ torch_attributes,
+ id="bad output descriptors",
+ ),
+ ],
+)
+def test_build_request_direct_unsuccessful(
+ reply_channel, model, input, output, output_descriptors, custom_attributes
+):
+ with pytest.raises(ValueError):
+ built_request = MessageHandler.build_request(
+ reply_channel,
+ model,
+ input,
+ output,
+ output_descriptors,
+ custom_attributes,
+ )
+
+
+@pytest.mark.parametrize(
+ "req",
+ [
+ pytest.param(tf_indirect_request, id="tf indirect"),
+ pytest.param(tf_direct_request, id="tf direct"),
+ pytest.param(torch_indirect_request, id="indirect"),
+ pytest.param(torch_direct_request, id="direct"),
+ ],
+)
+def test_serialize_request_successful(req):
+ serialized = MessageHandler.serialize_request(req)
+ assert type(serialized) == bytes
+
+ deserialized = MessageHandler.deserialize_request(serialized)
+ assert deserialized.to_dict() == req.to_dict()
+
+
+def test_serialization_fails():
+ with pytest.raises(ValueError):
+ bad_request = MessageHandler.serialize_request(tensor_1)
+
+
+def test_deserialization_fails():
+ with pytest.raises(ValueError):
+ new_req = torch_direct_request.copy()
+ req_bytes = MessageHandler.serialize_request(new_req)
+ req_bytes = req_bytes + b"extra bytes"
+ deser = MessageHandler.deserialize_request(req_bytes)
diff --git a/tests/test_message_handler/test_response.py b/tests/test_message_handler/test_response.py
new file mode 100644
index 0000000000..86774132ec
--- /dev/null
+++ b/tests/test_message_handler/test_response.py
@@ -0,0 +1,191 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from smartsim._core.mli.message_handler import MessageHandler
+
+# The tests in this file belong to the group_a group
+pytestmark = pytest.mark.group_a
+
+fsd = "mock-feature-store-descriptor"
+
+result_key1 = MessageHandler.build_tensor_key("result_key1", fsd)
+result_key2 = MessageHandler.build_tensor_key("result_key2", fsd)
+
+torch_attributes = MessageHandler.build_torch_response_attributes()
+tf_attributes = MessageHandler.build_tf_response_attributes()
+
+tensor1 = MessageHandler.build_tensor_descriptor("c", "int8", [1])
+tensor2 = MessageHandler.build_tensor_descriptor("c", "int64", [3, 2])
+
+
+tf_indirect_response = MessageHandler.build_response(
+ "complete",
+ "Success!",
+ [result_key1, result_key2],
+ tf_attributes,
+)
+
+tf_direct_response = MessageHandler.build_response(
+ "complete",
+ "Success again!",
+ [tensor2, tensor1],
+ tf_attributes,
+)
+
+torch_indirect_response = MessageHandler.build_response(
+ "complete",
+ "Success!",
+ [result_key1, result_key2],
+ torch_attributes,
+)
+
+torch_direct_response = MessageHandler.build_response(
+ "complete",
+ "Success again!",
+ [tensor1, tensor2],
+ torch_attributes,
+)
+
+
+@pytest.mark.parametrize(
+ "status, status_message, result, custom_attribute",
+ [
+ pytest.param(
+ 200,
+ "Yay, it worked!",
+ [tensor1, tensor2],
+ None,
+ id="tensor descriptor list",
+ ),
+ pytest.param(
+ 200,
+ "Yay, it worked!",
+ [result_key1, result_key2],
+ tf_attributes,
+ id="tensor key list",
+ ),
+ ],
+)
+def test_build_response_successful(status, status_message, result, custom_attribute):
+ response = MessageHandler.build_response(
+ status=status,
+ message=status_message,
+ result=result,
+ custom_attributes=custom_attribute,
+ )
+ assert response is not None
+ assert response.status == status
+ assert response.message == status_message
+ if response.result.which() == "keys":
+ assert response.result.keys[0].to_dict() == result[0].to_dict()
+ else:
+ assert response.result.descriptors[0].to_dict() == result[0].to_dict()
+
+
+@pytest.mark.parametrize(
+ "status, status_message, result, custom_attribute",
+ [
+ pytest.param(
+ "bad status",
+ "Yay, it worked!",
+ [tensor1, tensor2],
+ None,
+ id="bad status",
+ ),
+ pytest.param(
+ "complete",
+ 200,
+ [tensor2],
+ torch_attributes,
+ id="bad status message",
+ ),
+ pytest.param(
+ "complete",
+ "Yay, it worked!",
+ ["result_key1", "result_key2"],
+ tf_attributes,
+ id="bad result",
+ ),
+ pytest.param(
+ "complete",
+ "Yay, it worked!",
+ [tf_attributes],
+ tf_attributes,
+ id="bad result type",
+ ),
+ pytest.param(
+ "complete",
+ "Yay, it worked!",
+ [tensor2, tensor1],
+ "custom attributes",
+ id="bad custom attributes",
+ ),
+ pytest.param(
+ "complete",
+ "Yay, it worked!",
+ [tensor2, tensor1],
+ result_key1,
+ id="bad custom attributes type",
+ ),
+ ],
+)
+def test_build_response_unsuccessful(status, status_message, result, custom_attribute):
+ with pytest.raises(ValueError):
+ response = MessageHandler.build_response(
+ status, status_message, result, custom_attribute
+ )
+
+
+@pytest.mark.parametrize(
+ "response",
+ [
+ pytest.param(torch_indirect_response, id="indirect"),
+ pytest.param(torch_direct_response, id="direct"),
+ pytest.param(tf_indirect_response, id="tf indirect"),
+ pytest.param(tf_direct_response, id="tf direct"),
+ ],
+)
+def test_serialize_response(response):
+ serialized = MessageHandler.serialize_response(response)
+ assert type(serialized) == bytes
+
+ deserialized = MessageHandler.deserialize_response(serialized)
+ assert deserialized.to_dict() == response.to_dict()
+
+
+def test_serialization_fails():
+ with pytest.raises(ValueError):
+ bad_response = MessageHandler.serialize_response(result_key1)
+
+
+def test_deserialization_fails():
+ with pytest.raises(ValueError):
+ new_resp = torch_direct_response.copy()
+ resp_bytes = MessageHandler.serialize_response(new_resp)
+ resp_bytes = resp_bytes + b"extra bytes"
+ deser = MessageHandler.deserialize_response(resp_bytes)
diff --git a/tests/test_node_prioritizer.py b/tests/test_node_prioritizer.py
new file mode 100644
index 0000000000..905c0ecc90
--- /dev/null
+++ b/tests/test_node_prioritizer.py
@@ -0,0 +1,553 @@
+# BSD 2-Clause License
+#
+# Copyright (c) 2021-2024, Hewlett Packard Enterprise
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import random
+import threading
+import typing as t
+
+import pytest
+
+from smartsim._core.launcher.dragon.pqueue import NodePrioritizer, PrioritizerFilter
+from smartsim.error.errors import SmartSimError
+from smartsim.log import get_logger
+
+# The tests in this file belong to the group_b group
+pytestmark = pytest.mark.group_b
+
+
+logger = get_logger(__name__)
+
+
+class MockNode:
+ def __init__(self, hostname: str, num_cpus: int, num_gpus: int) -> None:
+ self.hostname = hostname
+ self.num_cpus = num_cpus
+ self.num_gpus = num_gpus
+
+
+def mock_node_hosts(
+ num_cpu_nodes: int, num_gpu_nodes: int
+) -> t.Tuple[t.List[MockNode], t.List[MockNode]]:
+ cpu_hosts = [f"cpu-node-{i}" for i in range(num_cpu_nodes)]
+ gpu_hosts = [f"gpu-node-{i}" for i in range(num_gpu_nodes)]
+
+ return cpu_hosts, gpu_hosts
+
+
+def mock_node_builder(num_cpu_nodes: int, num_gpu_nodes: int) -> t.List[MockNode]:
+ nodes = []
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+
+ nodes.extend(MockNode(hostname, 4, 0) for hostname in cpu_hosts)
+ nodes.extend(MockNode(hostname, 4, 4) for hostname in gpu_hosts)
+
+ return nodes
+
+
+def test_node_prioritizer_init_null() -> None:
+ """Verify that the priorizer reports failures to send a valid node set
+ if a null value is passed"""
+ lock = threading.RLock()
+ with pytest.raises(SmartSimError) as ex:
+ NodePrioritizer(None, lock)
+
+ assert "Missing" in ex.value.args[0]
+
+
+def test_node_prioritizer_init_empty() -> None:
+ """Verify that the priorizer reports failures to send a valid node set
+ if an empty list is passed"""
+ lock = threading.RLock()
+ with pytest.raises(SmartSimError) as ex:
+ NodePrioritizer([], lock)
+
+ assert "Missing" in ex.value.args[0]
+
+
+@pytest.mark.parametrize(
+ "num_cpu_nodes,num_gpu_nodes", [(1, 1), (2, 1), (1, 2), (8, 4), (1000, 200)]
+)
+def test_node_prioritizer_init_ok(num_cpu_nodes: int, num_gpu_nodes: int) -> None:
+ """Verify that initialization with a valid node list results in the
+ appropriate cpu & gpu ref counts, and complete ref map"""
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ # perform prioritizer initialization
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # get a copy of all the expected host names
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ all_hosts = cpu_hosts + gpu_hosts
+ assert len(all_hosts) == num_cpu_nodes + num_gpu_nodes
+
+ # verify tracking data is initialized correctly for all nodes
+ for hostname in all_hosts:
+ # show that the ref map is tracking the node
+ assert hostname in p._nodes
+
+ tracking_info = p.get_tracking_info(hostname)
+
+ # show that the node is created w/zero ref counts
+ assert tracking_info.num_refs == 0
+
+ # show that the node is created and marked as not dirty (unchanged)
+ # assert tracking_info.is_dirty == False
+
+ # iterate through known cpu node keys and verify prioritizer initialization
+ for hostname in cpu_hosts:
+ # show that the device ref counters are appropriately assigned
+ cpu_ref = next((n for n in p._cpu_refs if n.hostname == hostname), None)
+ assert cpu_ref, "CPU-only node not found in cpu ref set"
+
+ gpu_ref = next((n for n in p._gpu_refs if n.hostname == hostname), None)
+ assert not gpu_ref, "CPU-only node should not be found in gpu ref set"
+
+ # iterate through known GPU node keys and verify prioritizer initialization
+ for hostname in gpu_hosts:
+ # show that the device ref counters are appropriately assigned
+ gpu_ref = next((n for n in p._gpu_refs if n.hostname == hostname), None)
+ assert gpu_ref, "GPU-only node not found in gpu ref set"
+
+ cpu_ref = next((n for n in p._cpu_refs if n.hostname == hostname), None)
+ assert not cpu_ref, "GPU-only node should not be found in cpu ref set"
+
+ # verify we have all hosts in the ref map
+ assert set(p._nodes.keys()) == set(all_hosts)
+
+ # verify we have no extra hosts in ref map
+ assert len(p._nodes.keys()) == len(set(all_hosts))
+
+
+def test_node_prioritizer_direct_increment() -> None:
+ """Verify that performing the increment operation causes the expected
+ side effect on the intended records"""
+
+ num_cpu_nodes, num_gpu_nodes = 32, 8
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ exclude_index = 2
+ exclude_host0 = cpu_hosts[exclude_index]
+ exclude_host1 = gpu_hosts[exclude_index]
+ exclusions = [exclude_host0, exclude_host1]
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # let's increment each element in a predictable way and verify
+ for node in nodes:
+ if node.hostname in exclusions:
+ # expect 1 cpu and 1 gpu node at zero and not incremented
+ continue
+
+ if node.num_gpus == 0:
+ num_increments = random.randint(0, num_cpu_nodes - 1)
+ else:
+ num_increments = random.randint(0, num_gpu_nodes - 1)
+
+ # increment this node some random number of times
+ for _ in range(num_increments):
+ p.increment(node.hostname)
+
+ # ... and verify the correct incrementing is applied
+ tracking_info = p.get_tracking_info(node.hostname)
+ assert tracking_info.num_refs == num_increments
+
+ # verify the excluded cpu node was never changed
+ tracking_info0 = p.get_tracking_info(exclude_host0)
+ assert tracking_info0.num_refs == 0
+
+ # verify the excluded gpu node was never changed
+ tracking_info1 = p.get_tracking_info(exclude_host1)
+ assert tracking_info1.num_refs == 0
+
+
+def test_node_prioritizer_indirect_increment() -> None:
+ """Verify that performing the increment operation indirectly affects
+ each available node until we run out of nodes to return"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # verify starting state
+ for node in p._nodes.values():
+ tracking_info = p.get_tracking_info(node.hostname)
+
+ assert node.num_refs == 0 # <--- ref count starts at zero
+ assert tracking_info.num_refs == 0 # <--- ref count starts at zero
+
+ # perform indirect
+ for node in p._nodes.values():
+ tracking_info = p.get_tracking_info(node.hostname)
+
+ # apply `next` operation and verify tracking info reflects new ref
+ node = p.next(PrioritizerFilter.CPU)
+ tracking_info = p.get_tracking_info(node.hostname)
+
+ # verify side-effects
+ assert tracking_info.num_refs > 0 # <--- ref count should now be > 0
+
+ # we expect it to give back only "clean" nodes from next*
+ assert tracking_info.is_dirty == False # NOTE: this is "hidden" by protocol
+
+ # every node should be incremented now. prioritizer shouldn't have anything to give
+ tracking_info = p.next(PrioritizerFilter.CPU)
+ assert tracking_info is None # <--- get_next shouldn't have any nodes to give
+
+
+def test_node_prioritizer_indirect_decrement_availability() -> None:
+ """Verify that a node who is decremented (dirty) is made assignable
+ on a subsequent request"""
+
+ num_cpu_nodes, num_gpu_nodes = 1, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # increment our only node...
+ p.increment(cpu_hosts[0])
+
+ tracking_info = p.next()
+ assert tracking_info is None, "No nodes should be assignable"
+
+ # perform a decrement...
+ p.decrement(cpu_hosts[0])
+
+ # ... and confirm that the node is available again
+ tracking_info = p.next()
+ assert tracking_info is not None, "A node should be assignable"
+
+
+def test_node_prioritizer_multi_increment() -> None:
+ """Verify that retrieving multiple nodes via `next_n` API correctly
+ increments reference counts and returns appropriate results"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # Mark some nodes as dirty to verify retrieval
+ p.increment(cpu_hosts[0])
+ assert p.get_tracking_info(cpu_hosts[0]).num_refs > 0
+
+ p.increment(cpu_hosts[2])
+ assert p.get_tracking_info(cpu_hosts[2]).num_refs > 0
+
+ p.increment(cpu_hosts[4])
+ assert p.get_tracking_info(cpu_hosts[4]).num_refs > 0
+
+ # use next_n w/the minimum allowed value
+ all_tracking_info = p.next_n(1, PrioritizerFilter.CPU) # <---- next_n(1)
+
+ # confirm the number requested is honored
+ assert len(all_tracking_info) == 1
+ # ensure no unavailable node is returned
+ assert all_tracking_info[0].hostname not in [
+ cpu_hosts[0],
+ cpu_hosts[2],
+ cpu_hosts[4],
+ ]
+
+ # use next_n w/value that exceeds available number of open nodes
+ # 3 direct increments in setup, 1 out of next_n(1), 4 left
+ all_tracking_info = p.next_n(5, PrioritizerFilter.CPU)
+
+ # confirm that no nodes are returned, even though 4 out of 5 requested are available
+ assert len(all_tracking_info) == 0
+
+
+def test_node_prioritizer_multi_increment_validate_n() -> None:
+ """Verify that retrieving multiple nodes via `next_n` API correctly
+ reports failures when the request size is above pool size"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # we have 8 total cpu nodes available... request too many nodes
+ all_tracking_info = p.next_n(9, PrioritizerFilter.CPU)
+ assert len(all_tracking_info) == 0
+
+ all_tracking_info = p.next_n(num_cpu_nodes * 1000, PrioritizerFilter.CPU)
+ assert len(all_tracking_info) == 0
+
+
+def test_node_prioritizer_indirect_direct_interleaved_increments() -> None:
+ """Verify that interleaving indirect and direct increments results in
+ expected ref counts"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 4
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # perform some set of non-popped increments
+ p.increment(gpu_hosts[1])
+ p.increment(gpu_hosts[3])
+ p.increment(gpu_hosts[3])
+
+ # increment 0th item 1x
+ p.increment(cpu_hosts[0])
+
+ # increment 3th item 2x
+ p.increment(cpu_hosts[3])
+ p.increment(cpu_hosts[3])
+
+ # increment last item 3x
+ p.increment(cpu_hosts[7])
+ p.increment(cpu_hosts[7])
+ p.increment(cpu_hosts[7])
+
+ tracking_info = p.get_tracking_info(gpu_hosts[1])
+ assert tracking_info.num_refs == 1
+
+ tracking_info = p.get_tracking_info(gpu_hosts[3])
+ assert tracking_info.num_refs == 2
+
+ nodes = [n for n in p._nodes.values() if n.num_refs == 0 and n.num_gpus == 0]
+
+ # we should skip the 0-th item in the heap due to direct increment
+ tracking_info = p.next(PrioritizerFilter.CPU)
+ assert tracking_info.num_refs == 1
+ # confirm we get a cpu node
+ assert "cpu-node" in tracking_info.hostname
+
+ # this should pull the next item right out
+ tracking_info = p.next(PrioritizerFilter.CPU)
+ assert tracking_info.num_refs == 1
+ assert "cpu-node" in tracking_info.hostname
+
+ # ensure we pull from gpu nodes and the 0th item is returned
+ tracking_info = p.next(PrioritizerFilter.GPU)
+ assert tracking_info.num_refs == 1
+ assert "gpu-node" in tracking_info.hostname
+
+ # we should step over the 3-th node on this iteration
+ tracking_info = p.next(PrioritizerFilter.CPU)
+ assert tracking_info.num_refs == 1
+ assert "cpu-node" in tracking_info.hostname
+
+ # and ensure that heap also steps over a direct increment
+ tracking_info = p.next(PrioritizerFilter.GPU)
+ assert tracking_info.num_refs == 1
+ assert "gpu-node" in tracking_info.hostname
+
+ # and another GPU request should return nothing
+ tracking_info = p.next(PrioritizerFilter.GPU)
+ assert tracking_info is None
+
+
+def test_node_prioritizer_decrement_floor() -> None:
+ """Verify that repeatedly decrementing ref counts does not
+ allow negative ref counts"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 4
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # try a ton of decrements on all the items in the prioritizer
+ for _ in range(len(nodes) * 100):
+ index = random.randint(0, num_cpu_nodes - 1)
+ p.decrement(cpu_hosts[index])
+
+ index = random.randint(0, num_gpu_nodes - 1)
+ p.decrement(gpu_hosts[index])
+
+ for node in nodes:
+ tracking_info = p.get_tracking_info(node.hostname)
+ assert tracking_info.num_refs == 0
+
+
+@pytest.mark.parametrize("num_requested", [1, 2, 3])
+def test_node_prioritizer_multi_increment_subheap(num_requested: int) -> None:
+ """Verify that retrieving multiple nodes via `next_n` API correctly
+ increments reference counts and returns appropriate results
+ when requesting an in-bounds number of nodes"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # Mark some nodes as dirty to verify retrieval
+ p.increment(cpu_hosts[0])
+ p.increment(cpu_hosts[2])
+ p.increment(cpu_hosts[4])
+
+ hostnames = [cpu_hosts[0], cpu_hosts[1], cpu_hosts[2], cpu_hosts[3], cpu_hosts[5]]
+
+ # request n == {num_requested} nodes from set of 3 available
+ all_tracking_info = p.next_n(
+ num_requested,
+ hosts=hostnames,
+ ) # <---- w/0,2,4 assigned, only 1,3,5 from hostnames can work
+
+ # all parameterizations should result in a matching output size
+ assert len(all_tracking_info) == num_requested
+
+
+def test_node_prioritizer_multi_increment_subheap_assigned() -> None:
+ """Verify that retrieving multiple nodes via `next_n` API does
+ not return anything when the number requested cannot be satisfied
+ by the given subheap due to prior assignment"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # Mark some nodes as dirty to verify retrieval
+ p.increment(cpu_hosts[0])
+ p.increment(cpu_hosts[2])
+
+ hostnames = [
+ cpu_hosts[0],
+ "x" + cpu_hosts[2],
+ ] # <--- we can't get 2 from 1 valid node name
+
+ # request n == {num_requested} nodes from set of 3 available
+ num_requested = 2
+ all_tracking_info = p.next_n(num_requested, hosts=hostnames)
+
+ # w/0,2 assigned, nothing can be returned
+ assert len(all_tracking_info) == 0
+
+
+def test_node_prioritizer_empty_subheap_next_w_no_hosts() -> None:
+ """Verify that retrieving multiple nodes via `next_n` API does
+ with an empty host list uses the entire available host list"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # Mark some nodes as dirty to verify retrieval
+ p.increment(cpu_hosts[0])
+ p.increment(cpu_hosts[2])
+
+ hostnames = []
+
+ # request n == {num_requested} nodes from set of 3 available
+ num_requested = 1
+ node = p.next(hosts=hostnames)
+ assert node
+
+ # assert "No hostnames provided" == ex.value.args[0]
+
+
+def test_node_prioritizer_empty_subheap_next_n_w_hosts() -> None:
+ """Verify that retrieving multiple nodes via `next_n` API does
+ not blow up with an empty host list"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # Mark some nodes as dirty to verify retrieval
+ p.increment(cpu_hosts[0])
+ p.increment(cpu_hosts[2])
+
+ hostnames = []
+
+ # request n == {num_requested} nodes from set of 3 available
+ num_requested = 1
+ node = p.next_n(num_requested, hosts=hostnames)
+ assert node is not None
+
+
+@pytest.mark.parametrize("num_requested", [-100, -1, 0])
+def test_node_prioritizer_empty_subheap_next_n(num_requested: int) -> None:
+ """Verify that retrieving a node via `next_n` API does
+ not allow a request with num_items < 1"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # Mark some nodes as dirty to verify retrieval
+ p.increment(cpu_hosts[0])
+ p.increment(cpu_hosts[2])
+
+ # request n == {num_requested} nodes from set of 3 available
+ with pytest.raises(ValueError) as ex:
+ p.next_n(num_requested)
+
+ assert "Number of items requested" in ex.value.args[0]
+
+
+@pytest.mark.parametrize("num_requested", [-100, -1, 0])
+def test_node_prioritizer_empty_subheap_next_n(num_requested: int) -> None:
+ """Verify that retrieving multiple nodes via `next_n` API does
+ not allow a request with num_items < 1"""
+
+ num_cpu_nodes, num_gpu_nodes = 8, 0
+ cpu_hosts, gpu_hosts = mock_node_hosts(num_cpu_nodes, num_gpu_nodes)
+ nodes = mock_node_builder(num_cpu_nodes, num_gpu_nodes)
+
+ lock = threading.RLock()
+ p = NodePrioritizer(nodes, lock)
+
+ # Mark some nodes as dirty to verify retrieval
+ p.increment(cpu_hosts[0])
+ p.increment(cpu_hosts[2])
+
+ hostnames = [cpu_hosts[0], cpu_hosts[2]]
+
+ # request n == {num_requested} nodes from set of 3 available
+ with pytest.raises(ValueError) as ex:
+ p.next_n(num_requested, hosts=hostnames)
+
+ assert "Number of items requested" in ex.value.args[0]